Using streamlit to set the threshold for a classifier

In this post I will demonstrate how to use streamlit to build an app that can help interactively set the threshold for your machine learning classifier. Code available at timvink/demo_streamlit_threshold_classifier.

Introduction

Many machine learning classifiers can output probabilities instead of direct predictions (f.e. using sklearn’s .predict_proba() method). To predict a class these probabilities can be cut-off at a certain threshold. This threshold often defaults at >=0.5, but many business problems benefit from thoughtful setting of this threshold. This is especially true for unbalanced machine learning problems. Changing the threshold is inherently a tradeoff between precision and recall and should be done together with business stakeholders that understand the problem domain. And why not do that using an interactive app instead of a slide?

Streamlit

streamlit is an open source python library that makes it easy to build a custom web app. You can compare it to dash or R’s shiny package. Dash is more fully featured and customizable, but for quick prototyping I find streamlit is much simpler and easier to learn.

Modelling problem

Let’s say we have a straightforward binary classification problem. We generate a dataset X with 30k observations, 20 features and a class imbalance of 9:1. We’ll use a stratified train_test_split with 80% train and 20% test. Next, we train a simple RandomForestClassifier model on the train set, using a 5-fold cross-validated grid search to tune the hyperparameters. If you want to see the code see model.py.

To tune the threshold, we’ll need to save the actuals and the predicted probabilities for both train and test datasets. With those 4 arrays, we can compute static performance metrics like roc_auc_score, but also metrics that depend on the threshold, like precision_score and recall_score. You can find the code in eval.py.

Our app

In a new app.py file, we can add user interface elements like a title and slider with:

import streamlit as st
st.title("Setting the threshold for our classifier")
threshold = st.slider("Threshold", min_value=0.00, max_value=1.0, step=0.01, value=0.5)

We also need to get the predicted probabilities and actuals from our model. Because we don’t want to recalculate the entire model every time we change the model, streamlit offers the possibility to cache the results:

from model import get_predictions # custom model code

@st.cache()
def cached_get_predictions():
    return get_predictions()

y_train, yhat_prob_train, y_test, yhat_prob_test = cached_get_predictions()

Next up is to calculate the metrics that depend on the threshold and display them as a table in the user interface:

from eval import get_metrics_df # custom code

metrics = get_metrics_df(
    y_train, yhat_prob_train, y_test, yhat_prob_test, threshold=threshold
)
st.dataframe(metrics)

There are many other components I could add (see Streamlit API reference). I added a matplotlib plot to visualize the threshold setting. You can see the whole project github: timvink/demo_streamlit_threshold_classifier. Here’s what it looks like:

demo streamlit app

Conclusion & further reading

This was a basic example of setting a threshold. From here, you could consider extending the app to cover aspects like fairness (see for example Attacking discrimination with smarter machine learning).

I find streamlit an easy to use and quick to learn library to add some quick interactivity to certain analysis. I still find it too involved for a one-off analysis, but for scenarios with some re-usability I find it well worth your time learning.

Some good resources for further reading on streamlit:

Tim Vink

Amsterdam, Netherlands timvink.nl

Liked the post? Subscribe to the mailing list