def train_msml(feat_model, model, train_dataset, modelpath, val_dataset=None, train_flag=True, aug=True, models_folder='saved_models', epochs=120, batch_size=4, img_size=(224, 320), label_smoothing=False, wlr=False, BN=False, CenterLoss=False): losses_list = [msml(), 'categorical_crossentropy'] if CenterLoss: print('Using CenterLoss') losses_list.append(lambda y_true, y_pred: y_pred) losses_list.append(lambda y_true, y_pred: y_pred) model.compile( optimizer=standard_optimizer(), loss=losses_list, ) train_dataset = Dataset(train_dataset, to_rgb=False) modelpath = os.path.join(models_folder, modelpath) if (os.path.exists(modelpath)): print('MSML carregada.') model.load_weights(modelpath) if train_flag: checkpoint = ModelCheckpoint(modelpath, monitor='loss', verbose=1, save_best_only=True, mode='min') if wlr: print('Using warmup lr') lrate = LearningRateScheduler(warmup_lr) else: lrate = LearningRateScheduler(step_decay) callbacks_list = [lrate, checkpoint] if not CenterLoss: msml_gen = general_generator(feat_model, train_dataset, batch_size=batch_size, img_size=img_size, aug=aug, label_smoothing=label_smoothing) else: msml_gen = general_generator_center( feat_model, train_dataset, batch_size=batch_size, img_size=img_size, aug=aug, label_smoothing=label_smoothing) H = model.fit_generator(msml_gen, steps_per_epoch=int(train_dataset.ident_num() / batch_size), epochs=epochs, callbacks=callbacks_list, workers=1)
models_folder = '' MoRe_path = "MoRe.hdf5" MoRe = Dataset(MoRe_path, to_bgr=True) # Images were used in BGR format INPUT_SHAPE = (256, 256) feat_model = ResNet50_LastStride(input_shape=(INPUT_SHAPE[0], INPUT_SHAPE[1], 3)) # BASELINE STRIDE = 1 feat_model.load_weights(os.path.join(models_folder, 'RESNET50_ORIGINAL_WEIGHTS.h5'), by_name=True) feat_model._make_predict_function() identity_model = classification_layer_baseline(MoRe.ident_num('train'), BN=True) # Loading BN clsNet = classification_net(feat_model, identity_model, img_shape=INPUT_SHAPE) # train_classifier( # feat_model = feat_model, # model = clsNet, # dataset = MoRe, # modelpath = 'tsting_cls.hdf5', # models_folder = models_folder, # epochs = 1, # batch_size = 32, # img_size = INPUT_SHAPE, # label_smoothing = True, # wlr = True, # BN = True