示例#1
0
    def predict_and_show(self, image, show_output_channels):
        """
        
        Args:
            img: str(image path) or numpy array(b=1, h=576, w=576, c=1)
            show_output_channels: 1 or 2

        Returns:

        """
        if isinstance(image, str):
            images_src = self.read_images([image])
        else:
            images_src = image
        img = DataSet.preprocess(images_src, mode="image")
        predict_mask = self.predict(img, 1, use_channels=show_output_channels)
        predict_mask = np.squeeze(predict_mask, axis=0)
        predict_mask = self.postprocess(predict_mask)
        predict_mask = DataSet.de_preprocess(predict_mask, mode="mask")
        if show_output_channels == 2:
            mask0 = predict_mask[..., 0]
            mask1 = predict_mask[..., 1]
            image_c3 = np.concatenate([np.squeeze(images_src, axis=0) for i in range(3)], axis=-1)
            image_mask0 = apply_mask(image_c3, mask0, color=[255, 106, 106], alpha=0.5)
            # result = np.concatenate((np.squeeze(images_src, axis=[0, -1]), mask0, mask1, image_mask0), axis=1)
            plt.imshow(image_mask0)
        else:
            result = np.concatenate((np.squeeze(images_src, axis=[0, -1]), predict_mask), axis=1)
            plt.imshow(result, cmap="gray")

        plt.show()
示例#2
0
 def evaluate(self, x_val, y_val):
     x_val = DataSet.preprocess(x_val, "image")
     y_val = DataSet.preprocess(y_val, "mask")
     fit_loss = sigmoid_dice_loss
     fit_metrics = [binary_acc_ch0]
     self.model.compile(loss=fit_loss,
                        optimizer="Adam",
                        metrics=fit_metrics)
     # Score trained model.
     scores = self.model.evaluate(x_val, y_val, batch_size=5, verbose=1)
     print('Test loss:', scores[0])
     print('Test accuracy:', scores[1])
示例#3
0
    def predict_from_h5data(self, h5_data_path, val_fold_nb, use_channels, is_train=False, save_dir=None,
                            random_k_fold=False, random_k_fold_npy=None, input_channels=1,
                            output_channels=2, random_crop_size=None, mask_nb=0, batch_size=4
                            ):
        dataset = DataSet(h5_data_path, val_fold_nb, random_k_fold=random_k_fold, random_k_fold_npy=random_k_fold_npy,
                          input_channels=input_channels,
                          output_channels=output_channels, random_crop_size=random_crop_size, mask_nb=mask_nb,
                          batch_size=batch_size)
        images, _ = dataset.prepare_data(is_train)
        pred = self.predict(images, batch_size, use_channels=use_channels)

        if save_dir:
            keys = dataset.get_keys(is_train)
            mask_file_lst = ["{:03}.tif".format(int(key)) for key in keys]
            self.save_mask(pred, mask_file_lst, mask_nb, result_save_dir=save_dir)
        return pred
示例#4
0
def get_acc(model_def, model_weights, h5_data_path, val_fold_nb, is_train=False):
    dataset = DataSet(h5_data_path, val_fold_nb)
    images, masks = dataset.prepare_1i_1o_data(is_train=is_train, mask_nb=0)

    model_obj = ModelDeployment(model_def, model_weights)

    y_pred = model_obj.predict(images, batch_size=4)

    K.clear_session()
    if y_pred.shape[-1] == 2:
        y_pred = y_pred[..., 0]
    print(y_pred.shape)

    y_true = masks
    acc = get_pixel_wise_acc(y_true, y_pred)
    with tf.Session() as sess:
        print(sess.run(acc))
示例#5
0
    def predict_and_save_stage1_masks(self, h5_data_path, h5_result_saved_path, fold_k=0, batch_size=4):
        """
        从h5data中读取images进行预测,并把预测mask保存进h5data中。
        Args:
            h5_data_path: str, 存放有训练数据的h5文件路径。
            batch_size: int, 批大小。

        Returns: None.

        """

        f_result = h5py.File(h5_result_saved_path, "a")
        try:
            stage1_predict_masks_grp = f_result.create_group("stage1_fold{}_predict_masks".format(fold_k))
        except:
            stage1_predict_masks_grp = f_result["stage1_fold{}_predict_masks".format(fold_k)]

        dataset = DataSet(h5_data_path, fold_k)

        images_train = dataset.get_images(is_train=True)
        images_val = dataset.get_images(is_train=False)
        keys_train = dataset.get_keys(is_train=True)
        keys_val = dataset.get_keys(is_train=False)
        images = np.concatenate([images_train, images_val], axis=0)
        keys = np.concatenate([keys_train, keys_val], axis=0)
        print("predicting ...")
        images = dataset.preprocess(images, mode="image")
        y_pred = self.predict(images, batch_size, use_channels=1)
        print(y_pred.shape)
        print("Saving predicted masks ...")
        for i, key in enumerate(keys):
            stage1_predict_masks_grp.create_dataset(key, dtype=np.float32, data=y_pred[i])
        print("Done.")
示例#6
0
def do_evaluate():
    h5_data_path = "/home/topsky/helloworld/study/njai_challenge/cbct/inputs/data_0717.hdf5"
    val_fold_nb = "01"
    output_channels = 2
    is_train = False
    model_def = get_se_inception_resnet_v2_unet_sigmoid_gn(weights=None, output_channels=output_channels)
    model_weights = "/home/topsky/helloworld/study/njai_challenge/cbct/model_weights/new/20180802_0/se_inception_resnet_v2_gn_fold01_random_kfold_0_1i_2o.h5"

    model_obj = ModelDeployment(model_def, model_weights)
    dataset = DataSet(h5_data_path, val_fold_nb=val_fold_nb)
    images, masks = dataset.prepare_data(is_train=is_train)
    print(images.shape)
    print(masks.shape)
    # idx_lst = [0, 5, 10, 15, 20]
    # val_images = np.array([images[i] for i in idx_lst])
    # val_masks = np.array([masks[i] for i in idx_lst])
    # model_obj.evaluate(val_images, val_masks)
    model_obj.evaluate(images, masks)
示例#7
0
def do_predict_custom():
    model = get_dilated_unet(
        input_shape=(None, None, 1),
        mode='cascade',
        filters=32,
        n_class=1
    )
    model_weights = "/home/topsky/helloworld/study/njai_challenge/cbct/func/others_try/model_weights.hdf5"
    img_path = "/media/topsky/HHH/jzhang_root/data/njai/cbct/CBCT_testingset/CBCT_testingset/04+246ori.tif"
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    img = np.expand_dims(img, axis=-1)
    img = np.expand_dims(img, axis=0)
    img = DataSet.preprocess(img, mode="image")
    # print(img.shape)
    # exit()
    model.load_weights(model_weights)
    pred = model.predict(img, batch_size=1)
    pred_img = np.squeeze(pred[0], -1)
    pred_img = DataSet.de_preprocess(pred_img, mode="image")
    plt.imshow(pred_img, "gray")
    plt.show()
示例#8
0
def inference_2stages_from_files(model_def_stage1, model_weights_stage1, model_def_stage2, model_weights_stage2,
                                 file_dir, pred_save_dir):
    if not os.path.isdir(pred_save_dir):
        os.makedirs(pred_save_dir)
    model_obj = ModelDeployment(model_def_stage1, model_weights_stage1)
    file_path_lst = get_file_path_list(file_dir, ext=".tif")
    dst_file_path_lst = [os.path.join(pred_save_dir, os.path.basename(x)) for x in file_path_lst]

    imgs_src = model_obj.read_images(file_path_lst)
    imgs = DataSet.preprocess(imgs_src, mode="image")
    pred_stage1 = model_obj.predict(imgs, batch_size=5, use_channels=1)
    pred_stage1 = np.expand_dims(pred_stage1, axis=-1)
    input_stage2 = np.concatenate([imgs_src, pred_stage1], axis=-1)
    del model_obj
    print(pred_stage1.shape)
    print(input_stage2.shape)
    model_obj = ModelDeployment(model_def_stage2, model_weights_stage2)

    pred = model_obj.predict(input_stage2, batch_size=5, use_channels=1)
    pred = model_obj.postprocess(pred)
    pred = DataSet.de_preprocess(pred, mode="mask")
    for i in range(len(pred)):
        cv2.imwrite(dst_file_path_lst[i], pred[i])
示例#9
0
    def predict(self, images, batch_size, use_channels=2):
        """
        对未预处理过的图片进行预测。
        Args:
            images: 4-d numpy array. preprocessed image. (b, h, w, c=1)
            batch_size:
            use_channels: int, default to 2. 如果模型输出通道数为2,可以控制输出几个channel.默认输出第一个channel的预测值.

        Returns: 4-d numpy array.

        """
        images = DataSet.preprocess(images, mode="image")

        outputs = self.model.predict(images, batch_size)
        if use_channels == 1:
            outputs = outputs[..., 0]
            outputs = np.expand_dims(outputs, -1)
        return outputs
示例#10
0
def tta_test():
    img = cv2.imread(
        "/media/topsky/HHH/jzhang_root/data/njai/cbct/CBCT_testingset/CBCT_testingset/04+246ori.tif",
        cv2.IMREAD_GRAYSCALE)
    img = np.expand_dims(img, axis=-1)
    img = DataSet.preprocess(img, mode="image")
    img = np.expand_dims(img, axis=0)

    model = get_densenet121_unet_sigmoid_gn(input_shape=(None, None, 1),
                                            output_channels=2,
                                            weights=None)
    model.load_weights(
        "/home/topsky/helloworld/study/njai_challenge/cbct/model_weights/20180731_0/best_val_acc_se_densenet_gn_fold0_random_0_1i_2o_20180801.h5"
    )
    pred = tta_predict(model, img, batch_size=1)
    # print(pred)
    pred = np.squeeze(pred, 0)
    print(pred.shape)
    pred = np.where(pred > 0.5, 255, 0)
    cv2.imwrite("/home/topsky/Desktop/mask_04+246ori_f1_random.tif", pred[...,
                                                                          0])
示例#11
0
    def predict_from_files_old(self, image_path_lst, batch_size=5, use_channels=2, mask_file_lst=None, tta=False,
                               is_save_npy=False, is_save_mask0=False, is_save_mask1=False, result_save_dir=""):
        """
        给定图片路径列表,返回预测结果(未处理过的),如果指定了预测结果保存路径,则保存预测结果(已处理过的)。
        如果指定了预测结果保存的文件名列表,则该列表顺序必须与image_path_lst一致;
        如果没有指定预测结果保存的文件名列表,则自动生成和输入相同的文件名列表。
        Args:
            image_path_lst: list.
            batch_size:
            use_channels: 输出几个channel。
            mask_file_lst: list, 预测结果保存的文件名列表。
            tta: bool, 预测时是否进行数据增强。
            is_save_npy: bool, 是否保存npy文件。
            is_save_mask0: bool
            is_save_mask1: bool
            result_save_dir: str, 结果保存的目录路径。
        Returns: 4-d numpy array, predicted result.

        """
        imgs = self.read_images(image_path_lst)
        imgs = DataSet.preprocess(imgs, mode="image")
        if tta:
            pred = tta_predict(self.model, imgs, batch_size=batch_size)
        else:
            pred = self.predict_old(imgs, batch_size=batch_size, use_channels=use_channels)
        if mask_file_lst is None:
            mask_file_lst = [os.path.basename(x) for x in image_path_lst]
        if is_save_npy:
            # 保存npy文件
            npy_dir = os.path.join(result_save_dir, "npy")
            self.save_npy(pred, mask_file_lst, npy_dir)
        if is_save_mask0:
            mask_nb = 0
            mask_save_dir = os.path.join(result_save_dir, "mask{}".format(mask_nb))
            self.save_mask(pred, mask_file_lst, mask_nb=mask_nb, result_save_dir=mask_save_dir)
        if is_save_mask1:
            mask_nb = 1
            mask_save_dir = os.path.join(result_save_dir, "mask{}".format(mask_nb))
            self.save_mask(pred, mask_file_lst, mask_nb=mask_nb, result_save_dir=mask_save_dir)
        return pred
示例#12
0
    def save_mask(pred, mask_file_lst, mask_nb, result_save_dir):
        """

        Args:
            pred: 4-d numpy array, (b, h, w, c)
            mask_file_lst:
            mask_nb:
            result_save_dir:

        Returns:

        """
        if not os.path.isdir(result_save_dir):
            os.makedirs(result_save_dir)
        masks = pred[..., mask_nb]
        mask_file_path_lst = [os.path.join(result_save_dir, x) for x in mask_file_lst]
        # 将预测结果转换为0-1数组。
        masks = ModelDeployment.postprocess(masks)
        # 将0-1数组转换为0-255数组。
        masks = DataSet.de_preprocess(masks, mode="mask")
        for i in range(len(pred)):
            cv2.imwrite(mask_file_path_lst[i], masks[i])
示例#13
0
    def predict_from_h5data_old(self, h5_data_path, val_fold_nb, is_train=False, save_dir=None,
                                color_lst=None):
        dataset = DataSet(h5_data_path, val_fold_nb)

        images = dataset.get_images(is_train=is_train)
        imgs_src = np.concatenate([images for i in range(3)], axis=-1)
        masks = dataset.get_masks(is_train=is_train, mask_nb=0)
        masks = np.squeeze(masks, axis=-1)
        print("predicting ...")
        y_pred = self.predict(dataset.preprocess(images, mode="image"), batch_size=4, use_channels=1)
        y_pred = self.postprocess(y_pred)
        y_pred = DataSet.de_preprocess(y_pred, mode="mask")
        print(y_pred.shape)

        if save_dir:
            keys = dataset.get_keys(is_train)
            if color_lst is None:
                color_gt = [255, 106, 106]
                color_pred = [0, 191, 255]
                # color_pred = [255, 255, 0]
            else:
                color_gt = color_lst[0]
                color_pred = color_lst[1]
            # BGR to RGB
            imgs_src = imgs_src[..., ::-1]
            image_masks = [apply_mask(image, mask, color_gt, alpha=0.5) for image, mask in zip(imgs_src, masks)]
            image_preds = [apply_mask(image, mask, color_pred, alpha=0.5) for image, mask in zip(imgs_src, y_pred)]
            dst_image_path_lst = [os.path.join(save_dir, "{:03}.tif".format(int(key))) for key in keys]
            if not os.path.isdir(save_dir):
                os.makedirs(save_dir)
            image_mask_preds = np.concatenate([imgs_src, image_masks, image_preds], axis=2)
            for i in range(len(image_masks)):
                cv2.imwrite(dst_image_path_lst[i], image_mask_preds[i])
            print("Done.")
        else:
            return y_pred
示例#14
0
def train_generator(model_def, model_saved_path, h5_data_path, batch_size, epochs, model_weights, gpus=1, verbose=1,
                    csv_log_suffix="0", fold_k="0", random_k_fold=False):
    learning_rate_scheduler = LearningRateScheduler(schedule=get_learning_rate_scheduler, verbose=0)
    opt = Adam(amsgrad=True)
    # opt = SGD()
    log_path = os.path.join(CONFIG.log_root, "log_" + os.path.splitext(os.path.basename(model_saved_path))[
        0]) + "_" + csv_log_suffix + ".csv"
    if os.path.isfile(log_path):
        print("Log file exists.")
        # exit()
    csv_logger = CSVLogger(log_path, append=False)

    # tensorboard = TensorBoard(log_dir='/home/jzhang/helloworld/mtcnn/cb/logs/tensorboard', write_images=True)

    fit_metrics = [dice_coef, metrics.binary_crossentropy, binary_acc_ch0]
    fit_loss = sigmoid_dice_loss_1channel_output

    # fit_loss = sigmoid_dice_loss
    # fit_metrics = [dice_coef_rounded_ch0, dice_coef_rounded_ch1, metrics.binary_crossentropy, mean_iou_ch0,
    #                binary_acc_ch0]
    # fit_metrics = [dice_coef_rounded_ch0, metrics.binary_crossentropy, mean_iou_ch0]
    # es = EarlyStopping('val_acc', patience=30, mode="auto", min_delta=0.0)
    # reduce_lr = ReduceLROnPlateau(monitor='val_acc', factor=0.1, patience=20, verbose=2, epsilon=1e-4,
    #                               mode='auto')

    if model_weights:
        model = model_def
        print("Loading weights ...")
        model.load_weights(model_weights, by_name=True, skip_mismatch=True)
        print("Model weights {} have been loaded.".format(model_weights))
    else:
        model = model_def

        print("Model created.")

    # prepare train and val data.
    dataset = DataSet(h5_data_path, val_fold_nb=fold_k, random_k_fold=random_k_fold)
    # x_train, y_train = dataset.prepare_stage2_data(is_train=True)
    x_train, y_train = dataset._prepare_3i_1o_data(is_train=True)
    print(x_train.shape)
    print(y_train.shape)
    # x_val, y_val = dataset.prepare_stage2_data(is_train=False)
    x_val, y_val = dataset._prepare_3i_1o_data(is_train=False)
    print(x_val.shape)
    print(y_val.shape)
    # we create two instances with the same arguments
    train_data_gen_args = dict(featurewise_center=False,
                               featurewise_std_normalization=False,
                               rotation_range=15,
                               width_shift_range=0.1,
                               height_shift_range=0.1,
                               horizontal_flip=True,
                               fill_mode="nearest",
                               shear_range=0.,
                               zoom_range=0.15, )
    train_image_datagen = ImageDataGenerator(**train_data_gen_args)
    train_mask_datagen = ImageDataGenerator(**train_data_gen_args)
    # Provide the same seed and keyword arguments to the fit and flow methods
    seed = np.random.randint(12488421)
    # train_image_datagen.fit(x_train, augment=True, seed=seed)
    # train_mask_datagen.fit(y_train, augment=True, seed=seed)
    train_image_generator = train_image_datagen.flow(x_train, None, batch_size, shuffle=True, seed=seed)
    train_mask_generator = train_mask_datagen.flow(y_train, None, batch_size, shuffle=True, seed=seed)
    # combine generators into one which yields image and masks
    train_data_generator = zip(train_image_generator, train_mask_generator)

    # no val augmentation.
    val_data_gen_args = dict(featurewise_center=False,
                             featurewise_std_normalization=False,
                             rotation_range=0.,
                             width_shift_range=0.,
                             height_shift_range=0.,
                             horizontal_flip=False)
    val_image_datagen = ImageDataGenerator(**val_data_gen_args)
    val_mask_datagen = ImageDataGenerator(**val_data_gen_args)
    val_image_generator = val_image_datagen.flow(x_val, None, batch_size, shuffle=False, seed=seed)
    val_mask_generator = val_mask_datagen.flow(y_val, None, batch_size, shuffle=False, seed=seed)
    val_data_generator = zip(val_image_generator, val_mask_generator)

    model_save_root, model_save_basename = os.path.split(model_saved_path)
    model_saved_path0 = os.path.join(model_save_root, "best_val_loss_" + model_save_basename)
    model_saved_path1 = os.path.join(model_save_root, "best_val_acc_" + model_save_basename)

    if gpus > 1:
        parallel_model = multi_gpu_model(model, gpus=gpus)
        model_checkpoint0 = ModelCheckpointMGPU(model, model_saved_path0, save_best_only=True, save_weights_only=True,
                                                monitor="val_loss",
                                                mode='min')
        model_checkpoint1 = ModelCheckpointMGPU(model, model_saved_path1, save_best_only=True, save_weights_only=True,
                                                monitor="val_binary_acc_ch0",
                                                mode='max')
    else:
        parallel_model = model
        model_checkpoint0 = ModelCheckpoint(model_saved_path0, save_best_only=True, save_weights_only=True,
                                            monitor='val_loss', mode='min')
        model_checkpoint1 = ModelCheckpoint(model_saved_path1, save_best_only=True, save_weights_only=True,
                                            monitor='val_binary_acc_ch0', mode='max')
    parallel_model.compile(loss=fit_loss,
                           optimizer=opt,
                           metrics=fit_metrics)
    model.summary()

    count_train = x_train.shape[0]
    count_val = x_val.shape[0]
    parallel_model.fit_generator(
        train_data_generator,
        validation_data=val_data_generator,
        steps_per_epoch=count_train // batch_size,
        validation_steps=count_val // batch_size,
        epochs=epochs,
        callbacks=[model_checkpoint0, model_checkpoint1, csv_logger, learning_rate_scheduler],
        verbose=verbose,
        workers=2,
        use_multiprocessing=True,
        shuffle=True
    )
    # model_save_root, model_save_basename = os.path.split(model_saved_path)
    # final_model_save_path = os.path.join(model_save_root, "final_" + model_save_basename)
    # model.save_weights(final_model_save_path)

    del model, parallel_model
    K.clear_session()
    gc.collect()
示例#15
0
        '00000041', '00000042', '00000043', '00000044', '00000045', '00000046',
        '00000047', '00000048', '00000049', '00000050', '00000051', '00000052',
        '00000053', '00000054', '00000055', '00000056', '00000057', '00000058',
        '00000059', '00000060', '00000061', '00000062', '00000063', '00000064',
        '00000065', '00000066', '00000067', '00000068', '00000069', '00000070',
        '00000071', '00000072', '00000073', '00000074', '00000075', '00000076',
        '00000077', '00000078', '00000079', '00000080'
    ]

    val_ids = ['00000041', '00000059', '00000074', '00000075']
    # prepare train and val data.
    dataset = DataSet(h5_data_path,
                      val_fold_nb=fold_k,
                      random_k_fold=random_k_fold,
                      input_channels=input_channels,
                      output_channels=output_channels,
                      random_crop_size=random_crop_size,
                      mask_nb=mask_nb,
                      batch_size=batch_size,
                      train_ids=train_ids,
                      val_ids=val_ids)
    # we create two instances with the same arguments
    train_data_gen_args = dict(
        featurewise_center=False,
        featurewise_std_normalization=False,
        rotation_range=0,
        width_shift_range=0.,
        height_shift_range=0.,
        horizontal_flip=True,
        fill_mode="nearest",
        shear_range=0.,
        zoom_range=[1, 1.4],
示例#16
0
def train_generator(model_def,
                    model_saved_path,
                    h5_data_path,
                    batch_size,
                    epochs,
                    model_weights,
                    gpus=1,
                    verbose=1,
                    csv_log_suffix="0",
                    fold_k="0",
                    random_k_fold=False,
                    input_channels=1,
                    output_channels=2,
                    random_crop_size=(256, 256),
                    mask_nb=0,
                    seed=0):
    model_weights_root = os.path.dirname(model_saved_path)
    if not os.path.isdir(model_weights_root):
        os.makedirs(model_weights_root)

    learning_rate_scheduler = LearningRateScheduler(
        schedule=get_learning_rate_scheduler, verbose=0)
    opt = Adam(amsgrad=True)
    # opt = SGD()
    log_path = os.path.join(
        CONFIG.log_root,
        "log_" + os.path.splitext(os.path.basename(model_saved_path))[0]
    ) + "_" + csv_log_suffix + ".csv"
    if not os.path.isdir(CONFIG.log_root):
        os.makedirs(CONFIG.log_root)

    if os.path.isfile(log_path):
        print("Log file exists.")
        # exit()
    csv_logger = CSVLogger(log_path, append=False)

    # tensorboard = TensorBoard(log_dir='/home/jzhang/helloworld/mtcnn/cb/logs/tensorboard', write_images=True)

    # fit_metrics = [dice_coef, metrics.binary_crossentropy, binary_acc_ch0]
    # fit_loss = sigmoid_dice_loss_1channel_output

    fit_loss = sigmoid_dice_loss
    fit_metrics = [
        dice_coef_rounded_ch0, dice_coef_rounded_ch1,
        metrics.binary_crossentropy, mean_iou_ch0, binary_acc_ch0
    ]
    # fit_metrics = [dice_coef_rounded_ch0, metrics.binary_crossentropy, mean_iou_ch0]
    # es = EarlyStopping('val_acc', patience=30, mode="auto", min_delta=0.0)
    # reduce_lr = ReduceLROnPlateau(monitor='val_acc', factor=0.1, patience=20, verbose=2, epsilon=1e-4,
    #                               mode='auto')

    if model_weights:
        model = model_def
        print("Loading weights ...")
        model.load_weights(model_weights, by_name=True, skip_mismatch=True)
        print("Model weights {} have been loaded.".format(model_weights))
    else:
        model = model_def

        print("Model created.")

    # # prepare train and val data.
    # dataset = DataSet(h5_data_path, val_fold_nb=fold_k, random_k_fold=random_k_fold,
    #                   input_channels=input_channels, output_channels=output_channels,
    #                   random_crop_size=random_crop_size, mask_nb=mask_nb, batch_size=batch_size
    #                   )

    train_ids = [
        '00000041', '00000042', '00000043', '00000044', '00000045', '00000046',
        '00000047', '00000048', '00000049', '00000050', '00000051', '00000052',
        '00000053', '00000054', '00000055', '00000056', '00000057', '00000058',
        '00000059', '00000060', '00000061', '00000062', '00000063', '00000064',
        '00000065', '00000066', '00000067', '00000068', '00000069', '00000070',
        '00000071', '00000072', '00000073', '00000074', '00000075', '00000076',
        '00000077', '00000078', '00000079', '00000080'
    ]

    val_ids = ['00000041', '00000059', '00000074', '00000075']

    dataset = DataSet(h5_data_path,
                      val_fold_nb=fold_k,
                      random_k_fold=random_k_fold,
                      input_channels=input_channels,
                      output_channels=output_channels,
                      random_crop_size=random_crop_size,
                      mask_nb=mask_nb,
                      batch_size=batch_size,
                      train_ids=train_ids,
                      val_ids=val_ids)
    # we create two instances with the same arguments
    # train_data_gen_args = dict(featurewise_center=False,
    #                            featurewise_std_normalization=False,
    #                            rotation_range=15,
    #                            width_shift_range=0.1,
    #                            height_shift_range=0.1,
    #                            horizontal_flip=True,
    #                            fill_mode="nearest",
    #                            shear_range=0.,
    #                            zoom_range=0.15,
    #                            )
    train_data_gen_args = dict(
        featurewise_center=False,
        featurewise_std_normalization=False,
        rotation_range=0,
        width_shift_range=0.,
        height_shift_range=0.,
        horizontal_flip=True,
        fill_mode="nearest",
        shear_range=0.,
        zoom_range=0.,
    )
    if CONFIG.random_crop_size is None:
        train_data_generator = dataset.get_keras_data_generator(
            is_train=True, keras_data_gen_param=train_data_gen_args, seed=seed)
    else:
        train_data_generator = dataset.get_custom_data_generator(
            is_train=True, keras_data_gen_param=train_data_gen_args, seed=seed)
    # no val augmentation.
    val_data_gen_args = dict(featurewise_center=False,
                             featurewise_std_normalization=False,
                             rotation_range=0.,
                             width_shift_range=0.,
                             height_shift_range=0.,
                             horizontal_flip=False)

    val_data_generator = dataset.get_keras_data_generator(
        is_train=False, keras_data_gen_param=val_data_gen_args, seed=seed)

    model_save_root, model_save_basename = os.path.split(model_saved_path)
    # model_saved_path_best_loss = os.path.join(model_save_root, "best_val_loss_" + model_save_basename)

    if gpus > 1:
        parallel_model = multi_gpu_model(model, gpus=gpus)
        model_checkpoint0 = ModelCheckpointMGPU(model,
                                                model_saved_path,
                                                save_best_only=True,
                                                save_weights_only=True,
                                                monitor="val_loss",
                                                mode='min')
    else:
        parallel_model = model
        model_checkpoint0 = ModelCheckpoint(model_saved_path,
                                            save_best_only=True,
                                            save_weights_only=True,
                                            monitor='val_loss',
                                            mode='min')
    parallel_model.compile(loss=fit_loss, optimizer=opt, metrics=fit_metrics)
    # model.summary()

    train_steps = dataset.get_train_val_steps(is_train=True)
    val_steps = dataset.get_train_val_steps(is_train=False)
    print("Training ...")
    parallel_model.fit_generator(
        train_data_generator,
        validation_data=val_data_generator,
        steps_per_epoch=train_steps,
        validation_steps=val_steps,
        epochs=epochs,
        callbacks=[model_checkpoint0, csv_logger, learning_rate_scheduler],
        verbose=verbose,
        workers=1,
        use_multiprocessing=False,
        shuffle=True)
    # model_save_root, model_save_basename = os.path.split(model_saved_path)
    # final_model_save_path = os.path.join(model_save_root, "final_" + model_save_basename)
    # model.save_weights(final_model_save_path)

    del model, parallel_model
    K.clear_session()
    gc.collect()