import tensorflow as tf from flask import Flask from flask import jsonify from flask import request, render_template import os,sys from cfg import settings # 导入模型 from models import my_mobilenet_v3 app = Flask(__name__) os.chdir(os.path.dirname(sys.argv[0])) #加载模型 model=my_mobilenet_v3() # 加载训练好的参数 if os.path.exists(settings.MODEL_PATH + '.index'): print('-------------load the model-----------------') model.load_weights(settings.MODEL_PATH) @app.route('/', methods=['GET']) #首页,vue入口 def index(): """ 首页,vue入口 """ return render_template('index.html') @app.route('/api/v1/dogs_classify/', methods=['POST']) #宠物狗图片分类接口
# -*- coding: utf-8 -*- # @Time : 2020/12/20 # @Author : Barbra # @File : train.py # @Software : PyCharm # @Desc : 训练 import os import tensorflow as tf import models import settings from matplotlib import pyplot as plt from data import train_db, test_db # 从models文件中导入模型 model = models.my_mobilenet_v3() model.summary() exponential_decay = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate=0.1, decay_steps=1, decay_rate=0.99) # 配置优化器、损失函数、以及监控指标 model.compile(tf.keras.optimizers.Adam(exponential_decay), loss=tf.keras.losses.categorical_crossentropy, metrics=['accuracy']) # 在每个epoch结束后尝试保存模型参数,设置断点续训 if os.path.exists(settings.MODEL_PATH + '.index'): print('-------------load the model-----------------') model.load_weights(settings.MODEL_PATH)