Ejemplo n.º 1
0
def generate_overlays_for_task(task_name_or_id,
                               output_folder,
                               num_processes=8,
                               modality_idx=0,
                               use_preprocessed=True,
                               data_identifier=default_data_identifier):
    if isinstance(task_name_or_id, str):
        if not task_name_or_id.startswith("Task"):
            task_name_or_id = int(task_name_or_id)
            task_name = convert_id_to_task_name(task_name_or_id)
        else:
            task_name = task_name_or_id
    else:
        task_name = convert_id_to_task_name(int(task_name_or_id))

    if not use_preprocessed:
        folder = join(nnUNet_raw_data, task_name)

        identifiers = [
            i[:-7] for i in subfiles(
                join(folder, 'labelsTr'), suffix='.nii.gz', join=False)
        ]

        image_files = [
            join(folder, 'imagesTr', i + "_%04.0d.nii.gz" % modality_idx)
            for i in identifiers
        ]
        seg_files = [
            join(folder, 'labelsTr', i + ".nii.gz") for i in identifiers
        ]

        assert all([isfile(i) for i in image_files])
        assert all([isfile(i) for i in seg_files])

        maybe_mkdir_p(output_folder)
        output_files = [join(output_folder, i + '.png') for i in identifiers]
        multiprocessing_plot_overlay(image_files, seg_files, output_files, 0.6,
                                     num_processes)
    else:
        folder = join(preprocessing_output_dir, task_name)
        if not isdir(folder):
            raise RuntimeError("run preprocessing for that task first")
        matching_folders = subdirs(folder, prefix=data_identifier + "_stage")
        if len(matching_folders) == 0:
            "run preprocessing for that task first (use default experiment planner!)"
        matching_folders.sort()
        folder = matching_folders[-1]
        identifiers = [
            i[:-4] for i in subfiles(folder, suffix='.npz', join=False)
        ]
        maybe_mkdir_p(output_folder)
        output_files = [join(output_folder, i + '.png') for i in identifiers]
        image_files = [join(folder, i + ".npz") for i in identifiers]
        maybe_mkdir_p(output_folder)
        multiprocessing_plot_overlay_preprocessed(image_files,
                                                  output_files,
                                                  overlay_intensity=0.6,
                                                  num_processes=num_processes,
                                                  modality_index=modality_idx)
def main():
    argparser = argparse.ArgumentParser(usage="Used to determine the postprocessing for a trained model. Useful for "
                                              "when the best configuration (2d, 3d_fullres etc) as selected manually.")
    argparser.add_argument("-m", type=str, required=True, help="U-Net model (2d, 3d_lowres, 3d_fullres or "
                                                               "3d_cascade_fullres)")
    argparser.add_argument("-t", type=str, required=True, help="Task name or id")
    argparser.add_argument("-tr", type=str, required=False, default=None,
                           help="nnUNetTrainer class. Default: %s, unless 3d_cascade_fullres "
                                "(then it's %s)" % (default_trainer, default_cascade_trainer))
    argparser.add_argument("-pl", type=str, required=False, default=default_plans_identifier,
                           help="Plans name, Default=%s" % default_plans_identifier)
    argparser.add_argument("-val", type=str, required=False, default="validation_raw",
                           help="Validation folder name. Default: validation_raw")

    args = argparser.parse_args()
    model = args.m
    task = args.t
    trainer = args.tr
    plans = args.pl
    val = args.val

    if not task.startswith("Task"):
        task_id = int(task)
        task = convert_id_to_task_name(task_id)

    if trainer is None:
        if model == "3d_cascade_fullres":
            trainer = "nnUNetTrainerV2CascadeFullRes"
        else:
            trainer = "nnUNetTrainerV2"

    folder = get_output_folder_name(model, task, trainer, plans, None)

    consolidate_folds(folder, val)
Ejemplo n.º 3
0
def main():
    import argparse
    parser = argparse.ArgumentParser(
        description=
        "We extend nnUNet to offer self-supervision tasks. This step is to"
        " split the dataset into two - self-supervision input and self- "
        "supervisio output folder.")
    parser.add_argument(
        "-t",
        type=int,
        help="Task id. The task name you wish to run self-supervision task for. "
        "It must have a matching folder 'TaskXXX_' in the raw "
        "data folder",
        required=True)
    parser.add_argument(
        "-ss",
        help=
        "Run self-supervision pretext asks. Specify which self-supervision task you "
        "wish to train. Current supported tasks: context_restoration| jigsaw_puzzle | byol"
    )

    args = parser.parse_args()

    base = join(os.environ['nnUNet_raw_data_base'], 'nnUNet_raw_data')
    task_name = convert_id_to_task_name(args.t)
    target_base = join(base, task_name)
    pretext = str(args.ss)

    print(
        f'Hey there: here\'s pretext task {pretext} for {task_name}. '
        f'Path to get ss datasets are {join(target_base, "ssInput" + "BYOL")} and {join(target_base, "ssOutput" + "BYOL")}'
    )
Ejemplo n.º 4
0
def export_entry_point():
    import argparse
    parser = argparse.ArgumentParser(description="Use this script to export models to a zip file for sharing with "
                                                 "others. You can upload the zip file and then either share the url "
                                                 "for usage with nnUNet_download_pretrained_model_by_url, or share the "
                                                 "zip for usage with nnUNet_install_pretrained_model_from_zip")
    parser.add_argument('-t', type=str, help='task name or task id')
    parser.add_argument('-o', type=str, help='output file name. Should end with .zip')
    parser.add_argument('-m', nargs='+',
                        help='list of model configurations. Default: 2d 3d_lowres 3d_fullres 3d_cascade_fullres. Must '
                             'be adapted to fit the available models of a task',
                        default=("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), required=False)
    parser.add_argument('-tr', type=str, help='trainer class used for 2d 3d_lowres and 3d_fullres. '
                                              'Default: %s' % default_trainer, required=False, default=default_trainer)
    parser.add_argument('-trc', type=str, help='trainer class used for 3d_cascade_fullres. '
                                              'Default: %s' % default_cascade_trainer, required=False,
                        default=default_cascade_trainer)
    parser.add_argument('-pl', type=str, help='nnunet plans identifier. Default: %s' % default_plans_identifier,
                        required=False, default=default_plans_identifier)
    args = parser.parse_args()

    taskname = args.t
    if taskname.startswith("Task"):
        pass
    else:
        try:
            taskid = int(taskname)
        except Exception as e:
            print('-t must be either a Task name (TaskXXX_YYY) or a task id (integer)')
            raise e
        taskname = convert_id_to_task_name(taskid)

    export_pretrained_model(taskname, args.o, args.m, args.tr, args.trc, args.pl)
Ejemplo n.º 5
0
def export_for_paper():
    output_base = "/media/fabian/DeepLearningData/nnunet_trained_models"
    task_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 17, 24, 27, 29, 35, 48, 55, 61, 38]
    for t in task_ids:
        if t == 61:
            models = ("3d_fullres",)
        else:
            models = ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres")
        taskname = convert_id_to_task_name(t)
        print(taskname)
        output_folder = join(output_base, taskname)
        os.makedirs(output_folder, exist_ok=True)
        copy_pretrained_models_for_task(taskname, output_folder, models)
        copy_ensembles(taskname, output_folder)
    compress_everything(output_base, 8)
def test_nnUNetTrainerV2_train_and_validate(tmp_path: Path, network: str,
                                            fold: int):
    prepare_paths(output_dir=tmp_path)
    task = nnp.convert_id_to_task_name(HIPPOCAMPUS_TASK_ID)
    decompress_data = True
    deterministic = False
    run_mixed_precision = True
    (
        plans_file,
        output_folder_name,
        dataset_directory,
        batch_dice,
        stage,
        trainer_class,
    ) = nndc.get_default_configuration(network, task, TEST_TRAINER_CLASS_NAME,
                                       default_plans_identifier)
    assert issubclass(trainer_class, nnUNetTrainerV2)
    assert nnUNetTrainerV2_test is trainer_class
    trainer = trainer_class(
        plans_file,
        fold,
        output_folder=output_folder_name,
        dataset_directory=dataset_directory,
        batch_dice=batch_dice,
        stage=stage,
        unpack_data=decompress_data,
        deterministic=deterministic,
        fp16=run_mixed_precision,
    )
    assert trainer.max_num_epochs == 2
    assert trainer.num_batches_per_epoch == 2
    assert trainer.num_val_batches_per_epoch == 2
    trainer.initialize(True)
    trainer.run_training()
    trainer.network.eval()
    trainer.validate(
        save_softmax=False,
        validation_folder_name="validation_raw",
        run_postprocessing_on_folds=True,
        overwrite=True,
    )
    check_expected_training_output(check_dir=tmp_path, network=network)
Ejemplo n.º 7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-i",
        '--input_folder',
        help="Must contain all modalities for each patient in the correct"
        " order (same as training). Files must be named "
        "CASENAME_XXXX.nii.gz where XXXX is the modality "
        "identifier (0000, 0001, etc)",
        required=True)
    parser.add_argument('-o',
                        "--output_folder",
                        required=True,
                        help="folder for saving predictions")
    parser.add_argument('-t',
                        '--task_name',
                        help='task name or task ID, required.',
                        default=default_plans_identifier,
                        required=True)

    parser.add_argument(
        '-tr',
        '--trainer_class_name',
        help=
        'Name of the nnUNetTrainer used for 2D U-Net, full resolution 3D U-Net and low resolution '
        'U-Net. The default is %s. If you are running inference with the cascade and the folder '
        'pointed to by --lowres_segmentations does not contain the segmentation maps generated by '
        'the low resolution U-Net then the low resolution segmentation maps will be automatically '
        'generated. For this case, make sure to set the trainer class here that matches your '
        '--cascade_trainer_class_name (this part can be ignored if defaults are used).'
        % default_trainer,
        required=False,
        default=default_trainer)
    parser.add_argument(
        '-ctr',
        '--cascade_trainer_class_name',
        help=
        "Trainer class name used for predicting the 3D full resolution U-Net part of the cascade."
        "Default is %s" % default_cascade_trainer,
        required=False,
        default=default_cascade_trainer)

    parser.add_argument(
        '-m',
        '--model',
        help=
        "2d, 3d_lowres, 3d_fullres or 3d_cascade_fullres. Default: 3d_fullres",
        default="3d_fullres",
        required=False)

    parser.add_argument(
        '-p',
        '--plans_identifier',
        help='do not touch this unless you know what you are doing',
        default=default_plans_identifier,
        required=False)

    parser.add_argument(
        '-f',
        '--folds',
        nargs='+',
        default='None',
        help=
        "folds to use for prediction. Default is None which means that folds will be detected "
        "automatically in the model output folder")

    parser.add_argument(
        '-z',
        '--save_npz',
        required=False,
        action='store_true',
        help=
        "use this if you want to ensemble these predictions with those of other models. Softmax "
        "probabilities will be saved as compressed numpy arrays in output_folder and can be "
        "merged between output_folders with nnUNet_ensemble_predictions")

    parser.add_argument(
        '-l',
        '--lowres_segmentations',
        required=False,
        default='None',
        help=
        "if model is the highres stage of the cascade then you can use this folder to provide "
        "predictions from the low resolution 3D U-Net. If this is left at default, the "
        "predictions will be generated automatically (provided that the 3D low resolution U-Net "
        "network weights are present")

    parser.add_argument("--part_id",
                        type=int,
                        required=False,
                        default=0,
                        help="Used to parallelize the prediction of "
                        "the folder over several GPUs. If you "
                        "want to use n GPUs to predict this "
                        "folder you need to run this command "
                        "n times with --part_id=0, ... n-1 and "
                        "--num_parts=n (each with a different "
                        "GPU (for example via "
                        "CUDA_VISIBLE_DEVICES=X)")

    parser.add_argument("--num_parts",
                        type=int,
                        required=False,
                        default=1,
                        help="Used to parallelize the prediction of "
                        "the folder over several GPUs. If you "
                        "want to use n GPUs to predict this "
                        "folder you need to run this command "
                        "n times with --part_id=0, ... n-1 and "
                        "--num_parts=n (each with a different "
                        "GPU (via "
                        "CUDA_VISIBLE_DEVICES=X)")

    parser.add_argument(
        "--num_threads_preprocessing",
        required=False,
        default=6,
        type=int,
        help=
        "Determines many background processes will be used for data preprocessing. Reduce this if you "
        "run into out of memory (RAM) problems. Default: 6")

    parser.add_argument(
        "--num_threads_nifti_save",
        required=False,
        default=2,
        type=int,
        help=
        "Determines many background processes will be used for segmentation export. Reduce this if you "
        "run into out of memory (RAM) problems. Default: 2")

    parser.add_argument(
        "--disable_tta",
        required=False,
        default=False,
        action="store_true",
        help=
        "set this flag to disable test time data augmentation via mirroring. Speeds up inference "
        "by roughly factor 4 (2D) or 8 (3D)")

    parser.add_argument(
        "--overwrite_existing",
        required=False,
        default=False,
        action="store_true",
        help=
        "Set this flag if the target folder contains predictions that you would like to overwrite"
    )

    parser.add_argument("--mode",
                        type=str,
                        default="normal",
                        required=False,
                        help="Hands off!")
    parser.add_argument("--all_in_gpu",
                        type=str,
                        default="None",
                        required=False,
                        help="can be None, False or True. "
                        "Do not touch.")
    parser.add_argument("--step_size",
                        type=float,
                        default=0.5,
                        required=False,
                        help="don't touch")
    # parser.add_argument("--interp_order", required=False, default=3, type=int,
    #                     help="order of interpolation for segmentations, has no effect if mode=fastest. Do not touch this.")
    # parser.add_argument("--interp_order_z", required=False, default=0, type=int,
    #                     help="order of interpolation along z is z is done differently. Do not touch this.")
    # parser.add_argument("--force_separate_z", required=False, default="None", type=str,
    #                     help="force_separate_z resampling. Can be None, True or False, has no effect if mode=fastest. "
    #                          "Do not touch this.")
    parser.add_argument(
        '-chk',
        help='checkpoint name, default: model_final_checkpoint',
        required=False,
        default='model_final_checkpoint')
    parser.add_argument(
        '--disable_mixed_precision',
        default=False,
        action='store_true',
        required=False,
        help=
        'Predictions are done with mixed precision by default. This improves speed and reduces '
        'the required vram. If you want to disable mixed precision you can set this flag. Note '
        'that yhis is not recommended (mixed precision is ~2x faster!)')
    ### ----------- added by Camila
    parser.add_argument(
        '--disable_sliding_window',
        default=False,
        action='store_true',
        required=False,
        help='Disable sliding window to predict the whole image')
    ### ----------- end added by Camila

    args = parser.parse_args()
    input_folder = args.input_folder
    output_folder = args.output_folder
    part_id = args.part_id
    num_parts = args.num_parts
    folds = args.folds
    save_npz = args.save_npz
    lowres_segmentations = args.lowres_segmentations
    num_threads_preprocessing = args.num_threads_preprocessing
    num_threads_nifti_save = args.num_threads_nifti_save
    disable_tta = args.disable_tta
    step_size = args.step_size
    # interp_order = args.interp_order
    # interp_order_z = args.interp_order_z
    # force_separate_z = args.force_separate_z
    overwrite_existing = args.overwrite_existing
    mode = args.mode
    all_in_gpu = args.all_in_gpu
    model = args.model
    trainer_class_name = args.trainer_class_name
    cascade_trainer_class_name = args.cascade_trainer_class_name
    ### ----------- added by Camila
    disable_sliding_window = args.disable_sliding_window
    ### ----------- end added by Camila

    task_name = args.task_name

    if not task_name.startswith("Task"):
        task_id = int(task_name)
        task_name = convert_id_to_task_name(task_id)

    assert model in ["2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"], "-m must be 2d, 3d_lowres, 3d_fullres or " \
                                                                             "3d_cascade_fullres"

    # if force_separate_z == "None":
    #     force_separate_z = None
    # elif force_separate_z == "False":
    #     force_separate_z = False
    # elif force_separate_z == "True":
    #     force_separate_z = True
    # else:
    #     raise ValueError("force_separate_z must be None, True or False. Given: %s" % force_separate_z)

    if lowres_segmentations == "None":
        lowres_segmentations = None

    if isinstance(folds, list):
        if folds[0] == 'all' and len(folds) == 1:
            pass
        else:
            folds = [int(i) for i in folds]
    elif folds == "None":
        folds = None
    else:
        raise ValueError("Unexpected value for argument folds")

    assert all_in_gpu in ['None', 'False', 'True']
    if all_in_gpu == "None":
        all_in_gpu = None
    elif all_in_gpu == "True":
        all_in_gpu = True
    elif all_in_gpu == "False":
        all_in_gpu = False

    # we need to catch the case where model is 3d cascade fullres and the low resolution folder has not been set.
    # In that case we need to try and predict with 3d low res first
    if model == "3d_cascade_fullres" and lowres_segmentations is None:
        print(
            "lowres_segmentations is None. Attempting to predict 3d_lowres first..."
        )
        assert part_id == 0 and num_parts == 1, "if you don't specify a --lowres_segmentations folder for the " \
                                                "inference of the cascade, custom values for part_id and num_parts " \
                                                "are not supported. If you wish to have multiple parts, please " \
                                                "run the 3d_lowres inference first (separately)"
        model_folder_name = join(
            network_training_output_dir, "3d_lowres", task_name,
            trainer_class_name + "__" + args.plans_identifier)
        assert isdir(
            model_folder_name
        ), "model output folder not found. Expected: %s" % model_folder_name
        lowres_output_folder = join(output_folder, "3d_lowres_predictions")
        predict_from_folder(model_folder_name,
                            input_folder,
                            lowres_output_folder,
                            folds,
                            False,
                            num_threads_preprocessing,
                            num_threads_nifti_save,
                            None,
                            part_id,
                            num_parts,
                            not disable_tta,
                            overwrite_existing=overwrite_existing,
                            mode=mode,
                            overwrite_all_in_gpu=all_in_gpu,
                            mixed_precision=not args.disable_mixed_precision,
                            step_size=step_size,
                            disable_sliding_window=disable_sliding_window)
        lowres_segmentations = lowres_output_folder
        torch.cuda.empty_cache()
        print("3d_lowres done")

    if model == "3d_cascade_fullres":
        trainer = cascade_trainer_class_name
    else:
        trainer = trainer_class_name

    model_folder_name = join(network_training_output_dir, model, task_name,
                             trainer + "__" + args.plans_identifier)
    print("using model stored in ", model_folder_name)
    assert isdir(
        model_folder_name
    ), "model output folder not found. Expected: %s" % model_folder_name

    predict_from_folder(model_folder_name,
                        input_folder,
                        output_folder,
                        folds,
                        save_npz,
                        num_threads_preprocessing,
                        num_threads_nifti_save,
                        lowres_segmentations,
                        part_id,
                        num_parts,
                        not disable_tta,
                        overwrite_existing=overwrite_existing,
                        mode=mode,
                        overwrite_all_in_gpu=all_in_gpu,
                        mixed_precision=not args.disable_mixed_precision,
                        step_size=step_size,
                        checkpoint_name=args.chk,
                        disable_sliding_window=disable_sliding_window)
Ejemplo n.º 8
0
                    plans_identifier=default_plans_identifier):
    copy_pretrained_models_for_task(taskname, output_folder, models,
                                    nnunet_trainer, nnunet_trainer_cascade,
                                    plans_identifier)
    copy_ensembles(taskname, output_folder, models,
                   (nnunet_trainer, nnunet_trainer_cascade),
                   (plans_identifier, ))
    compress_folder(join(output_folder, taskname + '.zip'),
                    join(output_folder, taskname))


if __name__ == "__main__":
    output_base = "/media/fabian/DeepLearningData/nnunet_trained_models"
    task_ids = [
        1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 17, 24, 27, 29, 35, 48, 55, 61, 38
    ]
    for t in task_ids:
        if t == 61:
            models = ("3d_fullres", )
        else:
            models = ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres")
        taskname = convert_id_to_task_name(t)
        print(taskname)
        output_folder = join(output_base, taskname)
        maybe_mkdir_p(output_folder)
        copy_pretrained_models_for_task(taskname, output_folder, models,
                                        'nnUNetTrainer', 'nnUNetTrainer',
                                        'nnUNetPlans')
        copy_ensembles(taskname, output_folder)
    compress_everything(output_base, 8)
Ejemplo n.º 9
0
def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-t",
        "--task_ids",
        nargs="+",
        help="List of integers belonging to the task ids you wish to run"
        " experiment planning and preprocessing for. Each of these "
        "ids must, have a matching folder 'TaskXXX_' in the raw "
        "data folder")
    parser.add_argument(
        "-pl3d",
        "--planner3d",
        type=str,
        default="ExperimentPlanner3D_v21",
        help=
        "Name of the ExperimentPlanner class for the full resolution 3D U-Net and U-Net cascade. "
        "Default is ExperimentPlanner3D_v21. Can be 'None', in which case these U-Nets will not be "
        "configured")
    parser.add_argument(
        "-pl2d",
        "--planner2d",
        type=str,
        default="ExperimentPlanner2D_v21",
        help=
        "Name of the ExperimentPlanner class for the 2D U-Net. Default is ExperimentPlanner2D_v21. "
        "Can be 'None', in which case this U-Net will not be configured")
    parser.add_argument(
        "-no_pp",
        action="store_true",
        help=
        "Set this flag if you dont want to run the preprocessing. If this is set then this script "
        "will only run the experiment planning and create the plans file")
    parser.add_argument(
        "-tl",
        type=int,
        required=False,
        default=8,
        help=
        "Number of processes used for preprocessing the low resolution data for the 3D low "
        "resolution U-Net. This can be larger than -tf. Don't overdo it or you will run out of "
        "RAM")
    parser.add_argument(
        "-tf",
        type=int,
        required=False,
        default=8,
        help=
        "Number of processes used for preprocessing the full resolution data of the 2D U-Net and "
        "3D U-Net. Don't overdo it or you will run out of RAM")
    parser.add_argument(
        "--verify_dataset_integrity",
        required=False,
        default=False,
        action="store_true",
        help=
        "set this flag to check the dataset integrity. This is useful and should be done once for "
        "each dataset!")
    args = parser.parse_args()
    task_ids = args.task_ids
    task_ids = [6]
    dont_run_preprocessing = args.no_pp
    tl = args.tl
    tf = args.tf
    planner_name3d = args.planner3d
    planner_name2d = args.planner2d

    if planner_name3d == "None":
        planner_name3d = None
    if planner_name2d == "None":
        planner_name2d = None

    # we need raw data
    tasks = []
    for i in task_ids:
        i = int(i)

        task_name = convert_id_to_task_name(i)
        #task_name = "Breast"

        if args.verify_dataset_integrity:
            verify_dataset_integrity(join(nnUNet_raw_data, task_name))

        crop(task_name, False, tf)

        tasks.append(task_name)

    search_in = join(nnunet.__path__[0], "experiment_planning")

    if planner_name3d is not None:
        planner_3d = recursive_find_python_class(
            [search_in],
            planner_name3d,
            current_module="nnunet.experiment_planning")
        if planner_3d is None:
            raise RuntimeError(
                "Could not find the Planner class %s. Make sure it is located somewhere in "
                "nnunet.experiment_planning" % planner_name3d)
    else:
        planner_3d = None

    if planner_name2d is not None:
        planner_2d = recursive_find_python_class(
            [search_in],
            planner_name2d,
            current_module="nnunet.experiment_planning")
        if planner_2d is None:
            raise RuntimeError(
                "Could not find the Planner class %s. Make sure it is located somewhere in "
                "nnunet.experiment_planning" % planner_name2d)
    else:
        planner_2d = None

    for t in tasks:
        print("\n\n\n", t)
        cropped_out_dir = os.path.join(nnUNet_cropped_data, t)
        preprocessing_output_dir_this_task = os.path.join(
            preprocessing_output_dir, t)
        #splitted_4d_output_dir_task = os.path.join(nnUNet_raw_data, t)
        #lists, modalities = create_lists_from_splitted_dataset(splitted_4d_output_dir_task)

        # we need to figure out if we need the intensity propoerties. We collect them only if one of the modalities is CT
        dataset_json = load_json(join(cropped_out_dir, 'dataset.json'))
        modalities = list(dataset_json["modality"].values())
        collect_intensityproperties = True if (("CT" in modalities) or
                                               ("ct" in modalities)) else False
        dataset_analyzer = DatasetAnalyzer(
            cropped_out_dir, overwrite=False,
            num_processes=tf)  # this class creates the fingerprint
        _ = dataset_analyzer.analyze_dataset(
            collect_intensityproperties
        )  # this will write output files that will be used by the ExperimentPlanner

        #maybe_mkdir_p(preprocessing_output_dir_this_task)
        if not os.path.isdir(preprocessing_output_dir_this_task):
            os.makedirs(preprocessing_output_dir_this_task)
        shutil.copy(join(cropped_out_dir, "dataset_properties.pkl"),
                    preprocessing_output_dir_this_task)
        shutil.copy(join(nnUNet_raw_data, t, "dataset.json"),
                    preprocessing_output_dir_this_task)

        threads = (tl, tf)

        print("number of threads: ", threads, "\n")

        if planner_3d is not None:
            exp_planner = planner_3d(cropped_out_dir,
                                     preprocessing_output_dir_this_task)
            exp_planner.plan_experiment()
            if not dont_run_preprocessing:  # double negative, yooo
                exp_planner.run_preprocessing(threads)
        if planner_2d is not None:
            exp_planner = planner_2d(cropped_out_dir,
                                     preprocessing_output_dir_this_task)
            exp_planner.plan_experiment()
            if not dont_run_preprocessing:  # double negative, yooo
                exp_planner.run_preprocessing(threads)
Ejemplo n.º 10
0
def main():
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-t",
        "--task_ids",
        nargs="+",
        help="List of integers belonging to the task ids you wish to run"
        " experiment planning and preprocessing for. Each of these "
        "ids must, have a matching folder 'TaskXXX_' in the raw "
        "data folder")
    parser.add_argument(
        "-pl3d",
        "--planner3d",
        type=str,
        default="ExperimentPlanner3D_v21",
        help=
        "Name of the ExperimentPlanner class for the full resolution 3D U-Net and U-Net cascade. "
        "Default is ExperimentPlanner3D_v21. Can be 'None', in which case these U-Nets will not be "
        "configured")
    parser.add_argument(
        "-pl2d",
        "--planner2d",
        type=str,
        default="ExperimentPlanner2D_v21",
        help=
        "Name of the ExperimentPlanner class for the 2D U-Net. Default is ExperimentPlanner2D_v21. "
        "Can be 'None', in which case this U-Net will not be configured")
    parser.add_argument(
        "-no_pp",
        action="store_true",
        help=
        "Set this flag if you dont want to run the preprocessing. If this is set then this script "
        "will only run the experiment planning and create the plans file")
    parser.add_argument(
        "-tl",
        type=int,
        required=False,
        default=8,
        help=
        "Number of processes used for preprocessing the low resolution data for the 3D low "
        "resolution U-Net. This can be larger than -tf. Don't overdo it or you will run out of "
        "RAM")
    parser.add_argument(
        "-tf",
        type=int,
        required=False,
        default=8,
        help=
        "Number of processes used for preprocessing the full resolution data of the 2D U-Net and "
        "3D U-Net. Don't overdo it or you will run out of RAM")
    parser.add_argument(
        "--verify_dataset_integrity",
        required=False,
        default=False,
        action="store_true",
        help=
        "set this flag to check the dataset integrity. This is useful and should be done once for "
        "each dataset!")
    parser.add_argument(
        "-overwrite_plans",
        type=str,
        default=None,
        required=False,
        help=
        "Use this to specify a plans file that should be used instead of whatever nnU-Net would "
        "configure automatically. This will overwrite everything: intensity normalization, "
        "network architecture, target spacing etc. Using this is useful for using pretrained "
        "model weights as this will guarantee that the network architecture on the target "
        "dataset is the same as on the source dataset and the weights can therefore be transferred.\n"
        "Pro tip: If you want to pretrain on Hepaticvessel and apply the result to LiTS then use "
        "the LiTS plans to run the preprocessing of the HepaticVessel task.\n"
        "Make sure to only use plans files that were "
        "generated with the same number of modalities as the target dataset (LiTS -> BCV or "
        "LiTS -> Task008_HepaticVessel is OK. BraTS -> LiTS is not (BraTS has 4 input modalities, "
        "LiTS has just one)). Also only do things that make sense. This functionality is beta with"
        "no support given.\n"
        "Note that this will first print the old plans (which are going to be overwritten) and "
        "then the new ones (provided that -no_pp was NOT set).")
    parser.add_argument(
        "-overwrite_plans_identifier",
        type=str,
        default=None,
        required=False,
        help=
        "If you set overwrite_plans you need to provide a unique identifier so that nnUNet knows "
        "where to look for the correct plans and data. Assume your identifier is called "
        "IDENTIFIER, the correct training command would be:\n"
        "'nnUNet_train CONFIG TRAINER TASKID FOLD -p nnUNetPlans_pretrained_IDENTIFIER "
        "-pretrained_weights FILENAME'")

    args = parser.parse_args()
    task_ids = args.task_ids
    dont_run_preprocessing = args.no_pp
    tl = args.tl
    tf = args.tf
    planner_name3d = args.planner3d
    planner_name2d = args.planner2d

    if planner_name3d == "None":
        planner_name3d = None
    if planner_name2d == "None":
        planner_name2d = None

    if args.overwrite_plans is not None:
        if planner_name2d is not None:
            print(
                "Overwriting plans only works for the 3d planner. I am setting '--planner2d' to None. This will "
                "skip 2d planning and preprocessing.")
        assert planner_name3d == 'ExperimentPlanner3D_v21_Pretrained', "When using --overwrite_plans you need to use " \
                                                                       "'-pl3d ExperimentPlanner3D_v21_Pretrained'"

    # we need raw data
    tasks = []
    for i in task_ids:
        i = int(i)

        task_name = convert_id_to_task_name(i)

        if args.verify_dataset_integrity:
            verify_dataset_integrity(join(nnUNet_raw_data, task_name))

        crop(task_name, False, tf)

        tasks.append(task_name)

    search_in = join(nnunet.__path__[0], "experiment_planning")

    if planner_name3d is not None:
        planner_3d = recursive_find_python_class(
            [search_in],
            planner_name3d,
            current_module="nnunet.experiment_planning")
        if planner_3d is None:
            raise RuntimeError(
                "Could not find the Planner class %s. Make sure it is located somewhere in "
                "nnunet.experiment_planning" % planner_name3d)
    else:
        planner_3d = None

    if planner_name2d is not None:
        planner_2d = recursive_find_python_class(
            [search_in],
            planner_name2d,
            current_module="nnunet.experiment_planning")
        if planner_2d is None:
            raise RuntimeError(
                "Could not find the Planner class %s. Make sure it is located somewhere in "
                "nnunet.experiment_planning" % planner_name2d)
    else:
        planner_2d = None

    for t in tasks:
        print("\n\n\n", t)
        cropped_out_dir = os.path.join(nnUNet_cropped_data, t)
        preprocessing_output_dir_this_task = os.path.join(
            preprocessing_output_dir, t)
        #splitted_4d_output_dir_task = os.path.join(nnUNet_raw_data, t)
        #lists, modalities = create_lists_from_splitted_dataset(splitted_4d_output_dir_task)

        # we need to figure out if we need the intensity propoerties. We collect them only if one of the modalities is CT
        dataset_json = load_json(join(cropped_out_dir, 'dataset.json'))
        modalities = list(dataset_json["modality"].values())
        collect_intensityproperties = True if (("CT" in modalities) or
                                               ("ct" in modalities)) else False
        dataset_analyzer = DatasetAnalyzer(
            cropped_out_dir, overwrite=False,
            num_processes=tf)  # this class creates the fingerprint
        _ = dataset_analyzer.analyze_dataset(
            collect_intensityproperties
        )  # this will write output files that will be used by the ExperimentPlanner

        maybe_mkdir_p(preprocessing_output_dir_this_task)
        shutil.copy(join(cropped_out_dir, "dataset_properties.pkl"),
                    preprocessing_output_dir_this_task)
        shutil.copy(join(nnUNet_raw_data, t, "dataset.json"),
                    preprocessing_output_dir_this_task)

        threads = (tl, tf)

        print("number of threads: ", threads, "\n")

        if planner_3d is not None:
            if args.overwrite_plans is not None:
                assert args.overwrite_plans_identifier is not None, "You need to specify -overwrite_plans_identifier"
                exp_planner = planner_3d(cropped_out_dir,
                                         preprocessing_output_dir_this_task,
                                         args.overwrite_plans,
                                         args.overwrite_plans_identifier)
            else:
                exp_planner = planner_3d(cropped_out_dir,
                                         preprocessing_output_dir_this_task)
            exp_planner.plan_experiment()
            if not dont_run_preprocessing:  # double negative, yooo
                exp_planner.run_preprocessing(threads)
        if planner_2d is not None:
            exp_planner = planner_2d(cropped_out_dir,
                                     preprocessing_output_dir_this_task)
            exp_planner.plan_experiment()
            if not dont_run_preprocessing:  # double negative, yooo
                exp_planner.run_preprocessing(threads)
Ejemplo n.º 11
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("gpu", help='0, 1, ..., 5 or \'all\'')
    parser.add_argument("-val",
                        "--validation_only",
                        help="use this if you want to only run the validation",
                        action="store_true")
    parser.add_argument("-c",
                        "--continue_training",
                        help="use this if you want to continue a training",
                        action="store_true")
    parser.add_argument(
        "-p",
        help=
        "plans identifier. Only change this if you created a custom experiment planner",
        default=default_plans_identifier,
        required=False)
    parser.add_argument(
        "--use_compressed_data",
        default=False,
        action="store_true",
        help=
        "If you set use_compressed_data, the training cases will not be decompressed. Reading compressed data "
        "is much more CPU and RAM intensive and should only be used if you know what you are "
        "doing",
        required=False)
    parser.add_argument(
        "--deterministic",
        help=
        "Makes training deterministic, but reduces training speed substantially. I (Fabian) think "
        "this is not necessary. Deterministic training will make you overfit to some random seed. "
        "Don't use that.",
        required=False,
        default=False,
        action="store_true")
    parser.add_argument("--npz",
                        required=False,
                        default=False,
                        action="store_true",
                        help="if set then nnUNet will "
                        "export npz files of "
                        "predicted segmentations "
                        "in the validation as well. "
                        "This is needed to run the "
                        "ensembling step so unless "
                        "you are developing nnUNet "
                        "you should enable this")
    parser.add_argument("--find_lr",
                        required=False,
                        default=False,
                        action="store_true",
                        help="not used here, just for fun")
    parser.add_argument("--valbest",
                        required=False,
                        default=False,
                        action="store_true",
                        help="hands off. This is not intended to be used")
    parser.add_argument(
        "--fp32",
        required=False,
        default=False,
        action="store_true",
        help="disable mixed precision training and run old school fp32")
    parser.add_argument(
        "--val_folder",
        required=False,
        default="validation_raw",
        help=
        "name of the validation folder. No need to use this for most people")
    parser.add_argument(
        "--interp_order",
        required=False,
        default=3,
        type=int,
        help=
        "order of interpolation for segmentations. Testing purpose only. Hands off"
    )
    parser.add_argument(
        "--interp_order_z",
        required=False,
        default=0,
        type=int,
        help=
        "order of interpolation along z if z is resampled separately. Testing purpose only. "
        "Hands off")
    parser.add_argument(
        "--force_separate_z",
        required=False,
        default="None",
        type=str,
        help=
        "force_separate_z resampling. Can be None, True or False. Testing purpose only. Hands off"
    )

    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    tasks = ["100", "101", "102", "103", "104"]
    # tasks = ["100"]
    # fold = args.fold
    fold = "4"
    # network = args.network
    network = "3d_fullres"
    network_trainer = "nnUNetMultiTrainerV2"
    validation_only = args.validation_only
    plans_identifier = args.p
    find_lr = args.find_lr

    use_compressed_data = args.use_compressed_data
    decompress_data = not use_compressed_data

    deterministic = args.deterministic
    valbest = args.valbest

    fp32 = args.fp32
    run_mixed_precision = not fp32

    val_folder = args.val_folder
    # val_folder = "mk_validation"   #temp_validation
    interp_order = args.interp_order
    interp_order_z = args.interp_order_z
    force_separate_z = args.force_separate_z

    classes_dict = {}
    for i, task in enumerate(tasks):
        if not task.startswith("Task"):
            task_id = int(task)
            task = convert_id_to_task_name(task_id)
        tasks[i] = task

        json_file = join(preprocessing_output_dir, task, "dataset.json")
        classes = []
        with open(json_file) as jsn:
            d = json.load(jsn)
            tags = d['labels']
            for i in tags:
                if not int(i) == 0:  #bkg not in tag
                    classes.append(tags[i])
            classes_dict[task] = classes
    # print("task:",tasks)# ['Task100_MALB', 'Task101_Liver', 'Task102_Spleen', 'Task103_Pancreas', 'Task104_KiTS']
    # print("classes_dict", classes_dict)#{'Task100_MALB': ['Liver', 'Spleen', 'Pancreas', 'LeftKidney', 'RightKidney'],...}
    if fold == 'all':
        pass
    else:
        fold = int(fold)

    if force_separate_z == "None":
        force_separate_z = None
    elif force_separate_z == "False":
        force_separate_z = False
    elif force_separate_z == "True":
        force_separate_z = True
    else:
        raise ValueError(
            "force_separate_z must be None, True or False. Given: %s" %
            force_separate_z)

    plans_file, output_folder_names, dataset_directorys, batch_dice, stage, \
        trainer_class = get_default_configuration_with_multiTask(
            network, tasks, network_trainer, plans_identifier)

    if trainer_class is None:
        raise RuntimeError(
            "Could not find trainer class in nnunet.training.network_training")

    if network == "3d_cascade_fullres":
        assert issubclass(trainer_class, (nnUNetTrainerCascadeFullRes, nnUNetTrainerV2CascadeFullRes)), \
            "If running 3d_cascade_fullres then your " \
            "trainer class must be derived from " \
            "nnUNetTrainerCascadeFullRes"
    else:
        assert issubclass(
            trainer_class, nnUNetTrainer
        ), "network_trainer was found but is not derived from nnUNetMultiTrainer"

    trainer = trainer_class(plans_file,
                            fold,
                            tasks,
                            tags=classes_dict,
                            output_folder_dict=output_folder_names,
                            dataset_directory_dict=dataset_directorys,
                            batch_dice=batch_dice,
                            stage=0,
                            unpack_data=decompress_data,
                            deterministic=deterministic,
                            fp16=run_mixed_precision)

    trainer.initialize(not validation_only)

    if find_lr:
        trainer.find_lr()
    else:
        if not validation_only:
            if args.continue_training:
                trainer.load_latest_checkpoint()
            trainer.run_training()  #training
        else:
            if valbest:
                trainer.load_best_checkpoint(train=False)
            else:
                trainer.load_latest_checkpoint(train=False)

        trainer.network.eval()

        # predict validation
        for task in tasks:
            print(f"test task: {task}")
            trainer.validate_specific_data(
                task,
                save_softmax=args.npz,
                validation_folder_name=val_folder,
                force_separate_z=force_separate_z,
                overwrite=True,
                interpolation_order=interp_order,
                interpolation_order_z=interp_order_z)
Ejemplo n.º 12
0
def main():
    import argparse
    parser = argparse.ArgumentParser(description="We extend nnUNet to offer self-supervision tasks. This step is to"
                                                 " split the dataset into two - self-supervision input and self- "
                                                 "supervision output folder.")
    parser.add_argument("-t", type=int, help="Task id. The task name you wish to run self-supervision task for. "
                                             "It must have a matching folder 'TaskXXX_' in the raw "
                                             "data folder", required=True)
    parser.add_argument("-ss_tasks", nargs="+",
                        help="Self-supervision Tasks. Specify which self-supervision task you wish to "
                             "run. Current supported tasks: context_restoration| jigsaw_puzzle | byol")
    parser.add_argument("-p", default=default_num_threads, type=int, required=False,
                        help="Use this to specify how many processes are used to run the script. "
                             "Default is %d" % default_num_threads)
    parser.add_argument("--verify_dataset_integrity", required=False, default=False, action="store_true",
                        help="set this flag to check the dataset integrity. This is useful and should be done once for "
                             "each dataset!")
    args = parser.parse_args()

    ss_tasks = args.ss_tasks
    base = join(os.environ['nnUNet_raw_data_base'], 'nnUNet_raw_data')
    task_name = convert_id_to_task_name(args.t)
    target_base = join(base, task_name)

    import json
    with open(join(target_base, 'dataset.json')) as json_file:
        updated_json_file = json.load(json_file).copy()

    if "context_restoration" in ss_tasks:
        pretext_task = convert_pretext_task("context_restoration")
        file_names = generate_augmented_datasets(pretext_task, target_base, swap_image)
        updated_json_file['contextRestoration'] = [{'input': f"./ssInput{pretext_task}/_%s" % i.split("/")[-1],
                                                    "output": f"./ssOutput{pretext_task}/%s" % i.split("/")[-1]} \
                                                   for i in file_names]
        if args.verify_dataset_integrity:
            verify_dataset_integrity(target_base, "context_restoration")

    if "jigsaw_puzzle" in ss_tasks:
        pretext_task = convert_pretext_task("jigsaw_puzzle")
        file_names = generate_augmented_datasets(pretext_task, target_base, swap_image)
        updated_json_file['jigsawPuzzle'] = [{'input': f"./ssInput{pretext_task}/_%s" % i.split("/")[-1],
                                              "output": f"./ssOutput{pretext_task}/%s" % i.split("/")[-1]} \
                                             for i in file_names]
        if args.verify_dataset_integrity:
            verify_dataset_integrity(target_base, "jigsaw_puzzle")

    if "byol" in ss_tasks:
        pretext_task = convert_pretext_task("byol")
        file_names = generate_augmented_datasets(pretext_task, target_base, byol_aug)
        updated_json_file['byol'] = [{'input': f"./ssInput{pretext_task}/_%s" % i.split("/")[-1],
                                      "output": f"./ssOutput{pretext_task}/%s" % i.split("/")[-1]} \
                                     for i in file_names]
        if args.verify_dataset_integrity:
            verify_dataset_integrity(target_base, "byol")

    # remove the original dataset.json
    os.remove(join(target_base, 'dataset.json'))
    # remove the modified dataset.json
    save_json(updated_json_file, join(target_base, "dataset.json"))
    print('Updated dataset.json')

    print('Preparation for self supervision succeeded! Move on to the plan_and_preprocessing stage.')
def manually_set_configurations():
    """
    ALSO NOT USED!
    :return:
    """
    task115_dir = join(preprocessing_output_dir, convert_id_to_task_name(115))

    ## larger patch size

    # task115 3d_fullres default is:
    """
    {'batch_size': 2, 
    'num_pool_per_axis': [2, 6, 6], 
    'patch_size': array([ 28, 256, 256]), 
    'median_patient_size_in_voxels': array([ 62, 512, 512]), 
    'current_spacing': array([5.        , 0.74199998, 0.74199998]), 
    'original_spacing': array([5.        , 0.74199998, 0.74199998]), 
    'do_dummy_2D_data_aug': True, 
    'pool_op_kernel_sizes': [[1, 2, 2], [1, 2, 2], [2, 2, 2], [2, 2, 2], [1, 2, 2], [1, 2, 2]], 
    'conv_kernel_sizes': [[1, 3, 3], [1, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]}
    """
    plans = load_pickle(join(task115_dir, 'nnUNetPlansv2.1_plans_3D.pkl'))
    fullres_stage = plans['plans_per_stage'][1]
    fullres_stage['patch_size'] = np.array([ 64, 320, 320])
    fullres_stage['num_pool_per_axis'] = [4, 6, 6]
    fullres_stage['pool_op_kernel_sizes'] = [[1, 2, 2],
                                            [1, 2, 2],
                                            [2, 2, 2],
                                            [2, 2, 2],
                                            [2, 2, 2],
                                            [2, 2, 2]]
    fullres_stage['conv_kernel_sizes'] = [[1, 3, 3],
                                        [1, 3, 3],
                                        [3, 3, 3],
                                        [3, 3, 3],
                                        [3, 3, 3],
                                        [3, 3, 3],
                                        [3, 3, 3]]

    save_pickle(plans, join(task115_dir, 'nnUNetPlansv2.1_custom_plans_3D.pkl'))

    ## larger batch size
    # (default for all 3d trainings is batch size 2)
    increase_batch_size(join(task115_dir, 'nnUNetPlansv2.1_plans_3D.pkl'), join(task115_dir, 'nnUNetPlansv2.1_bs3x_plans_3D.pkl'), 3)
    increase_batch_size(join(task115_dir, 'nnUNetPlansv2.1_plans_3D.pkl'), join(task115_dir, 'nnUNetPlansv2.1_bs5x_plans_3D.pkl'), 5)

    # residual unet
    """
    default is:
    Out[7]: 
    {'batch_size': 2,
     'num_pool_per_axis': [2, 6, 5],
     'patch_size': array([ 28, 256, 224]),
     'median_patient_size_in_voxels': array([ 62, 512, 512]),
     'current_spacing': array([5.        , 0.74199998, 0.74199998]),
     'original_spacing': array([5.        , 0.74199998, 0.74199998]),
     'do_dummy_2D_data_aug': True,
     'pool_op_kernel_sizes': [[1, 1, 1],
      [1, 2, 2],
      [1, 2, 2],
      [2, 2, 2],
      [2, 2, 2],
      [1, 2, 2],
      [1, 2, 1]],
     'conv_kernel_sizes': [[1, 3, 3],
      [1, 3, 3],
      [3, 3, 3],
      [3, 3, 3],
      [3, 3, 3],
      [3, 3, 3],
      [3, 3, 3]],
     'num_blocks_encoder': (1, 2, 3, 4, 4, 4, 4),
     'num_blocks_decoder': (1, 1, 1, 1, 1, 1)}
    """
    plans = load_pickle(join(task115_dir, 'nnUNetPlans_FabiansResUNet_v2.1_plans_3D.pkl'))
    fullres_stage = plans['plans_per_stage'][1]
    fullres_stage['patch_size'] = np.array([ 56, 256, 256])
    fullres_stage['num_pool_per_axis'] = [3, 6, 6]
    fullres_stage['pool_op_kernel_sizes'] = [[1, 1, 1],
                                             [1, 2, 2],
                                             [1, 2, 2],
                                            [2, 2, 2],
                                            [2, 2, 2],
                                            [2, 2, 2],
                                            [1, 2, 2]]
    fullres_stage['conv_kernel_sizes'] = [[1, 3, 3],
                                        [1, 3, 3],
                                        [3, 3, 3],
                                        [3, 3, 3],
                                        [3, 3, 3],
                                        [3, 3, 3],
                                        [3, 3, 3]]
    save_pickle(plans, join(task115_dir, 'nnUNetPlans_FabiansResUNet_v2.1_custom_plans_3D.pkl'))
Ejemplo n.º 14
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-gpu", type=str, default='0')

    parser.add_argument("-network", type=str, default='3d_fullres')
    parser.add_argument("-network_trainer",
                        type=str,
                        default='nnUNetTrainerV2_ResTrans')
    parser.add_argument("-task",
                        type=str,
                        default='17',
                        help="can be task name or task id")
    parser.add_argument("-fold",
                        type=str,
                        default='all',
                        help='0, 1, ..., 5 or \'all\'')
    parser.add_argument("-outpath",
                        type=str,
                        default='Trainer_CoTr',
                        help='output path')
    parser.add_argument("-norm_cfg",
                        type=str,
                        default='IN',
                        help='BN, IN or GN')
    parser.add_argument("-activation_cfg",
                        type=str,
                        default='LeakyReLU',
                        help='LeakyReLU or ReLU')

    parser.add_argument("-val",
                        "--validation_only",
                        default=False,
                        help="use this if you want to only run the validation",
                        required=False,
                        action="store_true")
    parser.add_argument("-c",
                        "--continue_training",
                        help="use this if you want to continue a training",
                        action="store_true")
    parser.add_argument(
        "-p",
        help=
        "plans identifier. Only change this if you created a custom experiment planner",
        default=default_plans_identifier,
        required=False)
    parser.add_argument(
        "--use_compressed_data",
        default=False,
        action="store_true",
        help=
        "If you set use_compressed_data, the training cases will not be decompressed. Reading compressed data "
        "is much more CPU and RAM intensive and should only be used if you know what you are "
        "doing",
        required=False)
    parser.add_argument("--deterministic", default=False, action="store_true")
    parser.add_argument("--npz",
                        required=False,
                        default=False,
                        action="store_true",
                        help="if set then nnUNet will "
                        "export npz files of "
                        "predicted segmentations "
                        "in the validation as well. "
                        "This is needed to run the "
                        "ensembling step so unless "
                        "you are developing nnUNet "
                        "you should enable this")
    parser.add_argument("--find_lr",
                        required=False,
                        default=False,
                        action="store_true",
                        help="not used here, just for fun")
    parser.add_argument("--valbest",
                        required=False,
                        default=False,
                        action="store_true",
                        help="hands off. This is not intended to be used")
    parser.add_argument(
        "--fp32",
        required=False,
        default=False,
        action="store_true",
        help="disable mixed precision training and run old school fp32")
    parser.add_argument(
        "--val_folder",
        required=False,
        default="validation_raw",
        help=
        "name of the validation folder. No need to use this for most people")
    parser.add_argument("--disable_saving",
                        required=False,
                        action='store_true')

    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    norm_cfg = args.norm_cfg
    activation_cfg = args.activation_cfg
    outpath = args.outpath + '_' + norm_cfg + '_' + activation_cfg

    task = args.task
    fold = args.fold
    network = args.network
    network_trainer = args.network_trainer
    validation_only = args.validation_only
    plans_identifier = args.p
    find_lr = args.find_lr

    use_compressed_data = args.use_compressed_data
    decompress_data = not use_compressed_data

    deterministic = args.deterministic
    valbest = args.valbest

    fp32 = args.fp32
    run_mixed_precision = not fp32

    val_folder = args.val_folder

    if validation_only and (norm_cfg == 'SyncBN'):
        norm_cfg == 'BN'

    if not task.startswith("Task"):
        task_id = int(task)
        task = convert_id_to_task_name(task_id)

    if fold == 'all':
        pass
    else:
        fold = int(fold)

    plans_file, output_folder_name, dataset_directory, batch_dice, stage, \
    trainer_class = get_default_configuration(outpath, network, task, network_trainer, plans_identifier, \
                                              search_in=(CoTr.__path__[0], "training", "network_training"), \
                                              base_module='CoTr.training.network_training')

    trainer = trainer_class(plans_file,
                            fold,
                            norm_cfg,
                            activation_cfg,
                            output_folder=output_folder_name,
                            dataset_directory=dataset_directory,
                            batch_dice=batch_dice,
                            stage=stage,
                            unpack_data=decompress_data,
                            deterministic=deterministic,
                            fp16=run_mixed_precision)

    if args.disable_saving:
        trainer.save_latest_only = False  # if false it will not store/overwrite _latest but separate files each
        trainer.save_intermediate_checkpoints = False  # whether or not to save checkpoint_latest
        trainer.save_best_checkpoint = False  # whether or not to save the best checkpoint according to self.best_val_eval_criterion_MA
        trainer.save_final_checkpoint = False  # whether or not to save the final checkpoint

    trainer.initialize(not validation_only)

    if find_lr:
        trainer.find_lr()
    else:
        if not validation_only:
            if args.continue_training:
                trainer.load_latest_checkpoint()
            trainer.run_training()
        else:
            if valbest:
                trainer.load_best_checkpoint(train=False)
            else:
                trainer.load_latest_checkpoint(train=False)

        trainer.network.eval()

        # predict validation
        trainer.validate(save_softmax=args.npz,
                         validation_folder_name=val_folder)

        if network == '3d_lowres':
            print("predicting segmentations for the next stage of the cascade")
            predict_next_stage(
                trainer,
                join(dataset_directory,
                     trainer.plans['data_identifier'] + "_stage%d" % 1))
Ejemplo n.º 15
0
def main():
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("-t", "--task_ids", nargs="+", default=["4"])
    parser.add_argument("-pl3d", "--planner3d", type=str, default="ExperimentPlanner3D_v21")
    parser.add_argument("-pl2d", "--planner2d", type=str, default="ExperimentPlanner2D_v21")
    parser.add_argument("-no_pp", action="store_true", default=False)
    parser.add_argument("-tl", type=int, required=False, default=8)
    parser.add_argument("-tf", type=int, required=False, default=8)
    parser.add_argument("--verify_dataset_integrity", required=False, default=False, action="store_true")

    args = parser.parse_args()

    task_ids = args.task_ids
    dont_run_preprocessing = args.no_pp
    tl = args.tl
    tf = args.tf
    planner_name3d = args.planner3d
    planner_name2d = args.planner2d

    if planner_name3d == "None":
        planner_name3d = None
    if planner_name2d == "None":
        planner_name2d = None

    # we need raw data
    tasks = []
    for i in task_ids:
        i = int(i)

        task_name = convert_id_to_task_name(i)

        if args.verify_dataset_integrity:
            verify_dataset_integrity(join(nnUNet_raw_data, task_name))

        crop(task_name, False, tf)

        tasks.append(task_name)

    search_in = join(nnunet.__path__[0], "experiment_planning")

    if planner_name3d is not None:
        planner_3d = recursive_find_python_class([search_in], planner_name3d, current_module="nnunet.experiment_planning")
        if planner_3d is None:
            raise RuntimeError("Could not find the Planner class %s. Make sure it is located somewhere in "
                               "nnunet.experiment_planning" % planner_name3d)
    else:
        planner_3d = None

    if planner_name2d is not None:
        planner_2d = recursive_find_python_class([search_in], planner_name2d, current_module="nnunet.experiment_planning")
        if planner_2d is None:
            raise RuntimeError("Could not find the Planner class %s. Make sure it is located somewhere in "
                               "nnunet.experiment_planning" % planner_name2d)
    else:
        planner_2d = None

    for t in tasks:
        print("\n\n\n", t)
        cropped_out_dir = os.path.join(nnUNet_cropped_data, t)
        preprocessing_output_dir_this_task = os.path.join(preprocessing_output_dir, t)
        #splitted_4d_output_dir_task = os.path.join(nnUNet_raw_data, t)
        #lists, modalities = create_lists_from_splitted_dataset(splitted_4d_output_dir_task)

        dataset_analyzer = DatasetAnalyzer(cropped_out_dir, overwrite=False)  # this class creates the fingerprint
        _ = dataset_analyzer.analyze_dataset()  # this will write output files that will be used by the ExperimentPlanner

        maybe_mkdir_p(preprocessing_output_dir_this_task)
        shutil.copy(join(cropped_out_dir, "dataset_properties.pkl"), preprocessing_output_dir_this_task)
        shutil.copy(join(nnUNet_raw_data, t, "dataset.json"), preprocessing_output_dir_this_task)

        threads = (tl, tf)

        print("number of threads: ", threads, "\n")

        if planner_3d is not None:
            exp_planner = planner_3d(cropped_out_dir, preprocessing_output_dir_this_task)
            exp_planner.plan_experiment()
            if not dont_run_preprocessing:  # double negative, yooo
                exp_planner.run_preprocessing(threads)
        if planner_2d is not None:
            exp_planner = planner_2d(cropped_out_dir, preprocessing_output_dir_this_task)
            exp_planner.plan_experiment()
            if not dont_run_preprocessing:  # double negative, yooo
                exp_planner.run_preprocessing(threads)
Ejemplo n.º 16
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("network")
    parser.add_argument("network_trainer")
    parser.add_argument("task", help="can be task name or task id")
    parser.add_argument("fold", help='0, 1, ..., 5 or \'all\'')
    parser.add_argument("-val", "--validation_only", help="use this if you want to only run the validation",
                        action="store_true")
    parser.add_argument("-c", "--continue_training", help="use this if you want to continue a training",
                        action="store_true")
    parser.add_argument("-p", help="plans identifier. Only change this if you created a custom experiment planner",
                        default=default_plans_identifier, required=False)
    parser.add_argument("--use_compressed_data", default=False, action="store_true",
                        help="If you set use_compressed_data, the training cases will not be decompressed. Reading compressed data "
                             "is much more CPU and RAM intensive and should only be used if you know what you are "
                             "doing", required=False)
    parser.add_argument("--deterministic",
                        help="Makes training deterministic, but reduces training speed substantially. I (Fabian) think "
                             "this is not necessary. Deterministic training will make you overfit to some random seed. "
                             "Don't use that.",
                        required=False, default=False, action="store_true")
    parser.add_argument("--npz", required=False, default=False, action="store_true", help="if set then nnUNet will "
                                                                                          "export npz files of "
                                                                                          "predicted segmentations "
                                                                                          "in the validation as well. "
                                                                                          "This is needed to run the "
                                                                                          "ensembling step so unless "
                                                                                          "you are developing nnUNet "
                                                                                          "you should enable this")
    parser.add_argument("--find_lr", required=False, default=False, action="store_true",
                        help="not used here, just for fun")
    parser.add_argument("--valbest", required=False, default=False, action="store_true",
                        help="hands off. This is not intended to be used")
    parser.add_argument("--fp32", required=False, default=False, action="store_true",
                        help="disable mixed precision training and run old school fp32")
    parser.add_argument("--val_folder", required=False, default="validation_raw",
                        help="name of the validation folder. No need to use this for most people")
    parser.add_argument("--disable_saving", required=False, action='store_true',
                        help="If set nnU-Net will not save any parameter files (except a temporary checkpoint that "
                             "will be removed at the end of the training). Useful for development when you are "
                             "only interested in the results and want to save some disk space")
    parser.add_argument("--disable_postprocessing_on_folds", required=False, action='store_true',
                        help="Running postprocessing on each fold only makes sense when developing with nnU-Net and "
                             "closely observing the model performance on specific configurations. You do not need it "
                             "when applying nnU-Net because the postprocessing for this will be determined only once "
                             "all five folds have been trained and nnUNet_find_best_configuration is called. Usually "
                             "running postprocessing on each fold is computationally cheap, but some users have "
                             "reported issues with very large images. If your images are large (>600x600x600 voxels) "
                             "you should consider setting this flag.")
    # parser.add_argument("--interp_order", required=False, default=3, type=int,
    #                     help="order of interpolation for segmentations. Testing purpose only. Hands off")
    # parser.add_argument("--interp_order_z", required=False, default=0, type=int,
    #                     help="order of interpolation along z if z is resampled separately. Testing purpose only. "
    #                          "Hands off")
    # parser.add_argument("--force_separate_z", required=False, default="None", type=str,
    #                     help="force_separate_z resampling. Can be None, True or False. Testing purpose only. Hands off")
    parser.add_argument('--val_disable_overwrite', action='store_false', default=True,
                        help='Validation does not overwrite existing segmentations')
    parser.add_argument('--disable_next_stage_pred', action='store_true', default=False,
                        help='do not predict next stage')
    parser.add_argument('-pretrained_weights', type=str, required=False, default=None,
                        help='path to nnU-Net checkpoint file to be used as pretrained model (use .model '
                             'file, for example model_final_checkpoint.model). Will only be used when actually training. '
                             'Optional. Beta. Use with caution.')

    args = parser.parse_args()

    task = args.task
    fold = args.fold
    network = args.network
    network_trainer = args.network_trainer
    validation_only = args.validation_only
    plans_identifier = args.p
    find_lr = args.find_lr
    disable_postprocessing_on_folds = args.disable_postprocessing_on_folds

    use_compressed_data = args.use_compressed_data
    decompress_data = not use_compressed_data

    deterministic = args.deterministic
    valbest = args.valbest

    fp32 = args.fp32
    run_mixed_precision = not fp32

    val_folder = args.val_folder
    # interp_order = args.interp_order
    # interp_order_z = args.interp_order_z
    # force_separate_z = args.force_separate_z

    if not task.startswith("Task"):
        task_id = int(task)
        task = convert_id_to_task_name(task_id)

    if fold == 'all':
        pass
    else:
        fold = int(fold)

    # if force_separate_z == "None":
    #     force_separate_z = None
    # elif force_separate_z == "False":
    #     force_separate_z = False
    # elif force_separate_z == "True":
    #     force_separate_z = True
    # else:
    #     raise ValueError("force_separate_z must be None, True or False. Given: %s" % force_separate_z)

    plans_file, output_folder_name, dataset_directory, batch_dice, stage, \
    trainer_class = get_default_configuration(network, task, network_trainer, plans_identifier)

    if trainer_class is None:
        raise RuntimeError("Could not find trainer class in nnunet.training.network_training")

    if network == "3d_cascade_fullres":
        assert issubclass(trainer_class, (nnUNetTrainerCascadeFullRes, nnUNetTrainerV2CascadeFullRes)), \
            "If running 3d_cascade_fullres then your " \
            "trainer class must be derived from " \
            "nnUNetTrainerCascadeFullRes"
    else:
        assert issubclass(trainer_class,
                          nnUNetTrainer), "network_trainer was found but is not derived from nnUNetTrainer"

    trainer = trainer_class(plans_file, fold, output_folder=output_folder_name, dataset_directory=dataset_directory,
                            batch_dice=batch_dice, stage=stage, unpack_data=decompress_data,
                            deterministic=deterministic,
                            fp16=run_mixed_precision)
    if args.disable_saving:
        trainer.save_final_checkpoint = False # whether or not to save the final checkpoint
        trainer.save_best_checkpoint = False  # whether or not to save the best checkpoint according to
        # self.best_val_eval_criterion_MA
        trainer.save_intermediate_checkpoints = True  # whether or not to save checkpoint_latest. We need that in case
        # the training chashes
        trainer.save_latest_only = True  # if false it will not store/overwrite _latest but separate files each

    trainer.initialize(not validation_only)

    if find_lr:
        trainer.find_lr()
    else:
        if not validation_only:
            if args.continue_training:
                # -c was set, continue a previous training and ignore pretrained weights
                trainer.load_latest_checkpoint()
            elif (not args.continue_training) and (args.pretrained_weights is not None):
                # we start a new training. If pretrained_weights are set, use them
                load_pretrained_weights(trainer.network, args.pretrained_weights)
            else:
                # new training without pretraine weights, do nothing
                pass

            trainer.run_training()
        else:
            if valbest:
                trainer.load_best_checkpoint(train=False)
            else:
                trainer.load_final_checkpoint(train=False)

        trainer.network.eval()

        # predict validation
        trainer.validate(save_softmax=args.npz, validation_folder_name=val_folder,
                         run_postprocessing_on_folds=not disable_postprocessing_on_folds,
                         overwrite=args.val_disable_overwrite)

        if network == '3d_lowres' and not args.disable_next_stage_pred:
            print("predicting segmentations for the next stage of the cascade")
            predict_next_stage(trainer, join(dataset_directory, trainer.plans['data_identifier'] + "_stage%d" % 1))
Ejemplo n.º 17
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("network")
    parser.add_argument("network_trainer")
    parser.add_argument("task", help="can be task name or task id")
    parser.add_argument("fold", help='0, 1, ..., 5 or \'all\'')
    parser.add_argument("-val",
                        "--validation_only",
                        help="use this if you want to only run the validation",
                        action="store_true")
    parser.add_argument("-c",
                        "--continue_training",
                        help="use this if you want to continue a training",
                        action="store_true")
    parser.add_argument(
        "-p",
        help=
        "plans identifier. Only change this if you created a custom experiment planner",
        default=default_plans_identifier,
        required=False)
    parser.add_argument(
        "--use_compressed_data",
        default=False,
        action="store_true",
        help=
        "If you set use_compressed_data, the training cases will not be decompressed. Reading compressed data "
        "is much more CPU and RAM intensive and should only be used if you know what you are "
        "doing",
        required=False)
    parser.add_argument(
        "--deterministic",
        help=
        "Makes training deterministic, but reduces training speed substantially. I (Fabian) think "
        "this is not necessary. Deterministic training will make you overfit to some random seed. "
        "Don't use that.",
        required=False,
        default=False,
        action="store_true")
    parser.add_argument("-gpus",
                        help="number of gpus",
                        required=True,
                        type=int)
    parser.add_argument("--dbs",
                        required=False,
                        default=False,
                        action="store_true",
                        help="distribute batch size. If "
                        "True then whatever "
                        "batch_size is in plans will "
                        "be distributed over DDP "
                        "models, if False then each "
                        "model will have batch_size "
                        "for a total of "
                        "GPUs*batch_size")
    parser.add_argument("--npz",
                        required=False,
                        default=False,
                        action="store_true",
                        help="if set then nnUNet will "
                        "export npz files of "
                        "predicted segmentations "
                        "in the vlaidation as well. "
                        "This is needed to run the "
                        "ensembling step so unless "
                        "you are developing nnUNet "
                        "you should enable this")
    parser.add_argument("--valbest",
                        required=False,
                        default=False,
                        action="store_true",
                        help="")
    parser.add_argument("--find_lr",
                        required=False,
                        default=False,
                        action="store_true",
                        help="")
    parser.add_argument(
        "--fp32",
        required=False,
        default=False,
        action="store_true",
        help="disable mixed precision training and run old school fp32")
    parser.add_argument(
        "--val_folder",
        required=False,
        default="validation_raw",
        help=
        "name of the validation folder. No need to use this for most people")
    parser.add_argument(
        "--disable_saving",
        required=False,
        action='store_true',
        help=
        "If set nnU-Net will not save any parameter files. Useful for development when you are "
        "only interested in the results and want to save some disk space")
    # parser.add_argument("--interp_order", required=False, default=3, type=int,
    #                     help="order of interpolation for segmentations. Testing purpose only. Hands off")
    # parser.add_argument("--interp_order_z", required=False, default=0, type=int,
    #                     help="order of interpolation along z if z is resampled separately. Testing purpose only. "
    #                          "Hands off")
    # parser.add_argument("--force_separate_z", required=False, default="None", type=str,
    #                     help="force_separate_z resampling. Can be None, True or False. Testing purpose only. Hands off")

    args = parser.parse_args()

    task = args.task
    fold = args.fold
    network = args.network
    network_trainer = args.network_trainer
    validation_only = args.validation_only
    plans_identifier = args.p

    use_compressed_data = args.use_compressed_data
    decompress_data = not use_compressed_data

    deterministic = args.deterministic
    valbest = args.valbest
    find_lr = args.find_lr
    num_gpus = args.gpus
    fp32 = args.fp32
    val_folder = args.val_folder
    # interp_order = args.interp_order
    # interp_order_z = args.interp_order_z
    # force_separate_z = args.force_separate_z

    if not task.startswith("Task"):
        task_id = int(task)
        task = convert_id_to_task_name(task_id)

    if fold == 'all':
        pass
    else:
        fold = int(fold)

    # if force_separate_z == "None":
    #     force_separate_z = None
    # elif force_separate_z == "False":
    #     force_separate_z = False
    # elif force_separate_z == "True":
    #     force_separate_z = True
    # else:
    #     raise ValueError("force_separate_z must be None, True or False. Given: %s" % force_separate_z)

    plans_file, output_folder_name, dataset_directory, batch_dice, stage, \
        trainer_class = get_default_configuration(network, task, network_trainer, plans_identifier)

    if trainer_class is None:
        raise RuntimeError("Could not find trainer class")

    if network == "3d_cascade_fullres":
        assert issubclass(trainer_class, nnUNetTrainerCascadeFullRes), "If running 3d_cascade_fullres then your " \
                                                                       "trainer class must be derived from " \
                                                                       "nnUNetTrainerCascadeFullRes"
    else:
        assert issubclass(trainer_class, nnUNetTrainer), "network_trainer was found but is not derived from " \
                                                         "nnUNetTrainer"

    trainer = trainer_class(plans_file,
                            fold,
                            output_folder=output_folder_name,
                            dataset_directory=dataset_directory,
                            batch_dice=batch_dice,
                            stage=stage,
                            unpack_data=decompress_data,
                            deterministic=deterministic,
                            distribute_batch_size=args.dbs,
                            num_gpus=num_gpus,
                            fp16=not fp32)

    if args.disable_saving:
        trainer.save_latest_only = False  # if false it will not store/overwrite _latest but separate files each
        trainer.save_intermediate_checkpoints = False  # whether or not to save checkpoint_latest
        trainer.save_best_checkpoint = False  # whether or not to save the best checkpoint according to self.best_val_eval_criterion_MA
        trainer.save_final_checkpoint = False  # whether or not to save the final checkpoint

    trainer.initialize(not validation_only)

    if find_lr:
        trainer.find_lr()
    else:
        if not validation_only:
            if args.continue_training:
                trainer.load_latest_checkpoint()
            trainer.run_training()
        else:
            if valbest:
                trainer.load_best_checkpoint(train=False)
            else:
                trainer.load_latest_checkpoint(train=False)

        trainer.network.eval()

        # predict validation
        trainer.validate(save_softmax=args.npz,
                         validation_folder_name=val_folder)

        if network == '3d_lowres':
            trainer.load_best_checkpoint(False)
            print("predicting segmentations for the next stage of the cascade")
            predict_next_stage(
                trainer,
                join(dataset_directory,
                     trainer.plans['data_identifier'] + "_stage%d" % 1))
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("network")
    parser.add_argument("network_trainer")
    parser.add_argument("task", help="can be task name or task id")
    parser.add_argument("fold", help='0, 1, ..., 5 or \'all\'')
    parser.add_argument("gpu", help='0, 1, ..., 5 or \'all\'')
    parser.add_argument("-val",
                        "--validation_only",
                        help="use this if you want to only run the validation",
                        action="store_true")
    parser.add_argument("-c",
                        "--continue_training",
                        help="use this if you want to continue a training",
                        action="store_true")
    parser.add_argument(
        "-p",
        help=
        "plans identifier. Only change this if you created a custom experiment planner",
        default=default_plans_identifier,
        required=False)
    parser.add_argument(
        "--use_compressed_data",
        default=False,
        action="store_true",
        help=
        "If you set use_compressed_data, the training cases will not be decompressed. Reading compressed data "
        "is much more CPU and RAM intensive and should only be used if you know what you are "
        "doing",
        required=False)
    parser.add_argument(
        "--deterministic",
        help=
        "Makes training deterministic, but reduces training speed substantially. I (Fabian) think "
        "this is not necessary. Deterministic training will make you overfit to some random seed. "
        "Don't use that.",
        required=False,
        default=False,
        action="store_true")
    parser.add_argument("--npz",
                        required=False,
                        default=False,
                        action="store_true",
                        help="if set then nnUNet will "
                        "export npz files of "
                        "predicted segmentations "
                        "in the validation as well. "
                        "This is needed to run the "
                        "ensembling step so unless "
                        "you are developing nnUNet "
                        "you should enable this")
    parser.add_argument("--find_lr",
                        required=False,
                        default=False,
                        action="store_true",
                        help="not used here, just for fun")
    parser.add_argument("--valbest",
                        required=False,
                        default=False,
                        action="store_true",
                        help="hands off. This is not intended to be used")
    parser.add_argument(
        "--fp32",
        required=False,
        default=False,
        action="store_true",
        help="disable mixed precision training and run old school fp32")
    parser.add_argument(
        "--val_folder",
        required=False,
        default="validation_raw",
        help=
        "name of the validation folder. No need to use this for most people")
    parser.add_argument(
        "--interp_order",
        required=False,
        default=3,
        type=int,
        help=
        "order of interpolation for segmentations. Testing purpose only. Hands off"
    )
    parser.add_argument(
        "--interp_order_z",
        required=False,
        default=0,
        type=int,
        help=
        "order of interpolation along z if z is resampled separately. Testing purpose only. "
        "Hands off")
    parser.add_argument(
        "--force_separate_z",
        required=False,
        default="None",
        type=str,
        help=
        "force_separate_z resampling. Can be None, True or False. Testing purpose only. Hands off"
    )

    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    task = args.task
    fold = args.fold
    network = args.network
    network_trainer = args.network_trainer
    validation_only = args.validation_only
    plans_identifier = args.p
    find_lr = args.find_lr

    use_compressed_data = args.use_compressed_data
    decompress_data = not use_compressed_data

    deterministic = args.deterministic
    valbest = args.valbest

    fp32 = args.fp32
    run_mixed_precision = not fp32

    val_folder = args.val_folder
    interp_order = args.interp_order
    interp_order_z = args.interp_order_z
    force_separate_z = args.force_separate_z

    if not task.startswith("Task"):
        task_id = int(task)
        task = convert_id_to_task_name(task_id)

    if fold == 'all':
        pass
    else:
        fold = int(fold)

    if force_separate_z == "None":
        force_separate_z = None
    elif force_separate_z == "False":
        force_separate_z = False
    elif force_separate_z == "True":
        force_separate_z = True
    else:
        raise ValueError(
            "force_separate_z must be None, True or False. Given: %s" %
            force_separate_z)

    plans_file, output_folder_name, dataset_directory, batch_dice, stage, \
    trainer_class = get_default_configuration(network, task, network_trainer, plans_identifier)

    if trainer_class is None:
        raise RuntimeError(
            "Could not find trainer class in nnunet.training.network_training")

    if network == "3d_cascade_fullres":
        assert issubclass(trainer_class, (nnUNetTrainerCascadeFullRes, nnUNetTrainerV2CascadeFullRes)), \
            "If running 3d_cascade_fullres then your " \
            "trainer class must be derived from " \
            "nnUNetTrainerCascadeFullRes"
    else:
        assert issubclass(
            trainer_class, nnUNetTrainer
        ), "network_trainer was found but is not derived from nnUNetTrainer"

    trainer = trainer_class(plans_file,
                            fold,
                            output_folder=output_folder_name,
                            dataset_directory=dataset_directory,
                            batch_dice=batch_dice,
                            stage=stage,
                            unpack_data=decompress_data,
                            deterministic=deterministic,
                            fp16=run_mixed_precision)

    trainer.initialize(not validation_only)

    if find_lr:
        trainer.find_lr()
    else:
        if not validation_only:
            if args.continue_training:
                trainer.load_latest_checkpoint()
            trainer.run_training()
        else:
            if valbest:
                trainer.load_best_checkpoint(train=False)
            else:
                trainer.load_latest_checkpoint(train=False)

        trainer.network.eval()

        # predict validation
        trainer.validate(save_softmax=args.npz,
                         validation_folder_name=val_folder,
                         force_separate_z=force_separate_z,
                         interpolation_order=interp_order,
                         interpolation_order_z=interp_order_z)

        if network == '3d_lowres':
            trainer.load_best_checkpoint(False)
            print("predicting segmentations for the next stage of the cascade")
            predict_next_stage(
                trainer,
                join(dataset_directory,
                     trainer.plans['data_identifier'] + "_stage%d" % 1))