How to deploy Keras model to production using flask (part – 2)

Deploy Keras model to production, Deploy Keras model to production using flask

Hello everyone, this is part two of the two-part tutorial series on how to deploy Keras model to production. In part one of the tutorial series, we looked at how to use Convolutional Neural Network (CNN) to classify MNIST Handwritten digits using Keras. We also saved the model file obtained after training. In this part of the tutorial series, we are going to see how to deploy Keras model to production using Flask.

Flask is part of the categories of the micro-framework. Micro-frameworks are normally framework with little to no dependencies to external libraries. This means that the framework is light and there is little dependency to update and watch for security bugs. A very simple flask app for web rendering would something like:

from flask import Flask
app = Flask(__name__)

@app.route("/")
def index():
    return "Hello World!"

if __name__ == "__main__":
    app.run(host='0.0.0.0', port=5000)

Save it under the filename app.py and run it using the command

python3 app.py

Just as easy as that your ‘hello world’ flask web app is up and running. I won’t be able to cover complete flask tutorial on this post but you can refer to flask documentation if you want to learn more. It’s well documented and easy to understand.

The basic structure of a flask web application looks like this:

$ tree deploy_mnist_flask/
deploy_mnist_flask/
|-- static
|-- templates

The templates folder is the place where the templates will be put. The static folder is the place where any files (images, css, javascript) needed by the web application will be put.

Make a file named index.html inside the templates directory and copy/paste the code below in the file. This is the html file we will render using flask.

<!DOCTYPE html>
<html lang="en">
  <head>
    <meta charset="utf-8">
    <meta http-equiv="X-UA-Compatible" content="IE=edge">
    <meta name="viewport" content="width=device-width, initial-scale=1">
    <!-- The above 3 meta tags *must* come first in the head; any other head content must come *after* these tags -->
    <meta name="description" content="">
    <meta name="author" content="">

    <title>MNIST Handwritten text recognition using keras</title>

    <!-- Bootstrap core CSS -->
    <link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.7/css/bootstrap.min.css" integrity="sha384-BVYiiSIFeK1dGmJRAkycuHAHRg32OmUcww7on3RYdg4Va+PmSTsz/K68vbdEjh4u" crossorigin="anonymous">
    <link rel="stylesheet" href="{{ url_for('static',filename='style.css') }}">
  </head>

  <body>

    <div class="container">
      <div class="header clearfix">
        <nav>
          <ul class="nav nav-pills pull-right">
            <li role="presentation" class="active"><a href="#">Home</a></li>
            <li role="presentation"><a href="http://www.python36.com/">About</a></li>
          </ul>
        </nav>
        <h3 class="text-muted">MNIST Handwritten CNN</h3>
      </div>

      <div class="jumbotron">
        <h3 class="jumbotronHeading">Draw the digit inside this Box!</h3>
    <div class="slidecontainer">
      <p>Drag the slider to change the line width.</p>
      <input type="range" min="10" max="50" value="15" id="myRange">
      <p>Value: <span id="sliderValue"></span></p>
    </div>
    <div class="canvasDiv">
          <canvas id="canvas" width="280" height="280"></canvas>
          <br>
          <p style="text-align:center;">
            <a class="btn btn-success myButton" href="#" role="button">Predict</a>
            <a class="btn btn-primary" href="#" id="clearButton" role="button">Clear</a>
      	</p>
        </div>
      </div>

      <div class="jumbotron">
      	<p id="result">Get your prediction here!!!</p>
      </div>

      <footer class="footer">
        <p>&copy; 2018, python36.com</p>
      </footer>

    </div> <!-- /container -->


  <script src='http://cdnjs.cloudflare.com/ajax/libs/jquery/2.1.3/jquery.min.js'></script>

    <script src="{{ url_for('static',filename='index.js') }}"></script>

    <script type="text/javascript">
     
    $(".myButton").click(function(){
      var $SCRIPT_ROOT = {{ request.script_root|tojson|safe }};
      var canvasObj = document.getElementById("canvas");
      var img = canvasObj.toDataURL();
      $.ajax({
        type: "POST",
        url: $SCRIPT_ROOT + "/predict/",
        data: img,
        success: function(data){
          $('#result').text(' Predicted Output: '+data);
        }
      });
    });
   
   	</script>

  </body>
</html>

This is the html for our landing page. It links to two external files, index.js and style.css

The code forindex.js is:

(function()
{
  var canvas = document.querySelector( "#canvas" );
  var context = canvas.getContext( "2d" );
  canvas.width = 280;
  canvas.height = 280;

  var Mouse = { x: 0, y: 0 };
  var lastMouse = { x: 0, y: 0 };
  context.fillStyle="black";
  context.fillRect(0,0,canvas.width,canvas.height);
  context.color = "white";
  context.lineWidth = 15;
    context.lineJoin = context.lineCap = 'round';
  
  debug();

  canvas.addEventListener( "mousemove", function( e )
  {
    lastMouse.x = Mouse.x;
    lastMouse.y = Mouse.y;

    Mouse.x = e.pageX - this.offsetLeft;
    Mouse.y = e.pageY - this.offsetTop;

  }, false );

  canvas.addEventListener( "mousedown", function( e )
  {
    canvas.addEventListener( "mousemove", onPaint, false );

  }, false );

  canvas.addEventListener( "mouseup", function()
  {
    canvas.removeEventListener( "mousemove", onPaint, false );

  }, false );

  var onPaint = function()
  {	
    context.lineWidth = context.lineWidth;
    context.lineJoin = "round";
    context.lineCap = "round";
    context.strokeStyle = context.color;
  
    context.beginPath();
    context.moveTo( lastMouse.x, lastMouse.y );
    context.lineTo( Mouse.x, Mouse.y );
    context.closePath();
    context.stroke();
  };

  function debug()
  {
    /* CLEAR BUTTON */
    var clearButton = $( "#clearButton" );
    
    clearButton.on( "click", function()
    {
      
        context.clearRect( 0, 0, 280, 280 );
        context.fillStyle="black";
        context.fillRect(0,0,canvas.width,canvas.height);
      
    });
    
    /* LINE WIDTH */

    var slider = document.getElementById("myRange");
    var output = document.getElementById("sliderValue");
    output.innerHTML = slider.value;

    slider.oninput = function() {
      output.innerHTML = this.value;
      context.lineWidth = $( this ).val();
    }
    
    $( "#lineWidth" ).change(function()
    {
      context.lineWidth = $( this ).val();
    });
  }
}());

And the code for style.css is

/* Space out content a bit */
body {
  padding-top: 20px;
  padding-bottom: 20px;
}

/* Everything but the jumbotron gets side spacing for mobile first views */
.header,
.footer {
  padding-right: 15px;
  padding-left: 15px;
}

/* Custom page header */
.header {
  padding-bottom: 20px;
  border-bottom: 1px solid #e5e5e5;
}
/* Make the masthead heading the same height as the navigation */
.header h3 {
  margin-top: 0;
  margin-bottom: 0;
  line-height: 40px;
}

/* Custom page footer */
.footer {
  padding-top: 19px;
  color: #777;
  border-top: 1px solid #e5e5e5;
}

/* Customize container */
@media (min-width: 768px) {
  .container {
    max-width: 730px;
  }
}
.container-narrow > hr {
  margin: 30px 0;
}

/* Main marketing message and sign up button */
.jumbotron {
  text-align: center;
  border-bottom: 1px solid #e5e5e5;
  padding-top: 20px;
  padding-bottom: 20px;
}

.bodyDiv{
  text-align: center;
}

@media screen and (min-width: 768px) {
  /* Remove the padding we set earlier */
  .header,
  .footer {
    padding-right: 0;
    padding-left: 0;
  }
  /* Space out the masthead */
  .header {
    margin-bottom: 30px;
  }
  /* Remove the bottom border on the jumbotron for visual effect */
  .jumbotron {
    border-bottom: 0;
  }
}

@media screen and (max-width: 500px) {
  .slidecontainer{
    display: none;
  }

}

.slidecontainer{
  float: left;
  width: 30%;
}

.jumbotronHeading{
  margin-bottom: 7vh;
}

.canvasDiv{
  display: flow-root;
  text-align: center;
}

Save these files inside the static directory. I am not going into an in-depth discussion about how to write javascript code or css styling in the post but if you are having problems understanding any of it, please feel free to mention below in the comments section below and we will help you.

Okay! We have all the files required to render our webpage ready. We use render_templatefunction from the flask module to render the html file. Import the module at the top of your page and use it to render the index.html page we just created by updating our app.py file as below:

from flask import Flask, render_template
app = Flask(__name__)

@app.route("/")
def index():
    return render_template("index.html")

if __name__ == "__main__":
    app.run(host='0.0.0.0', port=5000)

If you haven’t modified the html/css file, you should see a webpage that looks like the one below:
Deploy keras module to production using flask

We will copy the model.h5″ and model.json” files we created in part 1 of this tutorial inside the model folder in the working directory. Create a directory named model and copy paste the files inside the folder. We will also create a file named load.pywhich loads the model structure and model weight. Copy/paste the code below in the file:

from keras.models import model_from_json
import tensorflow as tf


def init():
  json_file = open('model.json','r')
  loaded_model_json = json_file.read()
  json_file.close()
  loaded_model = model_from_json(loaded_model_json)
  #load weights into new model
  loaded_model.load_weights("model.h5")
  print("Loaded Model from disk")

  #compile and evaluate loaded model
  loaded_model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
  graph = tf.get_default_graph()

  return loaded_model,graph

We will later call this from app.py

As of right now, the slider to change the line width and the clear button works but as we can see, the predict doesn’t do anything. Let’s fix that. First, let’s update our app.py file with the code below. The code is well documented in itself so I don’t think there is much need to explain everything. However, we will go through some core concepts.

# requests are objects that flask handles (get set post, etc)
from flask import Flask, render_template, request
# scientific computing library for saving, reading, and resizing images
from scipy.misc import imread, imresize
# for matrix math
import numpy as np
# for regular expressions, saves time dealing with string data
import re
# system level operations (like loading files)
import sys
# for reading operating system data
import os

# tell our app where our saved model is
sys.path.append(os.path.abspath("./model"))

from load import *

# initalize our flask app
app = Flask(__name__)
# global vars for easy reusability
global model, graph
# initialize these variables
model, graph = init()

import base64


# decoding an image from base64 into raw representation
def convertImage(imgData1):
    imgstr = re.search(r'base64,(.*)', str(imgData1)).group(1)
    with open('output.png', 'wb') as output:
        output.write(base64.b64decode(imgstr))


@app.route('/')
def index():
    return render_template("index.html")


@app.route('/predict/', methods=['GET', 'POST'])
def predict():
    # whenever the predict method is called, we're going
    # to input the user drawn character as an image into the model
    # perform inference, and return the classification
    # get the raw data format of the image
    imgData = request.get_data()
    # encode it into a suitable format
    convertImage(imgData)
    # read the image into memory
    x = imread('output.png', mode='L')
    # make it the right size
    x = imresize(x, (28, 28))
    # imsave('final_image.jpg', x)
    # convert to a 4D tensor to feed into our model
    x = x.reshape(1, 28, 28, 1)
    # in our computation graph
    with graph.as_default():
        # perform the prediction
        out = model.predict(x)
        print(out)
        print(np.argmax(out, axis=1))
        # convert the response to a string
        response = np.argmax(out, axis=1)
        return str(response[0])


if __name__ == "__main__":
    # run the app locally on the given port
    app.run(host='0.0.0.0', port=5000)
# optional if we want to run in debugging mode
# app.run(debug=True)

At the top of our file, we import the necessary libraries. We already have the web server up and running. The load.py file we created earlier is used to load the model weight and model structure so that we can make the prediction. What we do next is, whenever anyone clicks the Predict button, we read the image on the canvas. Such read image is passed to the predict function in the base64 format which we will convert to a .png file. We will then use the converted image and resize it to 28*28 pixels as used in the MNIST dataset. The image is then converted to a 4D tensor as used in the training set and we use the model.predict() function to predict the output and pass the response back to the calling AJAX method to update the html div display the result.

And presto!

MNIST digit classification in browser

If you’re having any trouble at any of the steps, please reach out to us in the comments section and we will solve it for you. The complete code used in this tutorial can be found in this github repo.

A huge shoutout to Siraj Rawal as this tutorial series is inspired by his video “How to deploy Keras model to production” 

Be the first to comment

Leave a Reply

Your email address will not be published.