コード例 #1
0
ファイル: prediction.py プロジェクト: fragalassi/ot_da_v0
def run_validation_cases(validation_keys_file,
                         model_file,
                         training_modalities,
                         labels,
                         hdf5_file,
                         output_label_map=False,
                         output_dir=".",
                         threshold=0.5,
                         overlap=16,
                         permute=False):
    validation_indices = pickle_load(validation_keys_file)
    model = load_old_model(model_file)
    data_file = tables.open_file(hdf5_file, "r")
    for i, index in enumerate(validation_indices):
        actual = round(i / len(validation_indices) * 100, 2)
        print("Running validation case: ", actual, "%")
        if 'subject_ids' in data_file.root:
            case_directory = os.path.join(
                output_dir, data_file.root.subject_ids[index].decode('utf-8'))
        else:
            case_directory = os.path.join(output_dir,
                                          "validation_case_{}".format(index))
        run_validation_case(data_index=index,
                            output_dir=case_directory,
                            model=model,
                            data_file=data_file,
                            training_modalities=training_modalities,
                            output_label_map=output_label_map,
                            labels=labels,
                            threshold=threshold,
                            overlap=overlap,
                            permute=permute)
    data_file.close()
コード例 #2
0
def main():
    if not os.path.exists(config["hdf5_file"]):
        training_files = list()
        for label_file in glob.glob("./data/training/subject-*-label.hdr"):
            training_files.append((label_file.replace("label", "T1"), label_file.replace("label", "T2"), label_file))

        write_data_to_file(training_files, config["hdf5_file"], image_shape=config["image_shape"])

    hdf5_file_opened = tables.open_file(config["hdf5_file"], "r")

    if os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = unet_model_3d(input_shape=config["input_shape"], n_labels=config["n_labels"])

    # get training and testing generators
    train_generator, validation_generator, nb_train_samples, nb_test_samples = get_training_and_validation_generators(
        hdf5_file_opened, batch_size=config["batch_size"], data_split=config["validation_split"],
        validation_keys_file=config["validation_file"], training_keys_file=config["training_file"],
        n_labels=config["n_labels"], labels=config["labels"], augment=True)

    # run training
    train_model(model=model, model_file=config["model_file"], training_generator=train_generator,
                validation_generator=validation_generator, steps_per_epoch=nb_train_samples,
                validation_steps=nb_test_samples, initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_epochs=config["decay_learning_rate_every_x_epochs"], n_epochs=config["n_epochs"])
    hdf5_file_opened.close()
コード例 #3
0
ファイル: train.py プロジェクト: CocoInParis/2d-segmentation
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files = fetch_training_data_files()
        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"]
                           )  #config["image_shape"] = (144, 144, 144)
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        #print('start model computing')
        # model = unet_model_3d(input_shape=config["input_shape"],                # 4+(32, 32, 32)
        #                       pool_size=config["pool_size"],                    #config["pool_size"] = (2, 2, 2), maxpooling size
        #                       n_labels=config["n_labels"],                      #config["n_labels"] = len(config["labels"])
        #                       initial_learning_rate=config["initial_learning_rate"],        #config["initial_learning_rate"] = 0.00001
        #                       deconvolution=config["deconvolution"])                        #config["deconvolution"] = True  # if False, will use upsampling instead of deconvolution

        model = custom_unet(reu2018)
    #print('model loaded')
    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],  #config["batch_size"] = 6
        data_split=config["validation_split"],  #validation_split = 0.8
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    # run training
    print('training mdel')
    train_model(
        model=model,
        model_file=config[
            "model_file"],  #config["model_file"] = os.path.abspath("tumor_segmentation_model.h5")
        training_generator=train_generator,
        validation_generator=validation_generator,
        steps_per_epoch=n_train_steps,
        validation_steps=n_validation_steps,
        initial_learning_rate=config["initial_learning_rate"],
        learning_rate_drop=config["learning_rate_drop"],
        learning_rate_patience=config["patience"],
        early_stopping_patience=config["early_stop"],
        n_epochs=config["n_epochs"])
    data_file_opened.close()
    print("model has been trained already")
コード例 #4
0
ファイル: train.py プロジェクト: zjngjng/3DUnetCNN
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["hdf5_file"]):
        training_files = fetch_training_data_files()

        write_data_to_file(training_files, config["hdf5_file"], image_shape=config["image_shape"])
    hdf5_file_opened = tables.open_file(config["hdf5_file"], "r")

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = unet_model_3d(input_shape=config["input_shape"],
                              downsize_filters_factor=config["downsize_nb_filters_factor"],
                              pool_size=config["pool_size"], n_labels=config["n_labels"],
                              initial_learning_rate=config["initial_learning_rate"])

    # get training and testing generators
    train_generator, validation_generator, nb_train_samples, nb_test_samples = get_training_and_validation_generators(
        hdf5_file_opened, batch_size=config["batch_size"], data_split=config["validation_split"], overwrite=overwrite,
        validation_keys_file=config["validation_file"], training_keys_file=config["training_file"],
        n_labels=config["n_labels"])

    # run training
    train_model(model=model, model_file=config["model_file"], training_generator=train_generator,
                validation_generator=validation_generator, steps_per_epoch=nb_train_samples,
                validation_steps=nb_test_samples, initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_epochs=config["decay_learning_rate_every_x_epochs"], n_epochs=config["n_epochs"])
    hdf5_file_opened.close()
コード例 #5
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           subject_ids=subject_ids)
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        print("Loading old model file from the location: ",
              config["model_file"])
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        print("Creating new model at the location: ", config["model_file"])
        model = isensee2017_model(
            input_shape=config["input_shape"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            n_base_filters=config["n_base_filters"])

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    print("Running the Training. Model file:", config["model_file"])
    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
コード例 #6
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    # 若有则加载旧数据集,注意,此时image_shape为之前设置的
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(return_subject_ids=True)

        write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"],
                           subject_ids=subject_ids)
    data_file_opened = open_data_file(config["data_file"])

    # 加载/创建模型文件
    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = att_res_ds_unet.att_res_ds_unet_model(input_shape=config["input_shape"], n_labels=config["n_labels"],
                                                      initial_learning_rate=config["initial_learning_rate"],
                                                      n_base_filters=config["n_base_filters"])
    from keras.utils.vis_utils import plot_model
    plot_model(model, to_file='att_res_ds_uet.png', show_shapes=True)

    # get training and testing generators
    # ../unet3d/generator.py
    # 创建生成器(generator),用于后面训练
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    # run training
    # ../unet3d/training.py
    # 训练一个keras模型
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
コード例 #7
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files = fetch_training_data_files()
        print("Number of Training file Found:", len(training_files))
        write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"])

    print("Opening data file.")
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        print("Loading existing model file.")
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        print("Instantiating new model file.")
        model = unet_model_3d(input_shape=config["input_shape"],
                              pool_size=config["pool_size"],
                              n_labels=config["n_labels"],
                              initial_learning_rate=config["initial_learning_rate"],
                              deconvolution=config["deconvolution"])

    # get training and testing generators
    print("Getting training and testing generators.")
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    # run training
    print("Running the training......")
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
    print("Training DONE")
コード例 #8
0
def validate(model_file,
             ibis_data,
             input_shape=(96, 96, 96),
             batch_size=2,
             data_split=0.8,
             num_gpus=None,
             only_aa=False):
    model = load_old_model(model_file)
    validate_data = IBISGenerator(ibis_data,
                                  input_shape=input_shape,
                                  batch_size=batch_size,
                                  only_aa=only_aa,
                                  start_index=data_split,
                                  shuffle=False)

    data_gen = validate_data.generate()

    y_true = []
    y_pred = []
    num_batches = validate_data.data.shape[0] / batch_size
    for i in xrange(num_batches):
        print i, "of", num_batches - 1
        X, y = data_gen.next()
        y_true_sample = y.flatten().astype(int)
        y_pred_sample = model.predict_on_batch(X).flatten().astype(int)

        y_true += y_true_sample.tolist()
        y_pred += y_pred_sample.tolist()

        print np.where(y_true_sample == 1)
        print np.where(y_pred_sample == 1)

        tpr, fpr, _ = roc_curve(y_true_sample, y_pred_sample)
        roc_auc = auc(fpr, tpr)
        print "Batch ROCAUC:", roc_auc

    tpr, fpr, _ = roc_curve(y_true_sample, y_pred_sample.tolist())
    roc_auc = auc(fpr, tpr)
    print "Total TPR, FPR:", tpr, fpr

    model_name = os.path.splitext(os.path.basename(model_file))[0]
    pp = PdfPages("model_evaluation-{}.pdf".format(model_name))

    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.plot(fpr, tpr, lw=3., label="ROC (AUC: {})".format(roc_auc))
    ax.set_xlabel("False Positive Rate")
    ax.set_ylabel("True Positive Rate")
    ax.set_xlim([0, 1.0])
    ax.set_ylim([0.0, 1.05])

    fig.suptitle("{} Model Evaluation".format(model_name), fontsize=20)

    pp.savefig()
    pp.close()

    return roc_auc
コード例 #9
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files = fetch_training_data_files()

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"])
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = unet_model_3d(
            input_shape=config["input_shape"],
            pool_size=config["pool_size"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            deconvolution=config["deconvolution"])

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distortion_factor"],
        augment_rotation_factor=config["rotation_factor"],
        mirror=config["mirror"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],
                logging_path=config["logging_path"])
    data_file_opened.close()
コード例 #10
0
def main(overwrite=False):
    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = isensee2017_model(input_shape=config["input_shape"],
                                  n_labels=config["n_labels"],
                                  initial_learning_rate=config["initial_learning_rate"],
                                  n_base_filters=config["n_base_filters"])

    # get training and validation generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        npy_path=config["npy_path"],
        subject_ids_file=config["subject_ids_file"],
        batch_size=config["batch_size"],
        validation_batch_size=config["validation_batch_size"],
        n_labels=config["n_labels"],
        labels=config["labels"],

        training_keys_file=config["training_keys_file"],
        validation_keys_file=config["validation_keys_file"],
        data_split=config["validation_split"],
        overwrite=overwrite,

        augment=config["augment"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"],
        permute=config["permute"],

        image_shape=config["image_shape"],
        patch_shape=config["patch_shape"],

        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],

        skip_blank=config["skip_blank"]
        )

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
コード例 #11
0
def main(config=None):
    model = load_old_model(config, re_compile=False)
    data_file_opened = open_data_file(config["data_file"])
    validation_idxs = pickle_load(config['validation_file'])
    validation_generator = data_generator(data_file_opened, validation_idxs, 
                                        batch_size=config['validation_batch_size'], 
                                        n_labels=config['n_labels'], labels=config['labels'],
                                        skip_blank=config['skip_blank'], shuffle_index_list=False)
    steps = math.ceil(len(validation_idxs) / config['validation_batch_size'])
    results = model.evaluate(validation_generator, steps=steps, verbose=1)
    metrics_names = model.metrics_names
    for i, x in enumerate(metrics_names):
        print('{}: {}'.format(x, results[i]))
        
    data_file_opened.close()
コード例 #12
0
ファイル: segmentation.py プロジェクト: votnhan/3DUnet
def segmentation_for_patient(subject_fd,
                             config,
                             output_path,
                             model=None,
                             mode='size_same_input'):

    if model is None:
        model = load_old_model(config)
    subject_name = os.path.basename(subject_fd)
    image_mris, original_affine, foreground = get_subject_tensor(
        subject_fd, subject_name)
    if mode == 'size_same_input':
        slices = get_slices(foreground)
        subject_data_fixed_size, affine = crop_subject_modals(
            image_mris, input_shape, slices)
    elif mode == 'size_interpolate':
        target_shape = tuple(config['inference_shape'])
        subject_data_fixed_size, affine = resize_modal_image(
            image_mris, target_shape)
    else:
        print('Do not support mode {} for inference'.format(mode))
        return

    subject_tensor = normalize_data(subject_data_fixed_size)

    subject_tensor = np.expand_dims(subject_tensor, axis=0)
    output_predict = predict(model, subject_tensor, affine)

    if mode == 'size_same_input':
        output = restore_dimension(output_predict, slices, original_affine)
    elif mode == 'size_interpolate':
        output = resize(output_predict,
                        new_shape=original_shape,
                        interpolation='nearest')
    else:
        print('Do not support mode {} for inference'.format(mode))
        return

    output_fd = os.path.join(output_path, subject_name)
    if not os.path.exists(output_fd):
        os.makedirs(output_fd)
    output_file = os.path.join(
        output_fd, '{}_prediction{}'.format(subject_name, extension))
    output.to_filename(output_file)

    print('Patient {} is done !'.format(subject_fd))
コード例 #13
0
ファイル: train.py プロジェクト: tsaiwentage/3DUnetCNN
def main(overwrite=False):
    # # convert input images into an hdf5 file
    # if overwrite or not os.path.exists(config["data_file"]):
    #     training_files = fetch_training_data_files()
    #
    #     write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"])
    # data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = unet_model_3d(
            input_shape=config["input_shape"],
            pool_size=config["pool_size"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            deconvolution=config["deconvolution"])

    from keras.utils.vis_utils import plot_model
    plot_model(model, to_file='original_uet.png', show_shapes=True)
コード例 #14
0
ファイル: segmentation.py プロジェクト: votnhan/3DUnet
def segmentation_for_set_patients(list_ids_file, path_dataset, config,
                                  output_path):
    model = load_old_model(config)
    file = open(list_ids_file, 'r')
    contents = file.read()
    list_ids = contents.split('\n')
    file.close()
    pattern = os.path.join(path_dataset, '*', '*')
    list_paths = glob.glob(pattern)
    for idx, ids in enumerate(list_ids):
        for path_subject in list_paths:
            if ids in path_subject:
                segmentation_for_patient(path_subject,
                                         config,
                                         output_path,
                                         model=model)
                break

        print('Done {}/{} patients'.format(idx + 1, len(list_ids)))

    print('Done for dataset: {}'.format(path_dataset))
コード例 #15
0
def main(overwrite=False):
    # # convert input images into an hdf5 file
    # # 若有则加载旧数据集,注意,此时image_shape为之前设置的
    # if overwrite or not os.path.exists(config["data_file"]):
    #     training_files, subject_ids = fetch_training_data_files(return_subject_ids=True)
    #
    #     write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"],
    #                        subject_ids=subject_ids)
    # data_file_opened = open_data_file(config["data_file"])

    # 加载/创建模型文件
    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = isensee2017_model(
            input_shape=config["input_shape"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            n_base_filters=config["n_base_filters"])
    from keras.utils.vis_utils import plot_model
    plot_model(model, to_file='isensee_unet.png', show_shapes=True)
コード例 #16
0
 def __init__(self, conf):
     self.config = conf
     self.model = load_old_model(self.config.model_file)
コード例 #17
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           subject_ids=subject_ids)
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = isensee2017_model(
            input_shape=config["input_shape"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            n_base_filters=config["n_base_filters"])

    with open('isensemodel_original.txt', 'w') as fh:
        # Pass the file handle in as a lambda function to make it callable
        model.summary(line_length=150, print_fn=lambda x: fh.write(x + '\n'))

    # Save Model
    plot_model(model, to_file="isensemodel_original.png", show_shapes=True)

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
コード例 #18
0
ファイル: train_old.py プロジェクト: maikia/StrokeUNET
def main(overwrite=False):
    # convert input images into an hdf5 file
    print(overwrite or not os.path.exists(config["data_file"]))
    print('path: ', os.path.exists(config["data_file"]))
    if overwrite or not os.path.exists(config["data_file"]):
        training_files = fetch_training_data_files()
        # try:
        write_data_to_file(
            training_files,
            config["data_file"],
            image_shape=config["image_shape"])  #, normalize=False)
        # except:
        #    import pdb; pdb.set_trace()
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = unet_model_3d(
            input_shape=config["input_shape"],
            pool_size=config["pool_size"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            deconvolution=config["deconvolution"])
        print(model.summary())
    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=False,  #overwrite, # set to False so that the training idcs 
        # are used as previously; as they are now used for the
        # normalization already in write_data_to_file (above)
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])
    # normalize the dataset if required
    # use only the training img (training_keys_file)
    fetch_training_data_files()

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
コード例 #19
0
def main(overwrite_data=False, overwrite_model=False):
    # run if the data not already stored hdf5
    if overwrite_data or not os.path.exists(config["data_file"]):
        _save_new_h5_datafile(config["data_file"],
                              new_image_shape=config["image_shape"])

    data_file_opened = open_data_file(config["data_file"])

    if not overwrite_model and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model

        print('initializing new isensee model with input shape',
              config['input_shape'])
        '''
        model = isensee2017_model(
            input_shape=config["input_shape"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            n_base_filters=config["n_base_filters"])
        '''
        model = unet_model_3d(
            input_shape=config["input_shape"],
            pool_size=config["pool_size"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            deconvolution=config["deconvolution"],
            n_base_filters=config["n_base_filters"])

    # get training and testing generators
    (train_generator, validation_generator, n_train_steps,
     n_validation_steps) = get_training_and_validation_generators(
         data_file_opened,
         batch_size=config["batch_size"],
         data_split=config["validation_split"],
         overwrite=overwrite_data,
         validation_keys_file=config["validation_file"],
         training_keys_file=config["training_file"],
         n_labels=config["n_labels"],
         labels=config["labels"],
         patch_shape=config["patch_shape"],
         validation_batch_size=config["validation_batch_size"],
         validation_patch_overlap=config["validation_patch_overlap"],
         training_patch_start_offset=config["training_patch_start_offset"],
         permute=config["permute"],
         augment=config["augment"],
         skip_blank=config["skip_blank"],
         augment_flip=config["flip"],
         augment_distortion_factor=config["distort"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
コード例 #20
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(return_subject_ids=True)

        write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"],
                           subject_ids=subject_ids)
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
        # new_model = isensee2017_model(input_shape=config["input_shape"], n_labels=config["n_labels"],
        #                              initial_learning_rate=config["initial_learning_rate"],
        #                              n_base_filters=config["n_base_filters"])

        if config['freeze_encoder']:
            last_index = list(layer.name for layer in model.layers) \
                .index('up_sampling3d_1')
            for layer in model.layers[:last_index]:
                layer.trainable = False
            from keras.optimizers import Adam
            from unet3d.model.isensee2017 import weighted_dice_coefficient_loss
            model.compile(optimizer=Adam(lr=config['initial_learning_rate']), loss=weighted_dice_coefficient_loss)
        # for new_layer, layer in zip(new_model.layers[1:], old_model.layers[1:]):
        #     assert new_layer.name == layer.name
        #     new_layer.set_weights(layer.get_weights())
        # model = new_model
    else:
        # instantiate new model
        model = isensee2017_model(input_shape=config["input_shape"], n_labels=config["n_labels"],
                                  initial_learning_rate=config["initial_learning_rate"],
                                  n_base_filters=config["n_base_filters"])
    model.summary()

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        augment=config["augment"],
        skip_blank=config["skip_blank"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
コード例 #21
0
ファイル: train.py プロジェクト: nggbaobkit/3DUnet
def main(config=None):
    # convert input images into an hdf5 file
    overwrite = config['overwrite']
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           subject_ids=subject_ids)
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        # if this happens, then the code wont care what is in "model_name" in config because it will take whatever
        # the pre-trained was (either 3d_unet_residual or attention_unet) to continue training. need to be careful
        # with this.
        model = load_old_model(config, re_compile=False)
        model.summary()
        # visualize_filters_shape(model)
    else:
        # instantiate new model
        if (config["model_name"] == "3d_unet_residual"):
            """3D Unet Residual Model"""
            model = isensee2017_model(input_shape=config["input_shape"],
                                      n_labels=config["n_labels"],
                                      n_base_filters=config["n_base_filters"],
                                      activation_name='softmax')
            optimizer = getattr(
                opts,
                config["optimizer"]["name"])(**config["optimizer"].get('args'))
            loss = getattr(module_metric, config["loss_fc"])
            metrics = [getattr(module_metric, x) for x in config["metrics"]]
            model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
            model.summary()
            # visualize_filters_shape(model)
        elif (config["model_name"] == "attention_unet"):
            """Attention Unet Model"""
            model = attention_unet_model(
                input_shape=config["input_shape"],
                n_labels=config["n_labels"],
                n_base_filters=config["n_base_filters"],
                activation_name='softmax')
            optimizer = getattr(
                opts,
                config["optimizer"]["name"])(**config["optimizer"].get('args'))
            loss = getattr(module_metric, config["loss_fc"])
            metrics = [getattr(module_metric, x) for x in config["metrics"]]
            model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
            model.summary()
            # visualize_filters_shape(model)
        else:
            """Wrong entry for model_name"""
            raise Exception(
                'Look at field model_best in config.json! This field can be either 3d_unet_residual or attention_unet.'
            )

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        validation_batch_size=config["validation_batch_size"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["optimizer"]["args"]["lr"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],
                model_best_path=config['model_best'])
    data_file_opened.close()
コード例 #22
0
import tables
import predict_h5
from unet3d.training import load_old_model

model = load_old_model('isensee_2017_model.h5')
data_file = tables.open_file('data/h5/ozerki.h5', "r")

for i in range(9):
    with open(f"data/out/result{i}", "w") as f:
        print(predict_h5.run_validation_case(i, 'data/out', model, data_file,
                                             ["t1"]),
              file=f)
コード例 #23
0
    def main(self, overwrite_data=True, overwrite_model=True):
        # convert input images into an hdf5 file
        if overwrite_data or not os.path.exists(self.config.data_file):
            training_files, subject_ids = self.fetch_training_data_files(
                return_subject_ids=True)
            write_data_to_file(training_files,
                               self.config.data_file,
                               image_shape=self.config.image_shape,
                               subject_ids=subject_ids)
        else:
            print(
                "Reusing previously written data file. Set overwrite_data to True to overwrite this file."
            )

        data_file_opened = open_data_file(self.config.data_file)

        if not overwrite_model and os.path.exists(self.config.model_file):
            model = load_old_model(self.config.model_file)
        else:
            # instantiate new model

            model, context_output_name = isensee2017_model(
                input_shape=self.config.input_shape,
                n_labels=self.config.n_labels,
                initial_learning_rate=self.config.initial_learning_rate,
                n_base_filters=self.config.n_base_filters,
                loss_function=self.config.loss_function,
                shortcut=self.config.shortcut)

        # get training and testing generators

        train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
            data_file_opened,
            batch_size=self.config.batch_size,
            data_split=self.config.validation_split,
            overwrite_data=overwrite_data,
            validation_keys_file=self.config.validation_file,
            training_keys_file=self.config.training_file,
            n_labels=self.config.n_labels,
            labels=self.config.labels,
            patch_shape=self.config.patch_shape,
            validation_batch_size=self.config.validation_batch_size,
            validation_patch_overlap=self.config.validation_patch_overlap,
            training_patch_overlap=self.config.training_patch_overlap,
            training_patch_start_offset=self.config.
            training_patch_start_offset,
            permute=self.config.permute,
            augment=self.config.augment,
            skip_blank=self.config.skip_blank,
            augment_flip=self.config.flip,
            augment_distortion_factor=self.config.distort)

        # run training
        train_model(model=model,
                    model_file=self.config.model_file,
                    training_generator=train_generator,
                    validation_generator=validation_generator,
                    steps_per_epoch=n_train_steps,
                    validation_steps=n_validation_steps,
                    initial_learning_rate=self.config.initial_learning_rate,
                    learning_rate_drop=self.config.learning_rate_drop,
                    learning_rate_patience=self.config.patience,
                    early_stopping_patience=self.config.early_stop,
                    n_epochs=self.config.epochs,
                    niseko=self.config.niseko)

        data_file_opened.close()
コード例 #24
0
def main(overwrite=False):
    args = get_args.train()
    overwrite = args.overwrite

    # config["data_file"] = get_brats_data_h5_path(args.challenge, args.year,
    #                                              args.inputshape, args.isbiascorrection,
    #                                              args.normalization, args.clahe,
    #                                              args.histmatch)

    # print(config["data_file"])

    print_section("Open file")
    data_file_opened = open_data_file(config["data_file"])

    print_section("get training and testing generators")
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators_new(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_steps_file=config["n_steps_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        is_create_patch_index_list_original=config["is_create_patch_index_list_original"],
        augment_flipud=config["augment_flipud"],
        augment_fliplr=config["augment_fliplr"],
        augment_elastic=config["augment_elastic"],
        augment_rotation=config["augment_rotation"],
        augment_shift=config["augment_shift"],
        augment_shear=config["augment_shear"],
        augment_zoom=config["augment_zoom"],
        n_augment=config["n_augment"],
        skip_blank=config["skip_blank"])

    print("-"*60)
    print("# Load or init model")
    print("-"*60)
    if not overwrite and os.path.exists(config["model_file"]):
        print("load old model")
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        print("init model model")
        model = unet_model_3d(input_shape=config["input_shape"],
                              pool_size=config["pool_size"],
                              n_labels=config["n_labels"],
                              initial_learning_rate=config["initial_learning_rate"],
                              deconvolution=config["deconvolution"],
                              depth=config["depth"],
                              n_base_filters=config["n_base_filters"])

    # model.summary()

    # import nibabel as nib
    # laptop_save_dir = "C:/Users/minhm/Desktop/temp/"
    # desktop_save_dir = "/home/minhvu/Desktop/temp/"
    # save_dir = desktop_save_dir
    # temp_in_path = desktop_save_dir + "template.nii.gz"
    # temp_out_path = desktop_save_dir + "out.nii.gz"
    # temp_out_truth_path = desktop_save_dir + "truth.nii.gz"

    # n_validation_samples = 0
    # validation_samples = list()
    # for i in range(20):
    #     print(i)
    #     x, y = next(train_generator)
    #     hash_x = hash(str(x))
    #     validation_samples.append(hash_x)
    #     n_validation_samples += x.shape[0]

    #     temp_in = nib.load(temp_in_path)
    #     temp_out = nib.Nifti1Image(x[0][0], affine=temp_in.affine)
    #     nib.save(temp_out, temp_out_path)

    #     temp_out = nib.Nifti1Image(y[0][0], affine=temp_in.affine)
    #     nib.save(temp_out, temp_out_truth_path)

    # print(n_validation_samples)

    print("-"*60)
    print("# start training")
    print("-"*60)
    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    data_file_opened.close()
コード例 #25
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(return_subject_ids=True)

        write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"],
                           subject_ids=subject_ids)

    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        base_model = load_old_model(config["model_file"])
        model = get_multiGPUmodel(base_model=base_model,n_labels=config["n_labels"],GPU=config["GPU"])
    else:
        # instantiate new model HYbrid Dense-Unet model from HDense project
        parser = argparse.ArgumentParser(description='Keras DenseUnet Training')
        parser.add_argument('-b', type=int, default= 1 )#config["batch_size"])
        parser.add_argument('-input_size', type=int, default= config["patch_shape"][0]) # 224 ) 
        parser.add_argument('-input_cols', type=int, default= config["patch_shape"][2]) #  8)
        args = parser.parse_args()
        #print(args.b)
        #model = dense_rnn_net(args)
        base_model = denseunet_3d(args)
        sgd = SGD(lr=1e-3, momentum=0.9, nesterov=True)
        model = base_model
        base_model.compile(optimizer=sgd, loss=[weighted_crossentropy])

    # get training and testing generators

     # Save Model
    plot_model(base_model,to_file="liver_segmentation_HDenseUnet.png",show_shapes=True)

    # Open the file
    with open(config['model_summaryfile'],'w') as fh:
        # Pass the file handle in as a lambda function to make it callable
        base_model.summary(line_length=150,print_fn=lambda x: fh.write(x + '\n'))

    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"]*config["GPU"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"]*config["GPU"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])
      
    print('INFO: Training Details','\n Batch Size : ',config["batch_size"]*config["GPU"]
                                  ,'\n Epoch Size : ',config["n_epochs"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],base_model=base_model)
    data_file_opened.close()
コード例 #26
0
def main(overwrite=False):
    # convert input images into an hdf5 file
    pdb.set_trace()
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           modality_names=config['all_modalities'],
                           subject_ids=subject_ids,
                           mean_std_file=config['mean_std_file'])
#     return
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = isensee2017_model(
            input_shape=config["input_shape"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            n_base_filters=config["n_base_filters"])

    # get training and testing generators


#     pdb.set_trace()
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"],
        pred_specific=config['pred_specific'],
        overlap_label=config['overlap_label_generator'],
        for_final_val=config['for_final_val'])

    # run training
    #     pdb.set_trace()
    time_0 = time.time()
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],
                logging_file=config['logging_file'])
    print('Training time:', sec2hms(time.time() - time_0))
    data_file_opened.close()
コード例 #27
0
ファイル: train.py プロジェクト: votnhan/3DUnet
def main(config=None):
    # convert input images into an hdf5 file
    overwrite = config['overwrite']
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           subject_ids=subject_ids,
                           norm_type=config['normalization_type'])
    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config, re_compile=False)
    else:
        # instantiate new model
        model = isensee2017_model(input_shape=config["input_shape"],
                                  n_labels=config["n_labels"],
                                  n_base_filters=config["n_base_filters"],
                                  activation_name='softmax')

        optimizer = getattr(
            opts,
            config["optimizer"]["name"])(**config["optimizer"].get('args'))
        loss = getattr(module_metric, config["loss_fc"])
        metrics = [getattr(module_metric, x) for x in config["metrics"]]
        model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        validation_batch_size=config["validation_batch_size"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["optimizer"]["args"]["lr"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],
                model_best_path=config['model_best'])
    data_file_opened.close()
コード例 #28
0
ファイル: train_2.py プロジェクト: cchmc-dll/ai_training
def main(overwrite=False):
    # convert input images into an hdf5 file
    if overwrite or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)

        write_data_to_file(training_files,
                           config["data_file"],
                           image_shape=config["image_shape"],
                           subject_ids=subject_ids)

    data_file_opened = open_data_file(config["data_file"])

    if not overwrite and os.path.exists(config["model_file"]):
        base_model = load_old_model(config["model_file"])
        model = get_multiGPUmodel(base_model=base_model,
                                  n_labels=config["n_labels"],
                                  GPU=config["GPU"])
    else:
        # instantiate new model
        base_model, model = unet_model_3d_multiGPU(
            input_shape=config["input_shape"],
            pool_size=config["pool_size"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            deconvolution=config["deconvolution"],
            GPU=config["GPU"])
    # Save Model
    plot_model(base_model,
               to_file="liver_segmentation_model_581_resize_1GPU.png",
               show_shapes=True)
    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened,
        batch_size=config["batch_size"] * config["GPU"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"] * config["GPU"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])

    print('INFO: Training Details', '\n Batch Size : ',
          config["batch_size"] * config["GPU"], '\n Epoch Size : ',
          config["n_epochs"])

    # For debugging ONLY
    # n_train_steps = 10

    # run training
    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=validation_generator,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"],
                base_model=base_model)
    data_file_opened.close()
コード例 #29
0
def main(overwrite=False):

    # convert input images into an hdf5 file
    if overwrite or not (os.path.exists(config["data_file0"])
                         and os.path.exists(config["data_file1"])):

        training_files, subject_ids = fetch_training_data_files(
            return_subject_ids=True)
        training_files0, training_files1 = training_files
        subject_ids0, subject_ids1 = subject_ids

        if not os.path.exists(config["data_file0"]):
            write_data_to_file(training_files0,
                               config["data_file0"],
                               image_shape=config["image_shape"],
                               subject_ids=subject_ids0)
        if not os.path.exists(config["data_file1"]):
            write_data_to_file(training_files1,
                               config["data_file1"],
                               image_shape=config["image_shape"],
                               subject_ids=subject_ids1)

    data_file_opened0 = open_data_file(config["data_file0"])
    data_file_opened1 = open_data_file(config["data_file1"])

    if not overwrite and os.path.exists(config["model_file"]):
        model = load_old_model(config["model_file"])
    else:
        # instantiate new model
        model = siam3dunet_model(
            input_shape=config["input_shape"],
            n_labels=config["n_labels"],
            initial_learning_rate=config["initial_learning_rate"],
            n_base_filters=config["n_base_filters"])

        #model = testnet_model(input_shape=config["input_shape"], n_labels=config["n_labels"], initial_learning_rate=config["initial_learning_rate"], n_base_filters=config["n_base_filters"])

        #if os.path.exists(config["model_file"]):
        #    model = load_weights(config["model_file"])

    # get training and testing generators
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
        data_file_opened0,
        data_file_opened1,
        batch_size=config["batch_size"],
        data_split=config["validation_split"],
        overwrite=overwrite,
        validation_keys_file=config["validation_file"],
        training_keys_file=config["training_file"],
        n_labels=config["n_labels"],
        labels=config["labels"],
        patch_shape=config["patch_shape"],
        validation_batch_size=config["validation_batch_size"],
        validation_patch_overlap=config["validation_patch_overlap"],
        training_patch_start_offset=config["training_patch_start_offset"],
        permute=config["permute"],
        augment=config["augment"],
        skip_blank=config["skip_blank"],
        augment_flip=config["flip"],
        augment_distortion_factor=config["distort"])
    '''
    train_data = []
    train_label = []
    for i in range(n_train_steps):
        a, b = next(train_generator)
        train_data.append(a)
        train_label.append(b)

        a0, a1 = a

        for i in range(len(a0[0,0,0,0,:])):
            a0_0 = a0[0,2,:,:,i]
            if a0_0.min() == a0_0.max():
                a0_0 = a0_0 - a0_0
            else:                
                a0_0 = (a0_0-a0_0.min())/(a0_0.max()-a0_0.min())
        #print (a0_0.shape)
        #print (a0_0.max())
        #print (a0_0.min())
            imsave(f'vis_img/{i}.jpg', a0_0)
        raise
    '''

    test_data, test_label = next(validation_generator)
    test_g = (test_data, test_label)

    train_data, train_label = next(train_generator)
    train_g = (train_data, train_label)

    if not overwrite and os.path.exists(config["model_file"]):

        txt_file = open(f"output_log.txt", "w")

        #res = model.evaluate(test_data, test_label)
        #print (res)
        pre = model.predict(test_data)
        #print ([i for i in pre[0]])
        #print ([int(i) for i in test_label[0]])
        for i in range(len(pre[0])):
            txt_file.write(
                str(pre[0][i][0]) + ' ' + str(test_label[0][i]) + "\n")

        pre_train = model.predict(train_data)
        for i in range(len(pre_train[0])):
            txt_file.write(
                str(pre_train[0][i][0]) + ' ' + str(train_label[0][i]) + "\n")

        txt_file.close()
        raise

    # run training

    train_model(model=model,
                model_file=config["model_file"],
                training_generator=train_generator,
                validation_generator=test_g,
                steps_per_epoch=n_train_steps,
                validation_steps=n_validation_steps,
                initial_learning_rate=config["initial_learning_rate"],
                learning_rate_drop=config["learning_rate_drop"],
                learning_rate_patience=config["patience"],
                early_stopping_patience=config["early_stop"],
                n_epochs=config["n_epochs"])
    '''
    for i in range(len(train_label)):
        #scores = model.evaluate(train_data[i], train_label[i], verbose=1)
        scores = model.predict(train_data[i])
        print (len(scores[0]))
    '''

    data_file_opened0.close()
    data_file_opened1.close()