Beispiel #1
0
def train(fold, train_patient_indexes, val_patient_indexes):

    log_dir = 'fold_' + str(fold) + '/'
    if not os.path.isdir(log_dir):
        os.mkdir(log_dir)
    num_slices_train = len(train_patient_indexes) * 189
    num_slices_val = len(val_patient_indexes) * 189

    # Create model
    K.clear_session()
    model = create_xception_unet_n(
        input_shape=input_shape,
        pretrained_weights_file=pretrained_weights_file)
    model.compile(optimizer=Adam(lr=1e-3), loss=get_loss, metrics=[dice])

    # Get callbacks
    checkpoint = ModelCheckpoint(
        log_dir + 'ep={epoch:03d}-loss={loss:.3f}-val_loss={val_loss:.3f}.h5',
        verbose=1,
        monitor='val_loss',
        save_weights_only=True,
        save_best_only=True,
        period=1)
    reduce_lr = ReduceLROnPlateau(monitor='val_loss',
                                  factor=0.2,
                                  min_delta=1e-3,
                                  patience=3,
                                  verbose=1)
    early_stopping = EarlyStopping(monitor='val_loss',
                                   min_delta=0,
                                   patience=5,
                                   verbose=1)
    csv_logger = CSVLogger(log_dir + 'record.csv')
    tensorboard = TensorBoard(log_dir=log_dir)

    # train the model
    model.fit_generator(
        create_train_date_generator(patient_indexes=train_patient_indexes,
                                    h5_file_path=data_file_path,
                                    batch_size=batch_size),
        steps_per_epoch=max(1, num_slices_train // batch_size),
        validation_data=create_val_date_generator(
            patient_indexes=val_patient_indexes,
            h5_file_path=data_file_path,
            batch_size=9),
        validation_steps=max(1, num_slices_val // 9),
        epochs=num_epoch,
        initial_epoch=0,
        callbacks=[
            checkpoint, reduce_lr, early_stopping, tensorboard, csv_logger
        ])
    model.save_weights(log_dir + 'trained_final_weights.h5')
    #save model itself
    model.save(os.path.join(log_dir, 'trained_final_model'))
    # model.save(os.path.join(log_dir, 'trained_final_model.h5'))

    # Evaluate model
    predicts = []
    labels = []
    f = create_val_date_generator(patient_indexes=val_patient_indexes,
                                  h5_file_path=data_file_path)
    for _ in range(num_slices_val):
        img, label = f.__next__()
        predicts.append(model.predict(img))
        labels.append(label)
    predicts = np.array(predicts)
    labels = np.array(labels)
    score_record = get_score_from_all_slices(labels=labels, predicts=predicts)

    # save score
    df = pd.DataFrame(score_record)
    df.to_csv(os.path.join(log_dir, 'score_record.csv'), index=False)

    # print score
    mean_score = {}
    for key in score_record.keys():
        print('In fold ', fold, ', average', key, ' value is: \t ',
              np.mean(score_record[key]))
        mean_score[key] = np.mean(score_record[key])

    # exit training
    K.clear_session()
    return mean_score
Beispiel #2
0
    sys.exit()

# sample_input = "/scratch/hasm/Data/Lesion/ATLAS_R1.1/Only_Data/Site1/031768/t01/031768_t1w_deface_stx.nii.gz"

print("".join(["data_file_path: (", str(data_file_path), ")"]), flush=True)
print("".join(["num_patients: (", str(num_patients), ")"]), flush=True)
print("".join(["num_slices: (", str(num_slices), ")"]), flush=True)
print("".join(["input_shape: (", str(input_shape), ")"]), flush=True)
print("".join([
    "xnet_pretrained_weights_file: (",
    str(xnet_pretrained_weights_file), ")"
]),
      flush=True)

model = create_xception_unet_n(
    input_shape=input_shape,
    pretrained_weights_file=xnet_pretrained_weights_file)

print("Generated Model", flush=True)

# val_patient_indexes = np.array([1])
for patient_index in np.arange(num_patients):
    val_patient_indexes = np.array([patient_index])
    num_slices_val = len(val_patient_indexes) * num_slices

    print("".join(["val_patient_indexes: ",
                   str(val_patient_indexes)]),
          flush=True)

    output_path_seg_final = os.path.join(
        output_dir, "".join(