示例#1
0
    #------------------------------------------------------#
    #   主干特征提取网络特征通用,冻结训练可以加快训练速度
    #   也可以在训练初期防止权值被破坏。
    #   Init_Epoch为起始世代
    #   Freeze_Epoch为冻结训练的世代
    #   Epoch总训练世代
    #   提示OOM或者显存不足请调小Batch_size
    #------------------------------------------------------#
    if True:
        lr = 1e-4
        Init_Epoch = 0
        Freeze_Epoch = 50
        Batch_size = 2

        model.compile(loss=dice_loss_with_CE() if dice_loss else CE(),
                      optimizer=Adam(lr=lr),
                      metrics=[f_score()])
        print('Train on {} samples, with batch size {}.'.format(
            len(train_lines), Batch_size))

        gen = Generator(Batch_size, train_lines, inputs_size,
                        num_classes).generate()

        model.fit_generator(
            gen,
            steps_per_epoch=max(1,
                                len(train_lines) // Batch_size),
            epochs=Freeze_Epoch,
            initial_epoch=Init_Epoch,
            callbacks=[checkpoint_period, reduce_lr, tensorboard])
示例#2
0
    #   Freeze_Epoch为冻结训练的世代
    #   Epoch总训练世代
    #   提示OOM或者显存不足请调小Batch_size
    #------------------------------------------------------#
    if True:
        batch_size = Freeze_batch_size
        lr = Freeze_lr
        start_epoch = Init_Epoch
        end_epoch = Freeze_Epoch

        epoch_step = len(train_lines) // batch_size

        if epoch_step == 0:
            raise ValueError("数据集过小,无法进行训练,请扩充数据集。")

        model.compile(loss=loss, optimizer=Adam(lr=lr), metrics=[f_score()])

        train_dataloader = UnetDataset(train_lines, input_shape, batch_size,
                                       num_classes, True, VOCdevkit_path)

        print('Train on {} samples, with batch size {}.'.format(
            len(train_lines), batch_size))
        model.fit_generator(
            generator=train_dataloader,
            steps_per_epoch=epoch_step,
            epochs=end_epoch,
            initial_epoch=start_epoch,
            use_multiprocessing=True if num_workers > 1 else False,
            workers=num_workers,
            callbacks=[
                logging, checkpoint, reduce_lr, early_stopping, loss_history