Exemple #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("network")
    parser.add_argument("network_trainer")
    parser.add_argument("task", help="can be task name or task id")
    parser.add_argument("fold", help='0, 1, ..., 5 or \'all\'')
    parser.add_argument("-val",
                        "--validation_only",
                        help="use this if you want to only run the validation",
                        action="store_true")
    parser.add_argument("-c",
                        "--continue_training",
                        help="use this if you want to continue a training",
                        action="store_true")
    parser.add_argument(
        "-p",
        help=
        "plans identifier. Only change this if you created a custom experiment planner",
        default=default_plans_identifier,
        required=False)
    parser.add_argument(
        "--use_compressed_data",
        default=False,
        action="store_true",
        help=
        "If you set use_compressed_data, the training cases will not be decompressed. Reading compressed data "
        "is much more CPU and RAM intensive and should only be used if you know what you are "
        "doing",
        required=False)
    parser.add_argument(
        "--deterministic",
        help=
        "Makes training deterministic, but reduces training speed substantially. I (Fabian) think "
        "this is not necessary. Deterministic training will make you overfit to some random seed. "
        "Don't use that.",
        required=False,
        default=False,
        action="store_true")
    parser.add_argument("-gpus",
                        help="number of gpus",
                        required=True,
                        type=int)
    parser.add_argument("--dbs",
                        required=False,
                        default=False,
                        action="store_true",
                        help="distribute batch size. If "
                        "True then whatever "
                        "batch_size is in plans will "
                        "be distributed over DDP "
                        "models, if False then each "
                        "model will have batch_size "
                        "for a total of "
                        "GPUs*batch_size")
    parser.add_argument("--npz",
                        required=False,
                        default=False,
                        action="store_true",
                        help="if set then nnUNet will "
                        "export npz files of "
                        "predicted segmentations "
                        "in the vlaidation as well. "
                        "This is needed to run the "
                        "ensembling step so unless "
                        "you are developing nnUNet "
                        "you should enable this")
    parser.add_argument("--valbest",
                        required=False,
                        default=False,
                        action="store_true",
                        help="")
    parser.add_argument("--find_lr",
                        required=False,
                        default=False,
                        action="store_true",
                        help="")
    parser.add_argument(
        "--fp32",
        required=False,
        default=False,
        action="store_true",
        help="disable mixed precision training and run old school fp32")
    parser.add_argument(
        "--val_folder",
        required=False,
        default="validation_raw",
        help=
        "name of the validation folder. No need to use this for most people")
    parser.add_argument(
        "--disable_saving",
        required=False,
        action='store_true',
        help=
        "If set nnU-Net will not save any parameter files. Useful for development when you are "
        "only interested in the results and want to save some disk space")
    # parser.add_argument("--interp_order", required=False, default=3, type=int,
    #                     help="order of interpolation for segmentations. Testing purpose only. Hands off")
    # parser.add_argument("--interp_order_z", required=False, default=0, type=int,
    #                     help="order of interpolation along z if z is resampled separately. Testing purpose only. "
    #                          "Hands off")
    # parser.add_argument("--force_separate_z", required=False, default="None", type=str,
    #                     help="force_separate_z resampling. Can be None, True or False. Testing purpose only. Hands off")

    args = parser.parse_args()

    task = args.task
    fold = args.fold
    network = args.network
    network_trainer = args.network_trainer
    validation_only = args.validation_only
    plans_identifier = args.p

    use_compressed_data = args.use_compressed_data
    decompress_data = not use_compressed_data

    deterministic = args.deterministic
    valbest = args.valbest
    find_lr = args.find_lr
    num_gpus = args.gpus
    fp32 = args.fp32
    val_folder = args.val_folder
    # interp_order = args.interp_order
    # interp_order_z = args.interp_order_z
    # force_separate_z = args.force_separate_z

    if not task.startswith("Task"):
        task_id = int(task)
        task = convert_id_to_task_name(task_id)

    if fold == 'all':
        pass
    else:
        fold = int(fold)

    # if force_separate_z == "None":
    #     force_separate_z = None
    # elif force_separate_z == "False":
    #     force_separate_z = False
    # elif force_separate_z == "True":
    #     force_separate_z = True
    # else:
    #     raise ValueError("force_separate_z must be None, True or False. Given: %s" % force_separate_z)

    plans_file, output_folder_name, dataset_directory, batch_dice, stage, \
        trainer_class = get_default_configuration(network, task, network_trainer, plans_identifier)

    if trainer_class is None:
        raise RuntimeError("Could not find trainer class")

    if network == "3d_cascade_fullres":
        assert issubclass(trainer_class, nnUNetTrainerCascadeFullRes), "If running 3d_cascade_fullres then your " \
                                                                       "trainer class must be derived from " \
                                                                       "nnUNetTrainerCascadeFullRes"
    else:
        assert issubclass(trainer_class, nnUNetTrainer), "network_trainer was found but is not derived from " \
                                                         "nnUNetTrainer"

    trainer = trainer_class(plans_file,
                            fold,
                            output_folder=output_folder_name,
                            dataset_directory=dataset_directory,
                            batch_dice=batch_dice,
                            stage=stage,
                            unpack_data=decompress_data,
                            deterministic=deterministic,
                            distribute_batch_size=args.dbs,
                            num_gpus=num_gpus,
                            fp16=not fp32)

    if args.disable_saving:
        trainer.save_latest_only = False  # if false it will not store/overwrite _latest but separate files each
        trainer.save_intermediate_checkpoints = False  # whether or not to save checkpoint_latest
        trainer.save_best_checkpoint = False  # whether or not to save the best checkpoint according to self.best_val_eval_criterion_MA
        trainer.save_final_checkpoint = False  # whether or not to save the final checkpoint

    trainer.initialize(not validation_only)

    if find_lr:
        trainer.find_lr()
    else:
        if not validation_only:
            if args.continue_training:
                trainer.load_latest_checkpoint()
            trainer.run_training()
        else:
            if valbest:
                trainer.load_best_checkpoint(train=False)
            else:
                trainer.load_latest_checkpoint(train=False)

        trainer.network.eval()

        # predict validation
        trainer.validate(save_softmax=args.npz,
                         validation_folder_name=val_folder)

        if network == '3d_lowres':
            trainer.load_best_checkpoint(False)
            print("predicting segmentations for the next stage of the cascade")
            predict_next_stage(
                trainer,
                join(dataset_directory,
                     trainer.plans['data_identifier'] + "_stage%d" % 1))
        assert issubclass(trainer_class, nnUNetTrainer), "network_trainer was found but is not derived from nnUNetTrainer"

    trainer = trainer_class(plans_file, fold, output_folder=output_folder_name, dataset_directory=dataset_directory,
                            batch_dice=batch_dice, stage=stage, unpack_data=unpack, deterministic=deterministic)

    trainer.initialize(not validation_only)

    if find_lr:
        trainer.find_lr()
    else:
        if not validation_only:
            if args.continue_training:
                trainer.load_latest_checkpoint()
            trainer.run_training()
        elif not valbest:
            trainer.load_latest_checkpoint(train=False)

        if valbest:
            trainer.load_best_checkpoint(train=False)
            val_folder = "validation_best_epoch"
        else:
            val_folder = "validation"

        # predict validation
        trainer.validate(save_softmax=args.npz, validation_folder_name=val_folder)

        if network == '3d_lowres':
            trainer.load_best_checkpoint(False)
            print("predicting segmentations for the next stage of the cascade")
            predict_next_stage(trainer, join(dataset_directory, trainer.plans['data_identifier'] + "_stage%d" % 1))
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("network")
    parser.add_argument("network_trainer")
    parser.add_argument("task", help="can be task name or task id")
    parser.add_argument("fold", help='0, 1, ..., 5 or \'all\'')
    parser.add_argument("gpu", help='0, 1, ..., 5 or \'all\'')
    parser.add_argument("-val",
                        "--validation_only",
                        help="use this if you want to only run the validation",
                        action="store_true")
    parser.add_argument("-c",
                        "--continue_training",
                        help="use this if you want to continue a training",
                        action="store_true")
    parser.add_argument(
        "-p",
        help=
        "plans identifier. Only change this if you created a custom experiment planner",
        default=default_plans_identifier,
        required=False)
    parser.add_argument(
        "--use_compressed_data",
        default=False,
        action="store_true",
        help=
        "If you set use_compressed_data, the training cases will not be decompressed. Reading compressed data "
        "is much more CPU and RAM intensive and should only be used if you know what you are "
        "doing",
        required=False)
    parser.add_argument(
        "--deterministic",
        help=
        "Makes training deterministic, but reduces training speed substantially. I (Fabian) think "
        "this is not necessary. Deterministic training will make you overfit to some random seed. "
        "Don't use that.",
        required=False,
        default=False,
        action="store_true")
    parser.add_argument("--npz",
                        required=False,
                        default=False,
                        action="store_true",
                        help="if set then nnUNet will "
                        "export npz files of "
                        "predicted segmentations "
                        "in the validation as well. "
                        "This is needed to run the "
                        "ensembling step so unless "
                        "you are developing nnUNet "
                        "you should enable this")
    parser.add_argument("--find_lr",
                        required=False,
                        default=False,
                        action="store_true",
                        help="not used here, just for fun")
    parser.add_argument("--valbest",
                        required=False,
                        default=False,
                        action="store_true",
                        help="hands off. This is not intended to be used")
    parser.add_argument(
        "--fp32",
        required=False,
        default=False,
        action="store_true",
        help="disable mixed precision training and run old school fp32")
    parser.add_argument(
        "--val_folder",
        required=False,
        default="validation_raw",
        help=
        "name of the validation folder. No need to use this for most people")
    parser.add_argument(
        "--interp_order",
        required=False,
        default=3,
        type=int,
        help=
        "order of interpolation for segmentations. Testing purpose only. Hands off"
    )
    parser.add_argument(
        "--interp_order_z",
        required=False,
        default=0,
        type=int,
        help=
        "order of interpolation along z if z is resampled separately. Testing purpose only. "
        "Hands off")
    parser.add_argument(
        "--force_separate_z",
        required=False,
        default="None",
        type=str,
        help=
        "force_separate_z resampling. Can be None, True or False. Testing purpose only. Hands off"
    )

    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    task = args.task
    fold = args.fold
    network = args.network
    network_trainer = args.network_trainer
    validation_only = args.validation_only
    plans_identifier = args.p
    find_lr = args.find_lr

    use_compressed_data = args.use_compressed_data
    decompress_data = not use_compressed_data

    deterministic = args.deterministic
    valbest = args.valbest

    fp32 = args.fp32
    run_mixed_precision = not fp32

    val_folder = args.val_folder
    interp_order = args.interp_order
    interp_order_z = args.interp_order_z
    force_separate_z = args.force_separate_z

    if not task.startswith("Task"):
        task_id = int(task)
        task = convert_id_to_task_name(task_id)

    if fold == 'all':
        pass
    else:
        fold = int(fold)

    if force_separate_z == "None":
        force_separate_z = None
    elif force_separate_z == "False":
        force_separate_z = False
    elif force_separate_z == "True":
        force_separate_z = True
    else:
        raise ValueError(
            "force_separate_z must be None, True or False. Given: %s" %
            force_separate_z)

    plans_file, output_folder_name, dataset_directory, batch_dice, stage, \
    trainer_class = get_default_configuration(network, task, network_trainer, plans_identifier)

    if trainer_class is None:
        raise RuntimeError(
            "Could not find trainer class in nnunet.training.network_training")

    if network == "3d_cascade_fullres":
        assert issubclass(trainer_class, (nnUNetTrainerCascadeFullRes, nnUNetTrainerV2CascadeFullRes)), \
            "If running 3d_cascade_fullres then your " \
            "trainer class must be derived from " \
            "nnUNetTrainerCascadeFullRes"
    else:
        assert issubclass(
            trainer_class, nnUNetTrainer
        ), "network_trainer was found but is not derived from nnUNetTrainer"

    trainer = trainer_class(plans_file,
                            fold,
                            output_folder=output_folder_name,
                            dataset_directory=dataset_directory,
                            batch_dice=batch_dice,
                            stage=stage,
                            unpack_data=decompress_data,
                            deterministic=deterministic,
                            fp16=run_mixed_precision)

    trainer.initialize(not validation_only)

    if find_lr:
        trainer.find_lr()
    else:
        if not validation_only:
            if args.continue_training:
                trainer.load_latest_checkpoint()
            trainer.run_training()
        else:
            if valbest:
                trainer.load_best_checkpoint(train=False)
            else:
                trainer.load_latest_checkpoint(train=False)

        trainer.network.eval()

        # predict validation
        trainer.validate(save_softmax=args.npz,
                         validation_folder_name=val_folder,
                         force_separate_z=force_separate_z,
                         interpolation_order=interp_order,
                         interpolation_order_z=interp_order_z)

        if network == '3d_lowres':
            trainer.load_best_checkpoint(False)
            print("predicting segmentations for the next stage of the cascade")
            predict_next_stage(
                trainer,
                join(dataset_directory,
                     trainer.plans['data_identifier'] + "_stage%d" % 1))
Exemple #4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("network")
    parser.add_argument("network_trainer")
    parser.add_argument("task", help="can be task name or task id")
    parser.add_argument("fold", help='0, 1, ..., 5 or \'all\'')
    parser.add_argument("-val", "--validation_only", help="use this if you want to only run the validation",
                        action="store_true")
    parser.add_argument("-c", "--continue_training", help="use this if you want to continue a training",
                        action="store_true")
    parser.add_argument("-p", help="plans identifier. Only change this if you created a custom experiment planner",
                        default=default_plans_identifier, required=False)
    parser.add_argument("--use_compressed_data", default=False, action="store_true",
                        help="If you set use_compressed_data, the training cases will not be decompressed. Reading compressed data "
                             "is much more CPU and RAM intensive and should only be used if you know what you are "
                             "doing", required=False)
    parser.add_argument("--deterministic",
                        help="Makes training deterministic, but reduces training speed substantially. I (Fabian) think "
                             "this is not necessary. Deterministic training will make you overfit to some random seed. "
                             "Don't use that.",
                        required=False, default=False, action="store_true")
    parser.add_argument("--npz", required=False, default=False, action="store_true", help="if set then nnUNet will "
                                                                                          "export npz files of "
                                                                                          "predicted segmentations "
                                                                                          "in the validation as well. "
                                                                                          "This is needed to run the "
                                                                                          "ensembling step so unless "
                                                                                          "you are developing nnUNet "
                                                                                          "you should enable this")
    parser.add_argument("--find_lr", required=False, default=False, action="store_true",
                        help="not used here, just for fun")
    parser.add_argument("--valbest", required=False, default=False, action="store_true",
                        help="hands off. This is not intended to be used")
    parser.add_argument("--fp32", required=False, default=False, action="store_true",
                        help="disable mixed precision training and run old school fp32")
    parser.add_argument("--val_folder", required=False, default="validation_raw",
                        help="name of the validation folder. No need to use this for most people")
    parser.add_argument("--disable_saving", required=False, action='store_true',
                        help="If set nnU-Net will not save any parameter files (except a temporary checkpoint that "
                             "will be removed at the end of the training). Useful for development when you are "
                             "only interested in the results and want to save some disk space")
    parser.add_argument("--disable_postprocessing_on_folds", required=False, action='store_true',
                        help="Running postprocessing on each fold only makes sense when developing with nnU-Net and "
                             "closely observing the model performance on specific configurations. You do not need it "
                             "when applying nnU-Net because the postprocessing for this will be determined only once "
                             "all five folds have been trained and nnUNet_find_best_configuration is called. Usually "
                             "running postprocessing on each fold is computationally cheap, but some users have "
                             "reported issues with very large images. If your images are large (>600x600x600 voxels) "
                             "you should consider setting this flag.")
    # parser.add_argument("--interp_order", required=False, default=3, type=int,
    #                     help="order of interpolation for segmentations. Testing purpose only. Hands off")
    # parser.add_argument("--interp_order_z", required=False, default=0, type=int,
    #                     help="order of interpolation along z if z is resampled separately. Testing purpose only. "
    #                          "Hands off")
    # parser.add_argument("--force_separate_z", required=False, default="None", type=str,
    #                     help="force_separate_z resampling. Can be None, True or False. Testing purpose only. Hands off")
    parser.add_argument('--val_disable_overwrite', action='store_false', default=True,
                        help='Validation does not overwrite existing segmentations')
    parser.add_argument('--disable_next_stage_pred', action='store_true', default=False,
                        help='do not predict next stage')
    parser.add_argument('-pretrained_weights', type=str, required=False, default=None,
                        help='path to nnU-Net checkpoint file to be used as pretrained model (use .model '
                             'file, for example model_final_checkpoint.model). Will only be used when actually training. '
                             'Optional. Beta. Use with caution.')

    args = parser.parse_args()

    task = args.task
    fold = args.fold
    network = args.network
    network_trainer = args.network_trainer
    validation_only = args.validation_only
    plans_identifier = args.p
    find_lr = args.find_lr
    disable_postprocessing_on_folds = args.disable_postprocessing_on_folds

    use_compressed_data = args.use_compressed_data
    decompress_data = not use_compressed_data

    deterministic = args.deterministic
    valbest = args.valbest

    fp32 = args.fp32
    run_mixed_precision = not fp32

    val_folder = args.val_folder
    # interp_order = args.interp_order
    # interp_order_z = args.interp_order_z
    # force_separate_z = args.force_separate_z

    if not task.startswith("Task"):
        task_id = int(task)
        task = convert_id_to_task_name(task_id)

    if fold == 'all':
        pass
    else:
        fold = int(fold)

    # if force_separate_z == "None":
    #     force_separate_z = None
    # elif force_separate_z == "False":
    #     force_separate_z = False
    # elif force_separate_z == "True":
    #     force_separate_z = True
    # else:
    #     raise ValueError("force_separate_z must be None, True or False. Given: %s" % force_separate_z)

    plans_file, output_folder_name, dataset_directory, batch_dice, stage, \
    trainer_class = get_default_configuration(network, task, network_trainer, plans_identifier)

    if trainer_class is None:
        raise RuntimeError("Could not find trainer class in nnunet.training.network_training")

    if network == "3d_cascade_fullres":
        assert issubclass(trainer_class, (nnUNetTrainerCascadeFullRes, nnUNetTrainerV2CascadeFullRes)), \
            "If running 3d_cascade_fullres then your " \
            "trainer class must be derived from " \
            "nnUNetTrainerCascadeFullRes"
    else:
        assert issubclass(trainer_class,
                          nnUNetTrainer), "network_trainer was found but is not derived from nnUNetTrainer"

    trainer = trainer_class(plans_file, fold, output_folder=output_folder_name, dataset_directory=dataset_directory,
                            batch_dice=batch_dice, stage=stage, unpack_data=decompress_data,
                            deterministic=deterministic,
                            fp16=run_mixed_precision)
    if args.disable_saving:
        trainer.save_final_checkpoint = False # whether or not to save the final checkpoint
        trainer.save_best_checkpoint = False  # whether or not to save the best checkpoint according to
        # self.best_val_eval_criterion_MA
        trainer.save_intermediate_checkpoints = True  # whether or not to save checkpoint_latest. We need that in case
        # the training chashes
        trainer.save_latest_only = True  # if false it will not store/overwrite _latest but separate files each

    trainer.initialize(not validation_only)

    if find_lr:
        trainer.find_lr()
    else:
        if not validation_only:
            if args.continue_training:
                # -c was set, continue a previous training and ignore pretrained weights
                trainer.load_latest_checkpoint()
            elif (not args.continue_training) and (args.pretrained_weights is not None):
                # we start a new training. If pretrained_weights are set, use them
                load_pretrained_weights(trainer.network, args.pretrained_weights)
            else:
                # new training without pretraine weights, do nothing
                pass

            trainer.run_training()
        else:
            if valbest:
                trainer.load_best_checkpoint(train=False)
            else:
                trainer.load_final_checkpoint(train=False)

        trainer.network.eval()

        # predict validation
        trainer.validate(save_softmax=args.npz, validation_folder_name=val_folder,
                         run_postprocessing_on_folds=not disable_postprocessing_on_folds,
                         overwrite=args.val_disable_overwrite)

        if network == '3d_lowres' and not args.disable_next_stage_pred:
            print("predicting segmentations for the next stage of the cascade")
            predict_next_stage(trainer, join(dataset_directory, trainer.plans['data_identifier'] + "_stage%d" % 1))
Exemple #5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-gpu", type=str, default='0')

    parser.add_argument("-network", type=str, default='3d_fullres')
    parser.add_argument("-network_trainer",
                        type=str,
                        default='nnUNetTrainerV2_ResTrans')
    parser.add_argument("-task",
                        type=str,
                        default='17',
                        help="can be task name or task id")
    parser.add_argument("-fold",
                        type=str,
                        default='all',
                        help='0, 1, ..., 5 or \'all\'')
    parser.add_argument("-outpath",
                        type=str,
                        default='Trainer_CoTr',
                        help='output path')
    parser.add_argument("-norm_cfg",
                        type=str,
                        default='IN',
                        help='BN, IN or GN')
    parser.add_argument("-activation_cfg",
                        type=str,
                        default='LeakyReLU',
                        help='LeakyReLU or ReLU')

    parser.add_argument("-val",
                        "--validation_only",
                        default=False,
                        help="use this if you want to only run the validation",
                        required=False,
                        action="store_true")
    parser.add_argument("-c",
                        "--continue_training",
                        help="use this if you want to continue a training",
                        action="store_true")
    parser.add_argument(
        "-p",
        help=
        "plans identifier. Only change this if you created a custom experiment planner",
        default=default_plans_identifier,
        required=False)
    parser.add_argument(
        "--use_compressed_data",
        default=False,
        action="store_true",
        help=
        "If you set use_compressed_data, the training cases will not be decompressed. Reading compressed data "
        "is much more CPU and RAM intensive and should only be used if you know what you are "
        "doing",
        required=False)
    parser.add_argument("--deterministic", default=False, action="store_true")
    parser.add_argument("--npz",
                        required=False,
                        default=False,
                        action="store_true",
                        help="if set then nnUNet will "
                        "export npz files of "
                        "predicted segmentations "
                        "in the validation as well. "
                        "This is needed to run the "
                        "ensembling step so unless "
                        "you are developing nnUNet "
                        "you should enable this")
    parser.add_argument("--find_lr",
                        required=False,
                        default=False,
                        action="store_true",
                        help="not used here, just for fun")
    parser.add_argument("--valbest",
                        required=False,
                        default=False,
                        action="store_true",
                        help="hands off. This is not intended to be used")
    parser.add_argument(
        "--fp32",
        required=False,
        default=False,
        action="store_true",
        help="disable mixed precision training and run old school fp32")
    parser.add_argument(
        "--val_folder",
        required=False,
        default="validation_raw",
        help=
        "name of the validation folder. No need to use this for most people")
    parser.add_argument("--disable_saving",
                        required=False,
                        action='store_true')

    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    norm_cfg = args.norm_cfg
    activation_cfg = args.activation_cfg
    outpath = args.outpath + '_' + norm_cfg + '_' + activation_cfg

    task = args.task
    fold = args.fold
    network = args.network
    network_trainer = args.network_trainer
    validation_only = args.validation_only
    plans_identifier = args.p
    find_lr = args.find_lr

    use_compressed_data = args.use_compressed_data
    decompress_data = not use_compressed_data

    deterministic = args.deterministic
    valbest = args.valbest

    fp32 = args.fp32
    run_mixed_precision = not fp32

    val_folder = args.val_folder

    if validation_only and (norm_cfg == 'SyncBN'):
        norm_cfg == 'BN'

    if not task.startswith("Task"):
        task_id = int(task)
        task = convert_id_to_task_name(task_id)

    if fold == 'all':
        pass
    else:
        fold = int(fold)

    plans_file, output_folder_name, dataset_directory, batch_dice, stage, \
    trainer_class = get_default_configuration(outpath, network, task, network_trainer, plans_identifier, \
                                              search_in=(CoTr.__path__[0], "training", "network_training"), \
                                              base_module='CoTr.training.network_training')

    trainer = trainer_class(plans_file,
                            fold,
                            norm_cfg,
                            activation_cfg,
                            output_folder=output_folder_name,
                            dataset_directory=dataset_directory,
                            batch_dice=batch_dice,
                            stage=stage,
                            unpack_data=decompress_data,
                            deterministic=deterministic,
                            fp16=run_mixed_precision)

    if args.disable_saving:
        trainer.save_latest_only = False  # if false it will not store/overwrite _latest but separate files each
        trainer.save_intermediate_checkpoints = False  # whether or not to save checkpoint_latest
        trainer.save_best_checkpoint = False  # whether or not to save the best checkpoint according to self.best_val_eval_criterion_MA
        trainer.save_final_checkpoint = False  # whether or not to save the final checkpoint

    trainer.initialize(not validation_only)

    if find_lr:
        trainer.find_lr()
    else:
        if not validation_only:
            if args.continue_training:
                trainer.load_latest_checkpoint()
            trainer.run_training()
        else:
            if valbest:
                trainer.load_best_checkpoint(train=False)
            else:
                trainer.load_latest_checkpoint(train=False)

        trainer.network.eval()

        # predict validation
        trainer.validate(save_softmax=args.npz,
                         validation_folder_name=val_folder)

        if network == '3d_lowres':
            print("predicting segmentations for the next stage of the cascade")
            predict_next_stage(
                trainer,
                join(dataset_directory,
                     trainer.plans['data_identifier'] + "_stage%d" % 1))