Creating a Machine Learning Web API with Flask

In our previous post, we went over how to create a simple linear regression model with scikit-learn and how to use it to make predictions. But, that’s not very useful for anyone other than the creator of the model since it’s only available on their machine. In this post, we’ll go over how to use Flask, a microframework for building websites and APIs in Python, to build our Web API and how to persist our model so we can have access to it without always having to retrain it each time we want to make a prediction.

Before we dive in, let’s go over what Flask is and how to set it up. If you want to skip to the code, the API and Jupyter Notebook are both on GitHub.

A Very Brief Introduction to Flask

If you downloaded the Anaconda distribution, you already have Flask installed, otherwise, you will have to install it yourself with – pip install flask.

Flask is very minimal since you only bring in the parts as you need them. To demonstrate this, here’s the Flask code to create a very simple web server.

from flask import Flask

app = Flask(__name__)


@app.route("/")
def hello():
    return "Hello World!"


if __name__ == '__main__':
    app.run(debug=True)

Once executed, you can navigate to the web address, which is shown the terminal, and observe the expected result.

Hello World Flask app
Hello World Flask app

Let’s review what the executed code is doing.

After importing, we create an instance of the Flask class and pass in the __name__ variable that Python fills in for us. This variable will be "__main__", if this file is being directly run through Python as a script. If we import the file instead, the value of __name__ will be the name of the file where we did the import. For instance, if we had test.py and run.py, and we imported test.py into run.py the __name__ value of test.py will be test.

Above our hello method definition, there’s the @app.route("/") line. The @ denotes a decorator, which allows the function, property, or class it’s precedes to be dynamically altered.

The hello method is where we put the code to run whenever the route of our app or API hits the top level route: /.

If our __name__ variable is __main__, indicating that we ran the file directly instead of importing it, then it will start the Flask app which will run and wait for web requests until the process ends.

Creating the API

While I generally use Jupyter Notebooks for all of my regular Python and data code, creating web applications and APIs are a bit different since a web server needs to be running to capture the requests. For the web code, I switch between PyCharm and Visual Studio. For this code, I went with PyCharm to build this API and I’ll detail parts of PyCharm I’ve used in order to help with testing and debugging.

In our last post we already built our model, so being able to call the predict method on it shouldn’t involve building it again. It’s a lot of overhead and unnecessary. But how can we avoid having to recreate the model in our API? This is where data persistence in Python comes into play.

The main goal of our API is to enable a client, whether a website or mobile app, to be able to use our model to make predictions. We’ll also add some extra features along the way, such as getting details about our model.

Before we can start building our API, we need a way to persist our model to a file so we can use that instead of training it every request.

Built-in Model Persistence
Python has a built-in method of persisting data called pickle. The pickle module can serialize objects or data into a file that we can save and load from.

As the documentation mentions, be very careful of loading pickled files if you don’t know where they came from as it could possibly contain malicious code and loading it will execute the contents.

Here’s a quick example of how a pickled file can be used with our linear regression model.

import pickle

with open("python_lin_reg_model.pkl", "wb") as file_handler:
    pickle.dump(lin_reg, file_handler)
    
with open("python_lin_reg_model.pkl", "rb") as file_handler:
    loaded_pickle = pickle.load(file_handler)
    
loaded_pickle

Model Persistence with **scikit-learn**
While this method works, scikit-learn has their own model persistence method we will use: joblib. This is more efficient to use with scikit-learn models due to it being better at handling larger numpy arrays that may be stored in the models.

Since we’ve already created our model from the last post, we can just save it out to the disk.

from sklearn.externals import joblib

joblib.dump(lin_reg, "linear_regression_model.pkl")

Prediction API
The prediction API is quite simple. We give it our data, the years of experience, and pass that into our predict method of our model.

@app.route("/predict", methods=['POST'])
def predict():
    if request.method == 'POST':
        try:
            data = request.get_json()
            years_of_experience = float(data["yearsOfExperience"])
            
            lin_reg = joblib.load("./linear_regression_model.pkl")
        except ValueError:
            return jsonify("Please enter a number.")

        return jsonify(lin_reg.predict(years_of_experience).tolist())

There are quite a few things going on here, so let’s break it down.

Similar to the code above that introduced Flask, we’re calling the @app.route decorator. However, we give it some additional information. We’re telling it that we want it to handle where the URL is /predict and we’re also letting it know to only handle POST requests.

In our route method definition we restrict the request methods to be POST and if so, we get the JSON of the request body so we can access its data. With that variable, we access the key of the data we want – yearsOfExperience and parse it into a float. All of that is wrapped inside of a [try/except](https://docs.python.org/3/tutorial/errors.html) block to catch any exceptions when parsing the yearsOfExperience into a float type which is then passed into our predict method. We also load our model into memory from the persisted file.

If there were no errors parsing the yearsOfExperience data then we pass the parsed variable into the predict method of our linear regression model; the variable was pulled into memory when we loaded our Flask app at startup. Next, we have to change the output of the predict method to a list since, without it, we’d get an error: Object of type 'ndarray' is not JSON serializable. We then call the [jsonify](http://flask.pocoo.org/docs/0.12/api/#flask.json.jsonify) method of Flask to send the response data as JSON.

Finally, we can run our API (in PyCharm you can just right click anywhere in the script and click “Run”). With our API running we can execute code to call it.

import requests

years_exp = {"yearsOfExperience": 8}

response = requests.post("{}/predict".format(BASE_URL), json = years_exp)

response.json()
response.json() ouput
“response.json” result

To call our APIs we’re going to use the [requests](http://docs.python-requests.org/en/master/) package which will make it easier to call APIs than using a built-in module for Python. With requests we can call the post method to indicate we want to send a POST request, and pass in our URL. We then give the method the json parameter of our data. We just pass it a Python dictionary and the json parameter of the post method will automatically send it as JSON to the API.

The next step is to save our response variable and then call the json() method to extract the response values as JSON. The final result is we get our prediction of 8 years of experience getting a salary of $100,712.

We can check our data to see how close we are. We can do a query on our data frame with pandas and the [query](https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.query.html) method.

df_copy.query('YearsExperience > 7 and YearsExperience <= 8')

Our prediction was pretty close. We were only off by a few hundred dollars.

Retrain Model API
As we get more data we can use that to improve our model for better accuracy. Some models in scikit-learn support the partial_fit method for incremental learning, however, the LinearRegression model isn’t one of them. So how can we improve our model? We’d have to retrain it with all of your old data, plus your new data.

In order to do this, we will need to save out the training data and labels.

joblib.dump(train_set, "training_data.pkl")
joblib.dump(train_labels, "training_labels.pkl")

We will also be using pandas and the os Python module, so we will need to import both modules. Let’s take a look at the API and then review it.

@app.route("/retrain", methods=['POST'])
def retrain():
    if request.method == 'POST':
        data = request.get_json()

        try:
            training_set = joblib.load("./training_data.pkl")
            training_labels = joblib.load("./training_labels.pkl")

            df = pd.read_json(data)

            df_training_set = df.drop(["Salary"], axis=1)
            df_training_labels = df["Salary"]

            df_training_set = pd.concat([training_set, df_training_set])
            df_training_labels = pd.concat([training_labels, df_training_labels])

            new_lin_reg = LinearRegression()
            new_lin_reg.fit(df_training_set, df_training_labels)

            os.remove("./linear_regression_model.pkl")
            os.remove("./training_data.pkl")
            os.remove("./training_labels.pkl")

            joblib.dump(new_lin_reg, "linear_regression_model.pkl")
            joblib.dump(df_training_set, "training_data.pkl")
            joblib.dump(df_training_labels, "training_labels.pkl")

            lin_reg = joblib.load("./linear_regression_model.pkl")
        except ValueError as e:
            return jsonify("Error when retraining - {}".format(e))

        return jsonify("Retrained model successfully.")

In this implementation, we get our JSON data from our request and load our training set and training labels into memory. We then use pandas to load the request data as JSON to create a data frame from it.

From the new data frame, we do the same splitting of our data into separate data frames for the training set and for the training labels. Then we use pandas to concatenate both the previous training set and the new set as well as the previous training labels and the new labels.

With our new data, we then call the fit method to create our new model. Unfortunately, when calling joblib.dump it doesn’t overwrite the file if it already exists, so we use the os.remove function to remove the saved model and data then save them again. Finally, we load the model back into memory.

Once we have that added to our API we can call it. Flask will restart the server once new changes have been found, so no need to restart it. First, we need some new data.

data = json.dumps([{"YearsExperience": 12,"Salary": 140000}, 
                   {"YearsExperience": 12.1,"Salary": 142000}])

With this new data, we can then call the retrain API.

response = requests.post("{}/retrain".format(BASE_URL), json = data)

response.json()
“response.json” result

Now that we’ve retrained the model, let’s do another prediction on it with the same input as before.

response = requests.post("{}/predict".format(BASE_URL), json = years_exp)

response.json()
response.json() ouput
“response.json” result

The new data makes our prediction a bit more accurate than it was before since we now have more data that it can train on.

Model Details API
Details such as the coefficients and intercepts of the model and the current score of the model may be another useful endpoint for our API.

if request.method == 'GET':
    try:
        lr = joblib.load("./linear_regression_model.pkl")
        training_set = joblib.load("./training_data.pkl")
        labels = joblib.load("./training_labels.pkl")

        return jsonify({"score": lr.score(training_set, labels),
                        "coefficients": lr.coef_.tolist(), "intercepts": lr.intercept_})
    except (ValueError, TypeError) as e:
        return jsonify("Error when getting details - {}".format(e))

This time we use a GET method since we don’t need to pass in any information. Here we are loading the training data and training labels along with our model. With our model, we call the score method and pass in the training set and labels to get our score. The coefficients and intercept are just attributes of the model. We have to call tolist() on the coefficients the same way we had to add it to the prediction results from earlier.

Now, let’s call it and see what we get.

response = requests.get("{}/currentDetails".format(BASE_URL))

response.json()


In this post, we learned what Flask is, how to use it to create APIs, and most importantly how to apply this knowledge to creating APIs for interacting with machine learning models.

Do you have a Data or Machine Learning project?

Data & Machine Learning Consulting  Data & Machine Learning Training

Stay Informed

Sign up for the latest blogs, events, and insights.

We deliver solutions that accelerate the value of Azure.
Ready to experience the full power of Microsoft Azure?

Atmosera is thrilled to announce that we have been named GitHub AI Partner of the Year.

X