Predicting churn for a Streaming music platform

Project Definition:

Problem Overview

This project consists of predicting the churn rate of a fictitious music streaming platform called “Sparkify”, it also seeks to create and identify significant features that describe the behavior of this variable.

But first, what is churn? Basically, a user is agitated when they stopped using our products or services. So the churn rate will be the number of customers who left during that time period times the number of customers at the beginning of it.

The problem domain is based on loyalty of engament of the cliente with some product of services this is a huge problem that now is comonly face by many companies and alwas be challenging especially for data scientist that have the dificult task of analysis large dataset find patters and predicted the behaviur ofr this metric.

For this task we will be use a small part of the complete data set. It contains information about interactions of 226 different users, the data set includes demographic information and transactional data on user behavior on the platform.

Problem Statement

There is a great challenge due to the huge dataset we work with, so we have to manipulate it using Apache Spark, a powerful data processing framework.

After all the processing, structuring of the data and the creation of the churn indicator based on user events, we will develop a basic analysis and some visualizations that can help us to understand in depth how the data behaves, once we have this step, we can pass to evaluate and create some new relevant features to help us achieve our goal.

Finally we will use Spark MLlib to build and test different machine learning models with large data sets, this library allows us to improve the training process and model optimization using Cross Validaion and Grid Search.


This problem can be taken as a binary classification problem in which we must classify the user into two groups, the one who is going to churn and the one who is not. These types of algorithms have four basic concepts to consider:

  • True positive (TP): We are correct in our true prediction.
  • True negative (TN): We make a correct rejection.
  • False positive (FP): Detected as false alarm (Type I error).
  • False negative (FN): We failed the prediction (Type II error).

If we want to evaluate how well your model performs we need to focus on TP and TN, Accuracy gives us the percent of predictions our model got right.

Another important analysis that we are going to apply is the ROC curve, this is a probability curve based on the variation of a threshold that modifies the distribution of TP, TN, FP and FN. Once the curve is drawn, we can measure the area under it (AUC), this measurement tells us how much the model is able to distinguish between classes. The higher the AUC, the better the model.


Data Exploration

First we need to know the schema of the data, this will help us to identify the initial structure and useful columns for the objective.

Next we must do a basic data analysis to learn more about the information, we find 286.500 rows with 18 columns, the data set includes demographic information and transactional data on user behavior on the platform. For this case we are going to divide the basic analysis into two large groups, one categorical and the other numerical.

For categorical variables, we apply a simple analysis to identify the possible values ​​that we can find in each column and how they are distributed.

In this first analysis we can see that there are 8.346 rows with null values, if we do a deeper analysis based on the events corresponding to this transaction we find that these refer to page navigation events without Login, this information is not useful for our analysis . so we can filter it.

For numeric variables, we evaluate some basic statistics such as record count, mean, standard deviation, minimum, and maximum. In this analysis, we found that the two columns with date data (record and ts) are not in the correct format. To do this, we need to apply a date transformation that allows us to perform different operations between them and obtain enriched characteristics.

In this step we also remove some columns that we consider not necessary for the project such as: auth, firstName, lastName, location, method.

Now that we’ve done our first basic analysis, we need to define the churn indicator. For this we will use “Cancellation confirmation” events that occur for both paid and free users.

Data Visualization

For this step we carry out some groupings to observe how the data changes through each class, using the abandonment flag previously created to separate the users and achieve a better understanding of the impact of the variables.

In further analysis related to preprocessing, we will perform additional data visualizations based on feature engineering.

First, we study how the categorical variables differ between each group of users, the two main relevant plots and in which we cannot make some hypotheses are:

  • Gender: Men seem to churn more than women, but in the modeling stage we will analyze whether this difference is significant or not.
  • Level: Users with a free account cancel in a higher proportion than paid users.
Users by churn classification

For the interpretation of the numerical variables we use the box plot, this is a useful tool because it provides valuable information about the distribution of the data, including the median, the interquartile range and the variation of the data.

The interpretation of the graphics shows in a first approach that users who listen to a greater diversity of artists and have a high number of elements in session are less likely to cancel their subscription

Artist listened and Maximun items in session by churn classification


Data Preprocessing

Now that we have performed a first analysis of the data, we can proceed to generate new features to enrich our data and have a better final result of the model. Let’s first apply some necessary transformations:

  • The dates in the data frame aren’t in a format that we can process, so we need to adjust them to a date and time format that we can work with, this will help us to do some interesting analysis, in this case the antique of the platform users using the registration date and the activity date.
  • Extract the name of the operating system for the userAgent column, once we analyze the data we find a pattern that allows us to filter this information using the data processing function such as strip and subtring.
  • Use the satatus columns to detects the error when interacting with the platform. This can be a promising event to determine the loss of users.

After applying this transformation, it is time to establish new characteristics that are considered valuable for the churn classification. First we use the different navigation pages to specify some events that can help us decide if the user is going to churn:

  • advert: On some platforms the user does not like to receive constant advertising while listening to their favorite music, so this can be an interesting variable to determine abandonment.
  • addFriend: Sets whether the user adds a friend to the list, improving their affinity with the platform.
  • addToPlaylist: Sets whether the user adds a song to a playlist for better interaction with the platform.

Some of the features thought for this exercise are based on per session behavior, so we need to create a dataframe that groups parameter by userId and sessionId:

  • avgSongListened: average number of songs listened to by the user in all sessions.
  • avgSessionDuration: average of the duration of user interaction in all sessions.

After defining all these rules and parameters for new functions, we can apply it to the final data frame, once we achieve it, we must make sure that in our datafarme we only have one row per user, so we give full power to Spark to aggregate the data with the parameters of the data exploration stage and the new characteristics.

Preview of the final datafrme

By adding visualizations for the new variables, we can see that some of them show relevant characteristics, such as that the devices on which the platform is running can produce a higher churn rate than others, also adding friends has a significant difference in the users who cancel your subscription. The most relevant are the differences in the distribution of antiques, users who remain on the platform for a long period of time are more engaged to it

Aditional data visualization

Pair plots are useful tools to extract the behavior of numerical variables such as their distribution and their direct relationship with others through a scatter plot.

With this graph we can see that there are some variables with high correlation and that we can consider them as self-explanatory, so we can eliminate it from the dataframe, in the next step we only take all the numerical columns that visually do not have high correlation and we are going to perform a correlation analysis.

After this analysis we decide to delete columns: avgItemSession, artistListened and friendsAdded. Highly correlated variables are not desirable in model development as they could cause some serious problems such as overfitting and reduced generalization.


For the implementation of the model, we first need to use all the capabilities that Spark ML offers us, this library facilitates the execution of model training processes with tools such as VectorAssember and StringIndexer. This will make the data structure much more generalizable for testing many different types of algorithms.

Now we are going to build a pipeline with data transformation that are necessary to guarantee good results:

  • Apply One Hot Encoding to categorical columns, this will leave our data frame with only numeric values ​​and these variables will change to a binary interpretation.
  • Normalize the numerical variables, this will prevent the magnitude of some values ​​from affecting the result, this process works with a range between 0 and 1 or -1 and 1.
  • Assemble all the features in one column and index the label.

The next step is to aplly a train-test split for this exercise we use a division of 70% for training data and 30% for test data, with a distribution of 153 users on train and 71 in test. After that, we are going to test two classifiers, Random Forest Classifier and Logistic Regression and see how each performs on our data.

The Random Forest Classifier was trained with the following hyperparameters:

  • maxDepth: 5 (This represenen that the algorithm will make more splits and it gives the capability to capture more information about the data).
  • numTrees: 20 (Changing the number of trees in the forest makes us have more possible interaction combinations of our variables in each tree)
  • maxBins: 32 (Increasing maxBins allows the algorithm to consider more split candidates)

When we evaluate this model we obtain an AUC of 0.79 and an Acurrancy of 87%, which shows a good initial performance. Whenever we build a random forest it is essential to visualize the logic of the trees. This is a huge forest so we will only see a part of it as an example, the complete structure can be seen in the jupyter notebook file.

Random Forest logic structure

The Logistic Regiression Classifier was trained with the following hyperparameters:

  • maxIter: 100 (Maximun number of iterations to run the optimization algorithm)
  • threshold: 0.5 (It allow to ajust the realtion between TP and TN predictions)
Preview of the predicted dataframe

After evaluating the main defined metrics, we obtain an AUC of 0.78 and a Accurancy of 84%, these are good results and it means that the definition of the initial variables for the exercise was adequate.


Grid Search is extremely important to optimize our models and get the best out of each one, for this case we are going to take Rambdom Fores Classifier as a reference, and we will make a grid of parameters with three main hyperparameters maxDepth, numTrees and impunity.

Paramter Grid configuration

This method tests all the possible combinations of hyperparameters and brings back the best possible model, we must be careful and know very well which hyperparameters to test and an estimated range of values, because if we use too many parameters the process can take a long time and never finish.

We also use a 10-fold cross validation which is one of the most used refinement methods when performing machine learning models, this simply consists of adjusting each model 10 times, in each iteration the part of the data that is used to train changes. . and the one that is udes for test. This will avoid some errors caused by the traditional division of the train test.

We see that using the parameter grid and 10-fold cross valiation we obtain good results with an accurancy of 86%. Now we must save the best model and find its hyperparameters:

  • maxDepth: 3
  • numTrees: 20
  • impurity: gini


Model Evaluation and Validation

After all the analysis we chose tahta the best model to implemnet for the prorblem is the one obatin with the Grid Search Method with its 86% of accurancy, in the next part we make a zoom on why we chose this as the best option.

Another important part that we want to evaluate is that the characteristic we create is really significant for the model, so we obtain the information of how each variable behaves and we plot the 10 main variables by importance for the final prediction.

It seems that our first impression of the data related to the user’s antique on the platform is correct, below we have the amount of publicity he receives and the number of songs that the user adds to his playlist.


In the first place we choose the random forest instead of the logistic regression, because although we are doing a one hot coding the decision trees can work with categorical data and establish rules based on them based on a route logic where the probability of the final event, while Logistic regression is a pure mathematical algorithm that will not always interpret category variables as we want.

Why do we choose a model that performs worse than another? The answer is always dependent, in this case based on the methods used we have a solid accuracy of 87% in the simple default random forest, but using 10-fold cross validation (CV) gives us a real advantage over the final validation of the model. This advantage is respresented by the iteration over the train and the test data, thus obtaining an average precision, guaranteeing a high degree of confidence in this vlue. Also, when we have this type of case where a class is not so significant, it is highly recommended to use CV because when we make all the folds we can be sure that each class will be represented.



Spark is a high-performance data processing tool, which helps us manage large data sets, additionally allows working models through a defined structure that makes it easy to test how different algorithms work on the same data.

The feature engineering process is crucial for data science procedures since it gives significant value to data that is not normally perceived at the beginning and these can be decisive for the efficiency of an analytical model, thus giving us a complete vision when addressing problems.

Refinement of the model and logical thinking are the right way to choose a model for a business case, we not only need to focus on the output of the model but also on whether it is a coherent result and adjusted to the needs of the case.


This case can be taken much further when more consolidated analyzes are carried out on user behavior over time, in addition to validating other alternative classification algorithms. There is also a great opportunity for improvement in the generation of a much more robust parameter matrix with a greater number of variables and hyperparameters.

You can see the complete project in finished code, that has been posted on GitHub here.



Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store