温馨提示:本文翻译自stackoverflow.com,查看原文请点击:python - How to share a Tensorflow Keras model with a Flask route function?

python - 如何使用Flask路由功能共享Tensorflow Keras模型?

发布于 2020-03-27 11:04:12

tensorflow是否维护自己的内部全局状态?这种状态是通过将模型加载到一个函数中并尝试在另一个函数中使用而破坏的?

使用单例存储模型:

class Singleton(object):
    _instances = {}

    def __new__(class_, *args, **kwargs):
        if class_ not in class_._instances:
            class_._instances[class_] = super(Singleton, class_).__new__(class_, *args, **kwargs)
        return class_._instances[class_]


class Context(Singleton):
    pass

当我做:

@app.route('/file', methods=['GET', 'POST'])
def upload_file():
    if request.method == 'POST':
        file = request.files['file']
        if file and allowed_file(file.filename):
            # filename = secure_filename(file.filename)
            filename = file.filename
            filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
            file.save(filepath)

            context = Context()
            if context.loaded:

                img = cv2.imread(filepath)
                img = cv2.resize(img, (96, 96))
                img = img.astype("float") / 255.0
                img = img_to_array(img)
                img = np.expand_dims(img, axis=0)

                classes = context.model.predict(img)

def api_run():
    context = Context()
    context.model = load_model('model.h5')
    context.loaded = True

我遇到一些错误: ValueError: Tensor Tensor("dense_1/Softmax:0", shape=(?, 2), dtype=float32) is not an element of this graph.

但是,如果我要移入函数context.model = load_model('model.h5')内部,upload_file则一切正常。为什么会这样呢?如何存储模型以备后用?

查看更多

查看更多

提问者
cnd
被浏览
47
Fra 2019-10-16 14:16

是的,图模式下的Tensorflow具有自己的内部全局状态。

您不想在每次预测时都重新加载模型,这确实效率很低。

正确的策略是在Web应用程序的开头加载模型,然后引用全局状态。

对模型和图形使用全局变量,然后执行以下操作:

loaded_model = None
graph = None

def load_model(export_path):

    # global variables
    global loaded_model
    global graph

    loaded_model = load_model('model.h5'))
    graph = tf.get_default_graph()

然后,在您的预测函数中执行以下操作:

@app.route('/', methods=["POST"])
def predict():
    if request.method == "POST":
        data = request.data

        with graph.as_default():
            probas = loaded_model.predict(data)

有关如何执行此操作的完整简短示例,请参见此处

另外,如果您使用默认为“急切”模式的Tensorflow 2.0,则没有图形,因此没有问题。

发布
问题

分享
好友

手机
浏览

扫码手机浏览