Sparkify — Churn Prediction

Akshay Jain
8 min readDec 20, 2020

This blog was written to meet requirements for my Udacity “Data Scientist Nanodegree” capstone project. In this project, several prediction models have been tested using Spark to identify users that have higher likelihood to churn for a fictitious streaming music service.


Sparkify is a music streaming service ( like Spotify, Pandora, etc. ). In Sparkify users can either listen to music for free or buy a subscription. The free users have to listen to ads while the subscription users ( or paid users ) listen to songs ad-free. Users need a login to listen to a song on the service.

From a subscription perspective at any point, users can do any of the following:

  1. Upgrade from the free tier to the paid tier
  2. Downgrade from the paid tier to the free tier
  3. Cancel their account and leave the service

Key Business Question

To identify the users that could potentially cancel their account and leave the service. By identifying this population upfront, we could incentivize them by giving them discounts or other rewards. This could make them stick, giving us a loyal customer base which is key for a company’s growth.

To attain the objective, it’s important to first do exploratory analysis to glean insights from the data set and identify key variables of interest. The next step is to test different model algorithms and then pick the best model based on key evaluation metric (F1 Score) using Spark ML Library.


The data we have from Sparkify is that of user events. Every interaction of every user with the application is given to us. This means every time a user goes to the Home page, listens to a song, thumbs up a song, etc. we have an event in the data corresponding to the same. Each record in our data represents an event by a user.

Some of the key attributes present in each event in the data are:

  1. userId : A unique ID that identifies a user of the service
  2. sessionId : A unique ID that identifies a single continuous period of use by a user of the service
  3. artist : Artist that the user was listening to
  4. song : Song that the user was listening to
  5. page : Page the user visited when this event got generated. It can take various values : cancelling service, thumbs up, thumbs down, add to playlist, and add a friend etc.
  6. level : This can take 2 values: Paid/Free. It indicates was the user a paid or a free user when triggering this event
  7. ts : The timestamp in epochs for this event
  8. location : The timestamp in epochs for this event
  9. userAgent : It denotes the device, platform, and browser that the user chose to use at the time of the service.


1. Data Cleaning and Exploration

  1. Invalid data such as records with missing UserIDs and SessionIDs were removed from analysis.
  2. Duplicate records were removed from analysis.
  3. The total number of distinct users in the dataset are 225 and about 54% of users were identified as Male.
  4. Definition of Churn : Any user who cancelled their account on the platform. In other words, a churned user is one who has cancelled their account and left the platform.
  5. Out of 225, 52 users were identified to be churned; this is about 24% of the universe.
  6. The top 5 states that users of this service belong to are : California, Texas, New York, Florida and Arizona.
  7. About 54% of users leveraged Chrome as a browser and 48% of users used windows as their choice of OS.
  8. The churn rate for males is significantly higher than females (26% vs 19%) as shown in chart below.

9. The churn rate for users leveraging ‘Firefox’ browser is significantly higher than users using other browser types as shown in chart below.

10. The churn rate for users leveraging ‘Linux’ OS is significantly higher than users using other OS types as shown in chart below.

11. The churn rate for ‘free’ user is higher than ‘paid’ user (24% vs 20%) as shown in chart below.

2. Feature Engineering

Below is a list of features I incorporated to build a ML model to predict churn at user level:

Numerical Features

  1. Total number of songs
  2. Average length of session
  3. Number of songs that received a thumb up
  4. Number of songs that received a thumb down
  5. Number of friend additions
  6. Number of advertisements seen
  7. Number of playlist additions
  8. Number of unique artists and songs listened to
  9. Number of sessions

Categorial Features

  1. Gender type
  2. Browser type
  3. OS type
  4. Level of the user
  5. State a user lives in

All missing values were imputed with 0 in this case. The dataset was further split into — train (80%) and validation (20%) for building and validating the model respectively.

The next step was to create a Vector by combining the features together and scaling them using MinMaxScaler. For each value in a feature,MinMaxScaler subtracts the minimum value in the feature and then divides by the range. MinMaxScaler preserves the shape of the original distribution.

3. Model Building

Below are some classification models which I attempted :

  1. Logistic Regression (LR)
  2. Gradient Boosted Trees Classification (GBT)
  3. Support Vector Machines Classification (SVM)
  4. Random Forest Classification (RF)

For each model algorithm mentioned above, the following steps were followed:

  1. The Classifier was initially defined to create a Pipeline which took into account various parameters such as — Vector, Scaler, and Classifier.

2. Created a Parameter Grid to perform hyperparameter optimization to identify the best model which minimizes the error rate.

For logistic regression, due to memory constraints (regParam ([0.1,0.01]) was chosen in Grid Search. This is the regularization parameter which penalizes coefficient of variables if they are correlated with other variables in the model thereby making the model less redundant in nature.

For other Tree based methods, I used only two parameters such as (maxDepth[2,4]) and maxIter ([1,5]).

MaxDepth is referred to as maximum depth of the tree. Deeper trees are very expressive and may improve the efficacy of model prediction but are computationally expensive. Hence, I chose to use smaller set of values for this parameter.

MaxIter is the number of trees in an ensemble method. The tree based methods are ensemble methods i.e. they combine several base models to provide one best optimal model. I tried using some higher values but due to computational constraints had to cap this to 5.

3. CrossValidator objects were instantiated using Pipeline and Parameter Grid values for different algorithms as shown below.

4.The model was then trained on train dataset and cross-validated with (k=2) folds on validation dataset.

Here’s a snippet of the code:

4. Model Evaluation

F1 Score:

As it’s a classification model, F1 score was chosen a key evaluation metric to select the best model. Below are reasons why F1 score was chosen as key evaluation measure:

  • It provides a comprehensive summary of many metrics like Recall, Precision, True Positive, False Positive, False Negatives into one.
  • It provides equal weightage to both precision and recall values and is a robust measure compared to other metrics.
  • Optimizing F1 score results in low False Positive and False Negatives values thereby reducing business costs

The final metrics for all models after 2-fold cross-validation are as follows:

GBTs build trees one at a time, where each new tree helps to correct errors made by previously trained tree. With each tree added, the model becomes even more expressive. After performing hypertuning optimization, the best values for Maximum Depth and Maximum Iterations parameters identified by Grid Search are 4 and 5 respectively. This makes sense as increasing the depth of tree and number of trees impacts model performance in a positive manner. Since, a 2-fold cross-validation was done, the danger of running into model overfitting issue was reduced significantly as well.

Based on results above, the GBT Classification is the best model as it has the highest F1 Score of 0.77.


Let’s take a step back and look at the whole journey.

  1. The objective was to predict users who could potentially churn in a hypothetical music service Sparkify.
  2. Any user who decided to cancel subscription was identified as ‘churned’ user.
  3. Exploratory data analysis was performed to better understand churn distribution of users based on level/gender/browser type/ and OS type.
  4. Insights from analysis above and other key variables were chosen to perform feature engineering.
  5. The data was split into — train (80%) and test datasets (20%) for model building and evaluation respectively.
  6. ML Pipelines for various classifiers along with Grid Search parameters for performing hypertuning optimization were created for training the model.
  7. Gradient Boosted Trees was chosen as the best model on the basis of F1- Score after using a 2-fold cross validation method on test dataset.

To me there were 2 challenging aspects in the project:

  1. Working with Spark
    I used Spark in local model and it was a computationally taxing job due to memory constraints. I am yet to leverage the true power of Spark in building scalable models with a greater emphasis on speed and efficiency.
  2. Exploring Grid Search Parameters
    I’d love to learn about how to search for better Grid Search parameters for hypertuning optimization. I have used very basic parameters and despite this the models took a bit longer than expected to run. There probably is a better way to do it and I’ll have to do some extensive research on it.

Future Enhancements

There are a couple of potential improvements in future:

  1. Run the code on the full dataset.
    I used a subset of the full dataset of 12GB for my analysis. It would be a great experience to deploy this model in cloud and truly experience the power of leveraging Spark by building a scalable model with no speed or memory constraints.
  2. Collect more data about users.
    We just looked at user behavior for 2 months but in reality it would be nice to have some more history (e.g. 2 years ). By doing this, we can create various metrics such as — number of times user logged in by month, number of times user upgraded or downgraded their services, number of times user responded positively to a promotion/incentives, and some additional demographics information to improve the accuracy of model prediction.
  3. Build a Recommendation Engine.
    By collecting additional data as mentioned above, it would be a great opportunity to build a recommendation engine using Collaborative Filtering technique wherein we could identify users similar to other users based on the songs/artists/genre they listened to, and provide personalized recommendations regarding songs/artists the other user might like in order to make their experience with the app even better.

Github Repo