#------------------------------------------------------# # 主干特征提取网络特征通用,冻结训练可以加快训练速度 # 也可以在训练初期防止权值被破坏。 # Init_Epoch为起始世代 # Freeze_Epoch为冻结训练的世代 # Epoch总训练世代 # 提示OOM或者显存不足请调小Batch_size #------------------------------------------------------# if True: lr = 1e-4 Init_Epoch = 0 Freeze_Epoch = 50 Batch_size = 2 model.compile(loss=dice_loss_with_CE() if dice_loss else CE(), optimizer=Adam(lr=lr), metrics=[f_score()]) print('Train on {} samples, with batch size {}.'.format( len(train_lines), Batch_size)) gen = Generator(Batch_size, train_lines, inputs_size, num_classes).generate() model.fit_generator( gen, steps_per_epoch=max(1, len(train_lines) // Batch_size), epochs=Freeze_Epoch, initial_epoch=Init_Epoch, callbacks=[checkpoint_period, reduce_lr, tensorboard])
# Freeze_Epoch为冻结训练的世代 # Epoch总训练世代 # 提示OOM或者显存不足请调小Batch_size #------------------------------------------------------# if True: batch_size = Freeze_batch_size lr = Freeze_lr start_epoch = Init_Epoch end_epoch = Freeze_Epoch epoch_step = len(train_lines) // batch_size if epoch_step == 0: raise ValueError("数据集过小,无法进行训练,请扩充数据集。") model.compile(loss=loss, optimizer=Adam(lr=lr), metrics=[f_score()]) train_dataloader = UnetDataset(train_lines, input_shape, batch_size, num_classes, True, VOCdevkit_path) print('Train on {} samples, with batch size {}.'.format( len(train_lines), batch_size)) model.fit_generator( generator=train_dataloader, steps_per_epoch=epoch_step, epochs=end_epoch, initial_epoch=start_epoch, use_multiprocessing=True if num_workers > 1 else False, workers=num_workers, callbacks=[ logging, checkpoint, reduce_lr, early_stopping, loss_history