Sparkify — Churn Prediction

This blog was written to meet requirements for my Udacity “Data Scientist Nanodegree” capstone project. In this project, several prediction models have been tested using Spark to identify users that have higher likelihood to churn for a fictitious streaming music service.


Sparkify is a music streaming service ( like Spotify, Pandora, etc. ). In Sparkify users can either listen to music for free or buy a subscription. The free users have to listen to ads while the subscription users ( or paid users ) listen to songs ad-free. Users need a login to listen to a song on the service.

From a subscription perspective at any point, users can do any of the following:

  1. Upgrade from the free tier to the paid tier

Key Business Question

To identify the users that could potentially cancel their account and leave the service. By identifying this population upfront, we could incentivize them by giving them discounts or other rewards. This could make them stick, giving us a loyal customer base which is key for a company’s growth.

To attain the objective, it’s important to first do exploratory analysis to glean insights from the data set and identify key variables of interest. The next step is to test different model algorithms and then pick the best model based on key evaluation metric (F1 Score) using Spark ML Library.


The data we have from Sparkify is that of user events. Every interaction of every user with the application is given to us. This means every time a user goes to the Home page, listens to a song, thumbs up a song, etc. we have an event in the data corresponding to the same. Each record in our data represents an event by a user.

Some of the key attributes present in each event in the data are:

  1. userId : A unique ID that identifies a user of the service


1. Data Cleaning and Exploration

  1. Invalid data such as records with missing UserIDs and SessionIDs were removed from analysis.

9. The churn rate for users leveraging ‘Firefox’ browser is significantly higher than users using other browser types as shown in chart below.

10. The churn rate for users leveraging ‘Linux’ OS is significantly higher than users using other OS types as shown in chart below.

11. The churn rate for ‘free’ user is higher than ‘paid’ user (24% vs 20%) as shown in chart below.

2. Feature Engineering

Below is a list of features I incorporated to build a ML model to predict churn at user level:

Numerical Features

  1. Total number of songs

Categorial Features

  1. Gender type

All missing values were imputed with 0 in this case. The dataset was further split into — train (80%) and validation (20%) for building and validating the model respectively.

The next step was to create a Vector by combining the features together and scaling them using MinMaxScaler. For each value in a feature,MinMaxScaler subtracts the minimum value in the feature and then divides by the range. MinMaxScaler preserves the shape of the original distribution.

3. Model Building

Below are some classification models which I attempted :

  1. Logistic Regression (LR)

For each model algorithm mentioned above, the following steps were followed:

  1. The Classifier was initially defined to create a Pipeline which took into account various parameters such as — Vector, Scaler, and Classifier.

2. Created a Parameter Grid to perform hyperparameter optimization to identify the best model which minimizes the error rate.

For logistic regression, due to memory constraints (regParam ([0.1,0.01]) was chosen in Grid Search. This is the regularization parameter which penalizes coefficient of variables if they are correlated with other variables in the model thereby making the model less redundant in nature.

For other Tree based methods, I used only two parameters such as (maxDepth[2,4]) and maxIter ([1,5]).

MaxDepth is referred to as maximum depth of the tree. Deeper trees are very expressive and may improve the efficacy of model prediction but are computationally expensive. Hence, I chose to use smaller set of values for this parameter.

MaxIter is the number of trees in an ensemble method. The tree based methods are ensemble methods i.e. they combine several base models to provide one best optimal model. I tried using some higher values but due to computational constraints had to cap this to 5.

3. CrossValidator objects were instantiated using Pipeline and Parameter Grid values for different algorithms as shown below.

4.The model was then trained on train dataset and cross-validated with (k=2) folds on validation dataset.

Here’s a snippet of the code:

4. Model Evaluation

F1 Score:

As it’s a classification model, F1 score was chosen a key evaluation metric to select the best model. Below are reasons why F1 score was chosen as key evaluation measure:

  • It provides a comprehensive summary of many metrics like Recall, Precision, True Positive, False Positive, False Negatives into one.

The final metrics for all models after 2-fold cross-validation are as follows:

GBTs build trees one at a time, where each new tree helps to correct errors made by previously trained tree. With each tree added, the model becomes even more expressive. After performing hypertuning optimization, the best values for Maximum Depth and Maximum Iterations parameters identified by Grid Search are 4 and 5 respectively. This makes sense as increasing the depth of tree and number of trees impacts model performance in a positive manner. Since, a 2-fold cross-validation was done, the danger of running into model overfitting issue was reduced significantly as well.

Based on results above, the GBT Classification is the best model as it has the highest F1 Score of 0.77.


Let’s take a step back and look at the whole journey.

  1. The objective was to predict users who could potentially churn in a hypothetical music service Sparkify.

To me there were 2 challenging aspects in the project:

  1. Working with Spark
    I used Spark in local model and it was a computationally taxing job due to memory constraints. I am yet to leverage the true power of Spark in building scalable models with a greater emphasis on speed and efficiency.

Future Enhancements

There are a couple of potential improvements in future:

  1. Run the code on the full dataset.
    I used a subset of the full dataset of 12GB for my analysis. It would be a great experience to deploy this model in cloud and truly experience the power of leveraging Spark by building a scalable model with no speed or memory constraints.

Github Repo

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store