Warm tip: This article is reproduced from stackoverflow.com, please click
flask keras python tensorflow

How to share a Tensorflow Keras model with a Flask route function?

发布于 2020-03-27 10:21:56

Does tensorflow maintains its own internal global state, which is broken by loading the model in one function and trying to use it in another?

Using singleton for storing model:

class Singleton(object):
    _instances = {}

    def __new__(class_, *args, **kwargs):
        if class_ not in class_._instances:
            class_._instances[class_] = super(Singleton, class_).__new__(class_, *args, **kwargs)
        return class_._instances[class_]


class Context(Singleton):
    pass

When I do:

@app.route('/file', methods=['GET', 'POST'])
def upload_file():
    if request.method == 'POST':
        file = request.files['file']
        if file and allowed_file(file.filename):
            # filename = secure_filename(file.filename)
            filename = file.filename
            filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
            file.save(filepath)

            context = Context()
            if context.loaded:

                img = cv2.imread(filepath)
                img = cv2.resize(img, (96, 96))
                img = img.astype("float") / 255.0
                img = img_to_array(img)
                img = np.expand_dims(img, axis=0)

                classes = context.model.predict(img)

def api_run():
    context = Context()
    context.model = load_model('model.h5')
    context.loaded = True

I'm getting some error: ValueError: Tensor Tensor("dense_1/Softmax:0", shape=(?, 2), dtype=float32) is not an element of this graph.

However if I will move context.model = load_model('model.h5') inside upload_file function then everything will work. Why is that happening? How to store model for later use?

Questioner
cnd
Viewed
128
Fra 2019-10-16 14:16

Yes, Tensorflow in graph mode has its own internal global state.

You don't want to reload your model at every prediction, that's really inefficient.

The right strategy is to load the model at the start of your web app and then reference the global state.

Use a global variable for the model and graph and do something like this:

loaded_model = None
graph = None

def load_model(export_path):

    # global variables
    global loaded_model
    global graph

    loaded_model = load_model('model.h5'))
    graph = tf.get_default_graph()

then, in your prediction function you do:

@app.route('/', methods=["POST"])
def predict():
    if request.method == "POST":
        data = request.data

        with graph.as_default():
            probas = loaded_model.predict(data)

An complete short example for how to do this can be found here.

Alternatively, if you use Tensorflow 2.0, which defaults to Eager mode, you have no graph and therefore no problem.