示例#1
0
    train_data_gen_args = dict(
        featurewise_center=False,
        featurewise_std_normalization=False,
        rotation_range=0,
        width_shift_range=0.,
        height_shift_range=0.,
        horizontal_flip=True,
        fill_mode="nearest",
        shear_range=0.,
        zoom_range=[1, 1.4],
    )
    if random_crop_size is None:
        train_data_generator = dataset.get_keras_data_generator(
            is_train=True, keras_data_gen_param=train_data_gen_args, seed=seed)
    else:
        train_data_generator = dataset.get_custom_data_generator(
            is_train=True, keras_data_gen_param=train_data_gen_args, seed=seed)
    # no val augmentation.
    val_data_gen_args = dict(featurewise_center=False,
                             featurewise_std_normalization=False,
                             rotation_range=0.,
                             width_shift_range=0.,
                             height_shift_range=0.,
                             horizontal_flip=False)

    val_data_generator = dataset.get_keras_data_generator(
        is_train=False, keras_data_gen_param=val_data_gen_args, seed=seed)

    model = get_dilated_unet(input_shape=(None, None, 1),
                             mode='cascade',
                             filters=32,
                             n_class=1)
示例#2
0
def train_generator(model_def,
                    model_saved_path,
                    h5_data_path,
                    batch_size,
                    epochs,
                    model_weights,
                    gpus=1,
                    verbose=1,
                    csv_log_suffix="0",
                    fold_k="0",
                    random_k_fold=False,
                    input_channels=1,
                    output_channels=2,
                    random_crop_size=(256, 256),
                    mask_nb=0,
                    seed=0):
    model_weights_root = os.path.dirname(model_saved_path)
    if not os.path.isdir(model_weights_root):
        os.makedirs(model_weights_root)

    learning_rate_scheduler = LearningRateScheduler(
        schedule=get_learning_rate_scheduler, verbose=0)
    opt = Adam(amsgrad=True)
    # opt = SGD()
    log_path = os.path.join(
        CONFIG.log_root,
        "log_" + os.path.splitext(os.path.basename(model_saved_path))[0]
    ) + "_" + csv_log_suffix + ".csv"
    if not os.path.isdir(CONFIG.log_root):
        os.makedirs(CONFIG.log_root)

    if os.path.isfile(log_path):
        print("Log file exists.")
        # exit()
    csv_logger = CSVLogger(log_path, append=False)

    # tensorboard = TensorBoard(log_dir='/home/jzhang/helloworld/mtcnn/cb/logs/tensorboard', write_images=True)

    # fit_metrics = [dice_coef, metrics.binary_crossentropy, binary_acc_ch0]
    # fit_loss = sigmoid_dice_loss_1channel_output

    fit_loss = sigmoid_dice_loss
    fit_metrics = [
        dice_coef_rounded_ch0, dice_coef_rounded_ch1,
        metrics.binary_crossentropy, mean_iou_ch0, binary_acc_ch0
    ]
    # fit_metrics = [dice_coef_rounded_ch0, metrics.binary_crossentropy, mean_iou_ch0]
    # es = EarlyStopping('val_acc', patience=30, mode="auto", min_delta=0.0)
    # reduce_lr = ReduceLROnPlateau(monitor='val_acc', factor=0.1, patience=20, verbose=2, epsilon=1e-4,
    #                               mode='auto')

    if model_weights:
        model = model_def
        print("Loading weights ...")
        model.load_weights(model_weights, by_name=True, skip_mismatch=True)
        print("Model weights {} have been loaded.".format(model_weights))
    else:
        model = model_def

        print("Model created.")

    # # prepare train and val data.
    # dataset = DataSet(h5_data_path, val_fold_nb=fold_k, random_k_fold=random_k_fold,
    #                   input_channels=input_channels, output_channels=output_channels,
    #                   random_crop_size=random_crop_size, mask_nb=mask_nb, batch_size=batch_size
    #                   )

    train_ids = [
        '00000041', '00000042', '00000043', '00000044', '00000045', '00000046',
        '00000047', '00000048', '00000049', '00000050', '00000051', '00000052',
        '00000053', '00000054', '00000055', '00000056', '00000057', '00000058',
        '00000059', '00000060', '00000061', '00000062', '00000063', '00000064',
        '00000065', '00000066', '00000067', '00000068', '00000069', '00000070',
        '00000071', '00000072', '00000073', '00000074', '00000075', '00000076',
        '00000077', '00000078', '00000079', '00000080'
    ]

    val_ids = ['00000041', '00000059', '00000074', '00000075']

    dataset = DataSet(h5_data_path,
                      val_fold_nb=fold_k,
                      random_k_fold=random_k_fold,
                      input_channels=input_channels,
                      output_channels=output_channels,
                      random_crop_size=random_crop_size,
                      mask_nb=mask_nb,
                      batch_size=batch_size,
                      train_ids=train_ids,
                      val_ids=val_ids)
    # we create two instances with the same arguments
    # train_data_gen_args = dict(featurewise_center=False,
    #                            featurewise_std_normalization=False,
    #                            rotation_range=15,
    #                            width_shift_range=0.1,
    #                            height_shift_range=0.1,
    #                            horizontal_flip=True,
    #                            fill_mode="nearest",
    #                            shear_range=0.,
    #                            zoom_range=0.15,
    #                            )
    train_data_gen_args = dict(
        featurewise_center=False,
        featurewise_std_normalization=False,
        rotation_range=0,
        width_shift_range=0.,
        height_shift_range=0.,
        horizontal_flip=True,
        fill_mode="nearest",
        shear_range=0.,
        zoom_range=0.,
    )
    if CONFIG.random_crop_size is None:
        train_data_generator = dataset.get_keras_data_generator(
            is_train=True, keras_data_gen_param=train_data_gen_args, seed=seed)
    else:
        train_data_generator = dataset.get_custom_data_generator(
            is_train=True, keras_data_gen_param=train_data_gen_args, seed=seed)
    # no val augmentation.
    val_data_gen_args = dict(featurewise_center=False,
                             featurewise_std_normalization=False,
                             rotation_range=0.,
                             width_shift_range=0.,
                             height_shift_range=0.,
                             horizontal_flip=False)

    val_data_generator = dataset.get_keras_data_generator(
        is_train=False, keras_data_gen_param=val_data_gen_args, seed=seed)

    model_save_root, model_save_basename = os.path.split(model_saved_path)
    # model_saved_path_best_loss = os.path.join(model_save_root, "best_val_loss_" + model_save_basename)

    if gpus > 1:
        parallel_model = multi_gpu_model(model, gpus=gpus)
        model_checkpoint0 = ModelCheckpointMGPU(model,
                                                model_saved_path,
                                                save_best_only=True,
                                                save_weights_only=True,
                                                monitor="val_loss",
                                                mode='min')
    else:
        parallel_model = model
        model_checkpoint0 = ModelCheckpoint(model_saved_path,
                                            save_best_only=True,
                                            save_weights_only=True,
                                            monitor='val_loss',
                                            mode='min')
    parallel_model.compile(loss=fit_loss, optimizer=opt, metrics=fit_metrics)
    # model.summary()

    train_steps = dataset.get_train_val_steps(is_train=True)
    val_steps = dataset.get_train_val_steps(is_train=False)
    print("Training ...")
    parallel_model.fit_generator(
        train_data_generator,
        validation_data=val_data_generator,
        steps_per_epoch=train_steps,
        validation_steps=val_steps,
        epochs=epochs,
        callbacks=[model_checkpoint0, csv_logger, learning_rate_scheduler],
        verbose=verbose,
        workers=1,
        use_multiprocessing=False,
        shuffle=True)
    # model_save_root, model_save_basename = os.path.split(model_saved_path)
    # final_model_save_path = os.path.join(model_save_root, "final_" + model_save_basename)
    # model.save_weights(final_model_save_path)

    del model, parallel_model
    K.clear_session()
    gc.collect()