def generate(self): model_path = os.path.expanduser(self.model_path) assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.' # 计算总的种类 self.num_classes = len(self.class_names)+1 # 载入模型 inputs = Input(self.model_image_size) self.m2det = M2det.m2det(self.num_classes,inputs) self.m2det.load_weights(self.model_path,by_name=True) self.m2det.summary() print('{} model, anchors, and classes loaded.'.format(model_path)) # 画框设置不同的颜色 hsv_tuples = [(x / len(self.class_names), 1., 1.) for x in range(len(self.class_names))] self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) self.colors = list( map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))
import numpy as np import keras from keras.optimizers import Adam from nets.M2det_training import Generator from nets.M2det_training import conf_loss, smooth_l1 from keras.callbacks import TensorBoard, ModelCheckpoint, ReduceLROnPlateau, EarlyStopping from utils.utils import BBoxUtility from utils.anchors import get_anchors if __name__ == "__main__": NUM_CLASSES = 21 input_shape = (320, 320, 3) annotation_path = '2007_train.txt' inputs = keras.layers.Input(shape=input_shape) model = M2det.m2det(NUM_CLASSES, inputs) priors = get_anchors((input_shape[0], input_shape[1])) bbox_util = BBoxUtility(NUM_CLASSES, priors) model.load_weights("model_data\M2det_weights.h5", by_name=True, skip_mismatch=True) model.summary() # 0.1用于验证,0.9用于训练 val_split = 0.1 with open(annotation_path) as f: lines = f.readlines() np.random.seed(10101) np.random.shuffle(lines) np.random.seed(None)
# 训练之前一定要修改NUM_CLASSES # 修改成所需要区分的类的个数+1。 #----------------------------------------------------# NUM_CLASSES = 21 #----------------------------------------------------# # 输入图像大小 #----------------------------------------------------# input_shape = (320, 320, 3) #----------------------------------------------------# # 获得先验框 #----------------------------------------------------# anchors_size = [0.08, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05] priors = get_anchors((input_shape[0],input_shape[1]), anchors_size) bbox_util = BBoxUtility(NUM_CLASSES, priors) model = M2det.m2det(NUM_CLASSES,input_shape) #------------------------------------------------------# # 权值文件请看README,百度网盘下载 # 训练自己的数据集时提示维度不匹配正常 # 预测的东西都不一样了自然维度不匹配 #------------------------------------------------------# model.load_weights("model_data/M2det_weights.h5", by_name=True, skip_mismatch=True) #-------------------------------------------------------------------------------# # 训练参数的设置 # logging表示tensorboard的保存地址 # checkpoint用于设置权值保存的细节,period用于修改多少epoch保存一次 # reduce_lr用于设置学习率下降的方式 # early_stopping用于设定早停,val_loss多次不下降自动结束训练,表示模型基本收敛 #-------------------------------------------------------------------------------# logging = TensorBoard(log_dir='logs')