import tensorflow as tf from trainsetting import train_db, dev_db import models import setting # 从models文件中导入模型 model = models.my_densenet() model.summary() # 配置优化器、损失函数、以及监控指标 model.compile(tf.keras.optimizers.Adam(setting.LEARNING_RATE), loss=tf.keras.losses.categorical_crossentropy, metrics=['accuracy']) # 在每个epoch结束后尝试保存模型参数,只有当前参数的val_accuracy比之前保存的更优时,才会覆盖掉之前保存的参数 model_check_point = tf.keras.callbacks.ModelCheckpoint( filepath=setting.MODEL_PATH, monitor='val_accuracy', save_best_only=True) # 使用tf.keras的高级接口进行训练 model.fit_generator(train_db, epochs=setting.TRAIN_EPOCHS, validation_data=dev_db, callbacks=[model_check_point])
# @File : app.py # @Author : AaronJny # @Time : 2019/12/18 # @Desc : import tensorflow as tf from flask import Flask from flask import jsonify from flask import request, render_template import settings from models import my_densenet app = Flask(__name__) # 导入模型 model = my_densenet() # 加载训练好的参数 model.load_weights(settings.MODEL_PATH) @app.route('/', methods=['GET']) def index(): """ 首页,vue入口 """ return render_template('index.html') @app.route('/api/v1/pets_classify/', methods=['POST']) def pets_classify(): """