Beispiel #1
0
def create_callbacks(model, original_model, args):
    """Create Keras callbacks for training."""
    callbacks = []

    # Model checkpoint.
    if args.gpus == 1:
        model_checkpoint = AltModelCheckpoint(
            args.weights if args.debug == '' else os.path.join(
                args.debug, 'weights', 'weights-improvement-{epoch:02d}.hdf5'),
            model,
            monitor='val_dice_coef',
            mode='max',
            verbose=1,
            save_best_only=True,
            save_weights_only=True)
    else:
        model_checkpoint = AltModelCheckpoint(
            args.weights if args.debug == '' else os.path.join(
                args.debug, 'weights', 'weights-improvement-{epoch:02d}.hdf5'),
            original_model,
            monitor='val_dice_coef',
            mode='max',
            verbose=1,
            save_best_only=True,
            save_weights_only=True)
    callbacks.append(model_checkpoint)

    # Early stopping.
    # model_early_stopping = EarlyStopping(monitor='val_dice_coef', min_delta=0.001, patience=20, verbose=1, mode='max')
    # callbacks.append(model_early_stopping)

    # Tensorboard logs.
    if args.debug != '':
        mkdir_s(args.debug)
        mkdir_s(os.path.join(args.debug, 'weights'))
        mkdir_s(os.path.join(args.debug, 'logs'))
        model_tensorboard = TensorBoard(log_dir=os.path.join(
            args.debug, 'logs'),
                                        histogram_freq=0,
                                        write_graph=True,
                                        write_images=True)
        callbacks.append(model_tensorboard)

    # Training visualisation.
    if args.vis != '':
        model_visualisation = Visualisation(dir_name=args.vis,
                                            batchsize=args.batchsize,
                                            monitor='val_dice_coef',
                                            save_best_epochs_only=True,
                                            mode='max')
        callbacks.append(model_visualisation)

    return callbacks
Beispiel #2
0
    def test_on_epoch_end(self):
        model1 = Mock()

        model2 = Mock()
        model2.save = Mock()

        callback = AltModelCheckpoint('path/to/model.hdf5', model2)
        callback.model = model1

        callback.on_epoch_end(42)
        self.assertIs(callback.model, model1, 'original model is restored')

        # model2 saved
        model2.save.assert_called_once_with('path/to/model.hdf5',
                                            overwrite=True)
Beispiel #3
0
def set_checkpoint(base_model, weights_dir):
    if not os.path.exists(weights_dir):
        os.makedirs(weights_dir)
    filepath = os.path.join(weights_dir, "weights-{epoch:03d}.h5")
    checkpoint = AltModelCheckpoint(filepath,
                                    base_model,
                                    monitor='val_loss',
                                    mode='min',
                                    verbose=1,
                                    period=5)
    csv_logger = CSVLogger(os.path.join(weights_dir, 'log.csv'),
                           append=True,
                           separator=';')
    callbacks_list = [checkpoint, csv_logger]
    if lr_reducer is not None:
        callbacks_list = [lr_reducer] + callbacks_list

    timestamps = datetime.now()
    timestamps = str(timestamps)
    timestamps = timestamps[:timestamps.find('.')]
    timestamps = timestamps.replace(' ', '_')
    tensorboard_logdir = 'logs/{}'.format(timestamps)
    tensorboard = TensorBoard(log_dir=tensorboard_logdir)
    callbacks_list.append(tensorboard)
    return callbacks_list
Beispiel #4
0
def get_callbacks_mgpu(model_file,
                       initial_learning_rate=0.0001,
                       learning_rate_drop=0.5,
                       learning_rate_epochs=None,
                       learning_rate_patience=50,
                       logging_file="training.log",
                       verbosity=1,
                       early_stopping_patience=None,
                       base_model=None):
    callbacks = list()
    # callbacks.append(ModelCheckpoint(model_file, save_best_only=True))
    callbacks.append(
        AltModelCheckpoint(model_file, base_model, save_best_only=True))
    callbacks.append(CSVLogger(logging_file, append=True))
    if learning_rate_epochs:
        callbacks.append(
            LearningRateScheduler(
                partial(step_decay,
                        initial_lrate=initial_learning_rate,
                        drop=learning_rate_drop,
                        epochs_drop=learning_rate_epochs)))
    else:
        callbacks.append(
            ReduceLROnPlateau(factor=learning_rate_drop,
                              patience=learning_rate_patience,
                              verbose=verbosity))
    if early_stopping_patience:
        callbacks.append(
            EarlyStopping(verbose=verbosity, patience=early_stopping_patience))
    return callbacks
Beispiel #5
0
def main():
    """ Trains the network, saves model checkpoints and logs training progress."""
    multi_gpu = 6
    batch_size = 30
    epochs = 100
    cnn_filters = [64, 64, 128, 128, 256]
    max_pool_sizes = [2, 2, 2, 2, 2]
    rnn_sizes = [256, 256]
    fc_sizes = [256]
    dropout_rate = 0.0

    data_directory = 'features'
    class_to_id, id_to_class = get_dicts(data_directory)

    generator = SEDGenerator(data_dir=data_directory,
                             concurrent=2,
                             class_to_id=class_to_id,
                             id_to_class=id_to_class,
                             batch_size=batch_size,
                             random_seed=1,
                             train_size=0.9)

    training_generator = generator.generate(training=True)
    testing_generator = generator.generate(training=False)
    sample = next(training_generator)

    base_model, model_to_train = build_model(sample,
                                             cnn_filters=cnn_filters,
                                             max_pool_sizes=max_pool_sizes,
                                             rnn_sizes=rnn_sizes,
                                             fc_sizes=fc_sizes,
                                             dropout_rate=dropout_rate,
                                             multi_gpu=multi_gpu)

    log_dir = 'tensorboard_logs/{}'.format(time())
    tensorboard = TensorBoard(log_dir=log_dir)
    checkpointer = AltModelCheckpoint('model_checkpoint.h5',
                                      base_model,
                                      verbose=1,
                                      save_best_only=True)
    callbacks_list = [tensorboard, checkpointer]

    model_to_train.fit_generator(
        generator=training_generator,
        validation_data=testing_generator,
        steps_per_epoch=generator.steps_per_epoch,
        validation_steps=generator.steps_per_validation,
        callbacks=callbacks_list,
        epochs=epochs)

    base_model.save('model_final_epoch.h5')
Beispiel #6
0
def create_callbacks(
    model: keras_model,
    original_model: keras_model,
    debug: str,
    num_gpus: int,
    augmentate: bool,
    batchsize: int,
    vis: str,
    weights_path: str,
) -> List[str]:
    """Create Keras callbacks for training.

    Parameters
    ----------
    model: keras_model
        keras model
    original_model: keras_model
        model to use when num_gpus > 1
    debug: str
        path to save weights and tensorboard logs
    num_gpus: int
        number of gpus
    augmentate: bool
        augmentate the batch of images
    batchsize: int
        batchsize to use during training visualization
    vis: str
        images to read for training visualization
    weights_path: str
        path to save final weights

    Returns
    -------
    List[str]
        list of callbacks tu use in training.

    See Also
    --------
    Visualisation()

    Example
    -------
    robin.callbacks_utils.create_callbacks(
        model,
        gpu_model,
        logs,
        1,
        True,
        32,
        vis_imgs,
        weights
        )

    """
    callbacks = []

    # Model checkpoint.
    if num_gpus == 1:
        model_checkpoint = AltModelCheckpoint(
            weights_path
            if debug == ""
            else os.path.join(
                debug,
                "weights",
                "weights-improvement-{epoch:02d}.hdf5"),
            model,
            monitor="val_dice_coef",
            mode="max",
            verbose=1,
            save_best_only=True,
            save_weights_only=True,
        )
    else:
        model_checkpoint = AltModelCheckpoint(
            weights_path
            if debug == ""
            else os.path.join(
                debug,
                "weights",
                "weights-improvement-{epoch:02d}.hdf5"),
            original_model,
            monitor="val_dice_coef",
            mode="max",
            verbose=1,
            save_best_only=True,
            save_weights_only=True,
        )
    callbacks.append(model_checkpoint)

    # Early stopping.
    model_early_stopping = EarlyStopping(
        monitor="val_dice_coef",
        min_delta=0.001,
        patience=20,
        verbose=1,
        mode="max"
    )
    callbacks.append(model_early_stopping)

    # Tensorboard logs.
    if debug != "":
        mkdir_s(debug)
        mkdir_s(os.path.join(debug, "weights"))
        mkdir_s(os.path.join(debug, "logs"))
        model_tensorboard = TensorBoard(
            log_dir=os.path.join(debug, "logs"),
            histogram_freq=0,
            write_graph=True,
            write_images=True,
        )
        callbacks.append(model_tensorboard)

    # Training visualisation.
    if vis != "":
        model_visualisation = Visualisation(
            dir_name=vis,
            batchsize=batchsize,
            monitor="val_dice_coef",
            save_best_epochs_only=True,
            mode="max",
        )
        callbacks.append(model_visualisation)

    return callbacks
Beispiel #7
0
def run_training(config):
    # Step 1: Check if input type is defined
    try:
        input_type = config["input_type"]
    except:
        raise Exception(
            "Error: Input type not defined | \t Set  config[\"input_type\"] to \"Image\", \"Clinical\" or \"Both\" \n"
        )

# Step 2: Check if problem type is defined
    try:
        problem_type = config["problem_type"]
    except:
        raise Exception(
            "Error: Problem type not defined | \t Set  config[\"problem_type\"] to \"Classification\", \"Segmentation\" or \"Regression\" \n"
        )

# Step 3: Check if the Data File is defined and open it
    try:
        data_file = tables.open_file(os.path.abspath(
            os.path.join(config["data_dir"], config["data_file"])),
                                     mode='r')
    except:
        raise Exception(
            "Error: Could not open data file, check if config[\"data_file\"] is defined \n"
        )

# Step 4:Check if datafile contains all the data arrays required for the problem type
# and load the pickle files containing training, validation split. If no pickle file is presnet, a 80/20 split of all the data in the datafile will be used for training and validation.
    training_file = os.path.abspath(
        os.path.join(config["data_dir"], config['training_split']))
    validation_file = os.path.abspath(
        os.path.join(config["data_dir"], config['validation_split']))
    if data_file.__contains__('/truth'):
        if config["input_type"] is "Both" and data_file.__contains__(
                '/cldata') and data_file.__contains__('/imdata'):
            training_list, validation_list = create_validation_split(
                config["problem_type"],
                data_file.root.truth,
                training_file,
                validation_file,
                train_split=0.80,
                overwrite=0)
        elif config["input_type"] is "Image" and data_file.__contains__(
                '/imdata'):
            training_list, validation_list = create_validation_split(
                config["problem_type"],
                data_file.root.truth,
                training_file,
                validation_file,
                train_split=0.80,
                overwrite=0)
        elif config["input_type"] is "Clinical" and data_file.__contains__(
                '/cldata'):
            training_list, validation_list = create_validation_split(
                config["problem_type"],
                data_file.root.truth,
                training_file,
                validation_file,
                train_split=0.80,
                overwrite=0)
        else:
            print('Input Type: ', input_type)
            print('Clincial data: ', data_file.__contains__('/cldata'))
            print('Image data: ', data_file.__contains__('/imdata'))
            raise Exception(
                "data file does not contain the input group required to train")
    else:
        print('Truth data: ', data_file.__contains__('/truth'))
        raise Exception(
            "data file does not contain the truth group required to train")

# Step 5: Define Data Generators and Models for Specific Problem Types and Input Types:
    Ngpus = config['GPU']
    Ncpus = config['CPU']
    batch_size = config['batch_size'] * Ngpus

    n_epochs = config['n_epochs']
    num_validation_steps = None
    num_training_steps = None
    model1 = None
    classWeight = None

    if problem_type is 'Classification':
        classes = np.unique(data_file.root.truth)
        print(classes)
        classes = [y.decode("utf-8") for y in classes]
        # Calculate class_weights for balanced training among classes
        Y = data_file.root.truth.read()
        Y = np.asarray([y.decode("utf-8") for y in Y])
        #Convert to Binary categories
        le = preprocessing.LabelEncoder()
        Y = np_utils.to_categorical(le.fit_transform(Y), config['n_classes'])
        classTotals = Y.sum(axis=0)
        classWeight = classTotals.max() / classTotals
        print('classWeight: ', classWeight)

        if input_type is "Both":
            num_validation_patches, all_patches, validation_list_valid = get_number_of_patches(
                data_file,
                validation_list,
                patch_shape=config["patch_shape"],
                skip_blank=config["skip_blank"],
                patch_overlap=config["validation_patch_overlap"])
            num_training_patches, all_patches, training_list_valid = get_number_of_patches(
                data_file,
                training_list,
                patch_shape=config["patch_shape"],
                skip_blank=config["skip_blank"],
                patch_overlap=config["validation_patch_overlap"])
            num_validation_steps = get_number_of_steps(
                num_validation_patches, config["validation_batch_size"])
            num_training_steps = get_number_of_steps(num_training_patches,
                                                     batch_size)
            training_generator = DataGenerator_3DCL_Classification(
                data_file,
                training_list_valid,
                batch_size=config['batch_size'],
                n_classes=config['n_classes'],
                classes=classes,
                augment=config['augment'],
                augment_flip=config['flip'],
                augment_distortion_factor=config['distort'],
                skip_blank=False,
                permute=config['permute'],
                reduce=config['reduce'])
            validation_generator = DataGenerator_3DCL_Classification(
                data_file,
                validation_list_valid,
                batch_size=config['validation_batch_size'],
                n_classes=config['n_classes'],
                classes=classes,
                augment=config['augment'],
                augment_flip=config['flip'],
                augment_distortion_factor=config['distort'],
                skip_blank=False,
                permute=config['permute'],
                reduce=config['reduce'])
        elif input_type is "Image":
            num_validation_patches, all_patches, validation_list_valid = get_number_of_patches(
                data_file,
                validation_list,
                patch_shape=config["patch_shape"],
                skip_blank=config["skip_blank"],
                patch_overlap=config["validation_patch_overlap"])
            num_training_patches, all_patches, training_list_valid = get_number_of_patches(
                data_file,
                training_list,
                patch_shape=config["patch_shape"],
                skip_blank=config["skip_blank"],
                patch_overlap=config["validation_patch_overlap"])
            num_validation_steps = get_number_of_steps(
                num_validation_patches, config["validation_batch_size"])
            num_training_steps = get_number_of_steps(num_training_patches,
                                                     batch_size)

            training_generator = DataGenerator_3D_Classification(
                data_file,
                training_list_valid,
                batch_size=config['batch_size'],
                n_classes=config['n_classes'],
                classes=classes,
                augment=config['augment'],
                augment_flip=config['flip'],
                augment_distortion_factor=config['distort'],
                skip_blank=False,
                permute=config['permute'],
                reduce=config['reduce'])
            validation_generator = DataGenerator_3D_Classification(
                data_file,
                validation_list_valid,
                batch_size=config['validation_batch_size'],
                n_classes=config['n_classes'],
                classes=classes,
                augment=config['augment'],
                augment_flip=config['flip'],
                augment_distortion_factor=config['distort'],
                skip_blank=False,
                permute=config['permute'],
                reduce=config['reduce'])
        elif input_type is "Clinical":
            validation_list_valid = validation_list
            num_validation_patches = len(validation_list)
            training_list_valid = training_list
            num_training_patches = len(training_list_valid)
            num_validation_steps = get_number_of_steps(
                num_validation_patches, config["validation_batch_size"])
            num_training_steps = get_number_of_steps(num_training_patches,
                                                     batch_size)

            training_generator = DataGenerator_CL_Classification(
                data_file,
                training_list_valid,
                batch_size=config['batch_size'],
                n_classes=config['n_classes'],
                classes=classes)
            validation_generator = DataGenerator_CL_Classification(
                data_file,
                validation_list_valid,
                batch_size=config['validation_batch_size'],
                n_classes=config['n_classes'],
                classes=classes)

        if input_type is "Both":
            # create the MLP and CNN models
            mlp = MLP.build(dim=config['CL_features'],
                            num_outputs=8,
                            branch=True)
            cnn = Resnet3D.build_resnet_18(config['input_shape'],
                                           num_outputs=8,
                                           branch=True)

            # create the input to our final set of layers as the *output* of both
            # the MLP and CNN
            combinedInput = concatenate([mlp.output, cnn.output])

            # our final FC layer head will have two dense layers, the final one is the fused classification head
            x = Dense(8, activation="relu")(combinedInput)
            x = Dense(4, activation="relu")(x)
            x = Dense(2, activation="softmax")(x)

            # our final model will accept categorical/numerical data on the MLP
            # input and images on the CNN input, outputting a single value (the
            # predicted price of the house)
            model1 = Model(inputs=[mlp.input, cnn.input], outputs=x)
            plot_model(model1, to_file="Combined.png", show_shapes=True)
        elif input_type is "Image":
            # create the MLP and CNN models
            model1 = Resnet3D.build_resnet_18(config['input_shape'],
                                              num_outputs=2,
                                              reg_factor=1e-4,
                                              branch=False)
            plot_model(model1, to_file="Resnet_nolabel.png", show_shapes=True)
        elif input_type is "Clinical":
            # create the MLP and CNN models
            model1 = MLP.build(dim=config['CL_features'],
                               num_outputs=2,
                               branch=False)
            plot_model(model1, to_file="MLP.png", show_shapes=True)

    elif problem_type is 'Segmentation':
        if input_type is "Image":
            num_validation_patches, all_patches, validation_list_valid = get_number_of_patches(
                data_file,
                validation_list,
                patch_shape=config["patch_shape"],
                skip_blank=config["skip_blank"],
                patch_overlap=config["validation_patch_overlap"])
            num_training_patches, all_patches, training_list_valid = get_number_of_patches(
                data_file,
                training_list,
                patch_shape=config["patch_shape"],
                skip_blank=config["skip_blank"],
                patch_overlap=config["validation_patch_overlap"])
            num_validation_steps = get_number_of_steps(
                num_validation_patches, config["validation_batch_size"])
            num_training_steps = get_number_of_steps(num_training_patches,
                                                     batch_size)

            training_generator = DataGenerator_3D_Segmentation(
                data_file,
                training_list_valid,
                batch_size=config['batch_size'],
                n_labels=config['n_labels'],
                labels=labels,
                augment=config['augment'],
                augment_flip=config['flip'],
                augment_distortion_factor=config['distort'],
                patch_shape=config['patch_shape'],
                patch_overlap=0,
                patch_start_offset=0,
                skip_blank=False,
                permute=config['permute'],
                reduce=config['reduce'])
            validation_generator = DataGenerator_3D_Segmentation(
                data_file,
                validation_list_valid,
                batch_size=config['validation_batch_size'],
                n_labels=config['n_labels'],
                labels=labels,
                augment=config['augment'],
                augment_flip=config['flip'],
                augment_distortion_factor=config['distort'],
                patch_shape=config['patch_shape'],
                patch_overlap=0,
                patch_start_offset=0,
                skip_blank=False,
                permute=config['permute'],
                reduce=config['reduce'])
            model1 = isensee2017_model.build()


# Step 6: Train model after compiling with problem specific parameters
## Make Model MultiGPU
    if Ngpus > 1:
        model = multi_gpu_model(model1, gpus=Ngpus)
    # model1.compile(loss="binary_crossentropy", optimizer=opt,metrics=["accuracy"])
    else:
        model = model1

    ## Tensorboard Paths for Monitoring
    figPath = os.path.sep.join(
        [config["monitor"], "{}.png".format(os.getpid())])
    jsonPath = None
    tensorboard = TensorBoard(log_dir=config['monitor'] + "\{}".format(time()))

    # OPTIMIZER
    if (config['opt'] == 'adam'):
        opt = Adam
    else:
        opt = SGD(lr=1e-3, momentum=0.9)  # Continuous Learning Rate Decay

    # Loss Function
    if (config['loss'] == 'dsc'):
        loss_function = weighted_dice_coefficient_loss
    else:
        loss_function = "binary_crossentropy"

    # Learning rate
    if config['lr']:
        learning_rate = config['lr']
    else:
        learning_rate = 1e-3

    # Monitor Metrics
    if config['metrics']:
        metrics = config['lr']
    else:
        metrics = "val_acc"  #["accuracy"]

    ## General Callbacks for all problems
    earlystop = EarlyStopping(monitor='val_acc',
                              min_delta=0.0005,
                              patience=30,
                              verbose=0,
                              mode='auto')
    checkpoint = AltModelCheckpoint(config["training_model"] + '_model.h5',
                                    model1,
                                    monitor="val_acc",
                                    save_best_only=True,
                                    verbose=1)
    logger = CSVLogger(config["training_model"] + '_log.txt', append=True)
    if config["learning_rate_epochs"]:
        lr_scheduler = LearningRateScheduler(
            partial(step_decay,
                    initial_lrate=learning_rate,
                    drop=0.5,
                    epochs_drop=None))
    else:
        callbacks.append(ReduceLROnPlateau(factor=0.5, patience=30, verbose=1))

    callbacks = [lr_scheduler, tensorboard, checkpoint, earlystop]

    model.compile(optimizer=optimizer(lr=initial_learning_rate),
                  loss=loss_function,
                  metrics=metrics)
    ##
    # define the set of callbacks to be passed to the model during training
    #callbacks = [TrainingMonitor(figPath,jsonPath=jsonPath)]

    with open(config['training_model'] + '_summary.txt', 'w') as fh:
        # Pass the file handle in as a lambda function to make it callable
        model1.summary(line_length=150, print_fn=lambda x: fh.write(x + '\n'))

    # train the network
    print("[INFO] training network...")
    #aug = ImageDataGenerator(rotation_range=25, width_shift_range=0.1,height_shift_range=0.1, shear_range=0.1, zoom_range=0.2,horizontal_flip=True, fill_mode="nearest")
    #H = model.fit(trainX, trainY, validation_data=(testX, testY), class_weight=classWeight, batch_size=Nbatches*Ngpus, epochs=Nepochs, callbacks=callbacks, verbose=1)
    H = model.fit_generator(generator=training_generator,
                            steps_per_epoch=num_training_steps,
                            epochs=n_epochs,
                            validation_data=validation_generator,
                            validation_steps=num_validation_steps,
                            callbacks=callbacks,
                            class_weight=classWeight,
                            use_multiprocessing=False,
                            workers=Ncpus)

    # Step 7: plot the training + testing loss and accuracy
    Fepochs = len(H.history['loss'])
    plt.style.use("ggplot")
    plt.figure()
    plt.plot(np.arange(0, Fepochs), H.history["loss"], label="train_loss")
    plt.plot(np.arange(0, Fepochs), H.history["val_loss"], label="val_loss")
    plt.plot(np.arange(0, Fepochs), H.history["acc"], label="acc")
    plt.plot(np.arange(0, Fepochs), H.history["val_acc"], label="val_acc")
    plt.title("Training Loss and Accuracy")
    plt.xlabel("Epoch #")
    plt.ylabel("Loss/Accuracy")
    plt.legend()
    figpath_final = config["input_type"] + '.png'
    plt.savefig(figpath_final)
    plt.show()
    hdf5_file.close()
Beispiel #8
0
 def test_kwargs_pass_through(self):
     callback = AltModelCheckpoint('path/to/model.hdf5',
                                   None,
                                   monitor='foobar')
     self.assertEqual(callback.filepath, 'path/to/model.hdf5')
     self.assertEqual(callback.monitor, 'foobar')
def keras_fit_generator(img_rows=96,
                        img_cols=96,
                        n_imgs=10**4,
                        batch_size=32,
                        regenerate=True):

    if regenerate:

        # 2.运行data_to_array,输入img_rows、img_cols为256,数据预处理后放入data文件夹目录下备用
        data_to_array(img_rows, img_cols)
        #preprocess_data()

    # 6.跳转load_data,输入训练集数据、验证集数据
    X_train, y_train, X_val, y_val = load_data()
    # 三维结构分别为(切片数目累加和,影像长,影像宽)

    # 将影像长宽取出赋给img_rows、img_cols
    # X_val, y_val = augment_validation_data(X_val, y_val, seed=10)
    img_rows = X_train.shape[1]
    img_cols = X_train.shape[2]

    # Provide the same seed and keyword arguments to the fit and flow methods

    # 这是一些参数,其实这些参数在这里可以不用改,这里的参数都是keras提供的传统的数据增强的方式
    # 转多少角度、平移多少位置,是不是对称的,都是经典的数据增强方式

    # range()返回的是range object,而np.nrange()返回的是numpy.adarray()
    # 两者都是均匀地(evenly)等分区间;
    # range尽可用于迭代,而np.arange可用作迭代,也可用做向量。
    # range()不支持步长为小数,np.arange()支持步长为小数,
    # 如果只有一个参数默认起点为0。

    # np.meshgrid 根据传入的2个一维数组参数生成2个数组元素的列表,将前后组合起来形成网格点矩阵
    # indexing影响meshgrid()函数返回的矩阵的表示形式,变成ndarray结构
    # 返回x为将x横向量向下复制img_cols行,返回y为将y列向量向右复制img_rows行,形成变成ndarray结构
    x, y = np.meshgrid(np.arange(img_rows), np.arange(img_cols), indexing='ij')

    # functools.partial 偏函数,这里是仿射变换函数,用来对图像做仿射和弹性变换
    elastic = partial(elastic_transform,
                      x=x,
                      y=y,
                      alpha=img_rows * 1.5,
                      sigma=img_rows * 0.07)

    # we create two instances with the same arguments
    # dict 创建字典,参数均为经典的数据增强方式
    data_gen_args = dict(
        featurewise_center=False,  # 布尔值,使输入数据集去中心化(均值为0), 按feature执行
        featurewise_std_normalization=
        False,  # 布尔值,将输入除以数据集的标准差以完成标准化, 按feature执行。
        rotation_range=10.,  # 整数,数据提升时图片随机转动的角度。随机选择图片的角度(0-180)
        width_shift_range=0.1,  # 浮点数,图片宽度的某个比例,数据提升时图片随机水平偏移的幅度。
        height_shift_range=0.1,  # 浮点数,图片高度的某个比例,数据提升时图片随机竖直偏移的幅度。
        horizontal_flip=True,  # 布尔值,随机水平翻转。随机图片水平翻转(水平翻转不影响图片语义)。
        vertical_flip=True,  # 布尔值,进行随机竖直翻转。
        zoom_range=[1, 1.2],  # 浮点数或形如[lower,upper]的列表,随机缩放的幅度,若为浮点数,
        #   则相当于[lower,upper] = [1 - zoom_range, 1+zoom_range]。
        #   用来进行随机的放大。
        fill_mode='constant',  # ‘constant’,‘nearest’,‘reflect’或‘wrap’之一,
        #   当进行变换时超出边界的点将根据本参数给定的方法进行处理
        preprocessing_function=elastic)  # 将上文仿射变换函数应用起来。该函数将在任何其他修改之前运行。
    #   该函数接受一个参数,为一张图片(秩为3的numpy array),
    #   并且输出一个具有相同shape的numpy array

    # ImageDataGenerator Keras中函数,用来数据扩充,增加训练数据,防止过拟合,
    # 将这些经典参数填入函数,并在接下来将图像填入函数,可得到相应的扩充后的训练集图像与训练集mask
    # **data_gen_args的目的是存放变量参数,等待之后填充(固定用法)
    image_datagen = ImageDataGenerator(**data_gen_args)
    mask_datagen = ImageDataGenerator(**data_gen_args)

    seed = 2
    # 设置种子并将训练集图像和训练集mask使用fit函数填入
    image_datagen.fit(X_train, seed=seed)
    mask_datagen.fit(y_train, seed=seed)

    # flow 生成一个迭代器,接收numpy数组和标签为参数,生成经过数据提升或标准化后的batch数据,
    # 并在一个无限循环中不断的返回batch数据
    image_generator = image_datagen.flow(X_train,
                                         batch_size=batch_size,
                                         seed=seed)
    mask_generator = mask_datagen.flow(y_train,
                                       batch_size=batch_size,
                                       seed=seed)

    # zip 两两组合,压缩后留下备用
    train_generator = zip(image_generator, mask_generator)

    # 这边是unet网络的描述,先看参数前面是两个尺度,灰度图通道是1,开始输入8通道
    # 深度是7最里面这层有1024个chanel,再看一下,dropout是0.5,残差模块连起来

    # 7.建立训练模型,需要输入开始图像的长宽及其他参数,自动生成UNet模型,以下都还在陆陆续续填参数进模型
    model = UNet((img_rows, img_cols, 1),
                 start_ch=8,
                 depth=7,
                 batchnorm=True,
                 dropout=0.5,
                 maxpool=True,
                 residual=True)
    # model = multi_gpu_model(model, gpus=2, by_name=)

    # model.load_weights('../data/weights3.h5')

    # 输出训练模型各层的参数状况进行核对
    model.summary()

    # filepath = "model_{epoch:02d}-{val_acc:.2f}.hdf5"
    # checkpoint = ModelCheckpoint(os.path.join(save_dir, filepath), monitor='val_acc', verbose=1,
    #                              save_best_only=True)

    # ModelCheckpoint 保存最佳模型
    # (1)monitor:监视值,当监测值为val_acc时,mode自动为max,当监测值为val_loss时,mode自动为min。
    # (2)save_best_only:当设置为True时,只有监测值有改进时才会保存当前的模型
    model_checkpoint = AltModelCheckpoint('../data/weights7.h5',
                                          model,
                                          monitor='val_loss',
                                          save_best_only=True)

    # ???
    c_backs = [model_checkpoint]

    # EarlyStopping 早停法,防止过拟合,例如在验证集的AUC达到一定程度后停止训练,或loss函数开始增加时停止训练
    # (1)monitor:有’acc’,’val_acc’,’loss’,’val_loss’等等。正常情况下如果有验证集,就用
    # ’val_acc’或者’val_loss’。如果用的是交叉检验就用’acc’,这里明明有验证集为什么要用'loss'?
    # (2)min_delta:增大或减小的阈值,只有大于这个部分才算作improvement。
    # 这个值的大小取决于monitor,也反映了你的容忍程度。小于这个值的都漠不关心。
    # (3)patience:能够容忍多少个epoch内都没有improvement。这个设置其实是在抖动和真正的准确率下降之间做tradeoff。
    # 如果patience设的大,那么最终得到的准确率要略低于模型可以达到的最高准确率。
    # 如果patience设的小,那么模型很可能在前期抖动,还在全图搜索的阶段就停止了,准确率一般很差。
    # patience的大小和learning rate直接相关。在learning rate设定的情况下,前期先训练几次观察抖动的epoch number,
    # 比其稍大些设置patience。在learning rate变化的情况下,建议要【略小于】最大的抖动epoch number。
    c_backs.append(EarlyStopping(monitor='loss', min_delta=0.001, patience=5))

    model = multi_gpu_model(model, gpus=2)

    # Adam 自适应学习率优化算法,可以对学习率进行实时调节(随着时间的增加逐渐变小、
    # 既有二阶矩梯度估计又有一阶矩梯度估计,下降速度快,loss设定为dice函数(医学影像处理二分类问题较好)
    model.compile(optimizer=Adam(lr=0.001),
                  loss=dice_coef_loss,
                  metrics=[dice_coef])

    model.fit_generator(
        train_generator,  # generator:生成器函数
        steps_per_epoch=n_imgs //
        batch_size,  # steps_per_epoch:整数,当生成器返回steps_per_epoch
        # 次数据时计一个epoch结束,执行下一个epoch
        epochs=20,  # epochs:整数,数据迭代的轮数
        verbose=2,  # verbose:日志显示,0为不在标准输出流输出日志信息,
        # 1为输出进度条记录,2为每个epoch输出一行记录
        shuffle=True,  # shuffle:布尔值,是否随机打乱数据,默认为True
        validation_data=(X_val, y_val),
        # validation_data=(np.concatenate([X_train,X_val]), np.concatenate([y_train,y_val]) ),
        # validation_data:验证集生成器?连起来???
        callbacks=c_backs,
        use_multiprocessing=True)  # ???