Does tensorflow maintains its own internal global state, which is broken by loading the model in one function and trying to use it in another?
Using singleton for storing model:
class Singleton(object):
_instances = {}
def __new__(class_, *args, **kwargs):
if class_ not in class_._instances:
class_._instances[class_] = super(Singleton, class_).__new__(class_, *args, **kwargs)
return class_._instances[class_]
class Context(Singleton):
pass
When I do:
@app.route('/file', methods=['GET', 'POST'])
def upload_file():
if request.method == 'POST':
file = request.files['file']
if file and allowed_file(file.filename):
# filename = secure_filename(file.filename)
filename = file.filename
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
context = Context()
if context.loaded:
img = cv2.imread(filepath)
img = cv2.resize(img, (96, 96))
img = img.astype("float") / 255.0
img = img_to_array(img)
img = np.expand_dims(img, axis=0)
classes = context.model.predict(img)
def api_run():
context = Context()
context.model = load_model('model.h5')
context.loaded = True
I'm getting some error: ValueError: Tensor Tensor("dense_1/Softmax:0", shape=(?, 2), dtype=float32) is not an element of this graph.
However if I will move context.model = load_model('model.h5')
inside upload_file
function then everything will work. Why is that happening? How to store model for later use?
Yes, Tensorflow in graph mode has its own internal global state.
You don't want to reload your model at every prediction, that's really inefficient.
The right strategy is to load the model at the start of your web app and then reference the global state.
Use a global variable for the model and graph and do something like this:
loaded_model = None
graph = None
def load_model(export_path):
# global variables
global loaded_model
global graph
loaded_model = load_model('model.h5'))
graph = tf.get_default_graph()
then, in your prediction function you do:
@app.route('/', methods=["POST"])
def predict():
if request.method == "POST":
data = request.data
with graph.as_default():
probas = loaded_model.predict(data)
An complete short example for how to do this can be found here.
Alternatively, if you use Tensorflow 2.0, which defaults to Eager mode, you have no graph and therefore no problem.