def main(config_file):

    if config_file is None:
        config_file = os.path.join(dirs.CONFIG_DIR, 'test_brats_aleatoric.yaml')

    context = ctx.TorchTestContext('cuda')
    context.load_from_config(config_file)

    build_test = data.BuildData(
        build_dataset=data.BuildParametrizableDataset(),
    )

    if not hasattr(context.config.others, 'is_log_sigma'):
        raise ValueError('"is_log_sigma" entry missing in configuration file')
    is_log_sigma = context.config.others.is_log_sigma

    test_steps = [AleatoricPredictStep(is_log_sigma)]
    subject_steps = [step.ExtractSubjectInfoStep(), step.EvalSubjectStep()]

    subject_assembler = assembler.SubjectAssembler()
    test = loop.Test(test_steps, subject_steps, subject_assembler, entries=None)

    hook = hooks.ReducedComposeTestLoopHook([hooks.ConsoleTestLogHook(),
                                             hooks.WriteTestMetricsCsvHook('metrics.csv'),
                                             WriteHook()
                                             ])
    test(context, build_test, hook=hook)
Exemple #2
0
def main(config_file):

    if config_file is None:
        config_file = os.path.join(dirs.CONFIG_DIR,
                                   'test_brats_auxiliary_segm.yaml')

    context = ctx.TorchTestContext('cuda')
    context.load_from_config(config_file)

    build_test = data.BuildData(
        build_dataset=data.BuildParametrizableDataset(), )

    test_steps = [SegmentationPredictStep()]
    subject_steps = [step.ExtractSubjectInfoStep(), EvalSubjectStep()]

    subject_assembler = assembler.SubjectAssembler()
    test = loop.Test(test_steps,
                     subject_steps,
                     subject_assembler,
                     entries=('probabilities', 'orig_prediction'))

    hook = hooks.ReducedComposeTestLoopHook([
        hooks.ConsoleTestLogHook(),
        hooks.WriteTestMetricsCsvHook('metrics.csv'),
        WriteHook()
    ])
    test(context, build_test, hook=hook)
def main(config_file: str):

    if config_file is None:
        config_file = os.path.join(dirs.CONFIG_DIR, 'train_brats_auxiliary_segm.yaml')

    context = ctx.TorchTrainContext('cuda')
    context.load_from_config(config_file)

    build_train = data.BuildData(
        build_dataset=data.BuildParametrizableDataset(),
        build_sampler=data.BuildSelectionSampler(),
    )
    build_valid = data.BuildData(
        build_dataset=data.BuildParametrizableDataset(),
    )

    train_steps = [SpecialTrainStep(), step.EvalStep()]
    train = trainloop.Train(train_steps, only_validate=False)

    subject_assembler = assembler.SubjectAssembler()
    validate = trainloop.ValidateSubject([SpecialSegmentationPredictStep()],
                                         [step.ExtractSubjectInfoStep(), EvalSubjectStep()],
                                         subject_assembler, entries=('probabilities',))

    hook = hooks.ComposeTrainLoopHook([hooks.TensorboardXHook(), hooks.ConsoleLogHook(), hooks.SaveBestModelHook(),
                                       hooks.SaveNLastModelHook(3)])
    train(context, build_train, build_valid, validate, hook=hook)
def main(config_file: str, config_id: str):

    if config_file is None:
        if config_id == 'baseline':
            config_file = os.path.join(dirs.CONFIG_DIR,
                                       'train_brats_baseline.yaml')
        elif config_id == 'center':
            config_file = os.path.join(dirs.CONFIG_DIR,
                                       'train_brats_center.yaml')
        elif config_id in ('cv0', 'cv1', 'cv2', 'cv3', 'cv4'):
            config_file = os.path.join(
                dirs.CONFIG_DIR, 'baseline_cv',
                'train_brats_baseline_cv{}.yaml'.format(config_id[-1]))
        elif config_id in ('ensemble0', 'ensemble1', 'ensemble2', 'ensemble3',
                           'ensemble4', 'ensemble5', 'ensemble6', 'ensemble7',
                           'ensemble8', 'ensemble9'):
            config_file = os.path.join(
                dirs.CONFIG_DIR, 'train_ensemble',
                'train_brats_ensemble_{}.yaml'.format(config_id[-1]))
        else:
            config_file = os.path.join(dirs.CONFIG_DIR,
                                       'train_brats_baseline.yaml')

    context = ctx.TorchTrainContext('cuda')
    context.load_from_config(config_file)

    build_train = data.BuildData(
        build_dataset=data.BuildParametrizableDataset(),
        build_sampler=data.BuildSelectionSampler(),
    )
    build_valid = data.BuildData(
        build_dataset=data.BuildParametrizableDataset(), )

    train_steps = [step.TrainStep(), step.EvalStep()]
    train = trainloop.Train(train_steps, only_validate=False)

    subject_assembler = assembler.SubjectAssembler()
    validate = trainloop.ValidateSubject(
        [step.SegmentationPredictStep(do_probs=True)],
        [step.ExtractSubjectInfoStep(),
         EvalSubjectStep()],
        subject_assembler,
        entries=('probabilities', ))

    hook = hooks.ComposeTrainLoopHook([
        hooks.TensorboardXHook(),
        hooks.ConsoleLogHook(),
        hooks.SaveBestModelHook(),
        hooks.SaveNLastModelHook(3)
    ])
    train(context, build_train, build_valid, validate, hook=hook)
def main(config_file, config_id):

    if config_file is None:
        if config_id == 'baseline':
            config_file = os.path.join(dirs.CONFIG_DIR,
                                       'test_brats_baseline.yaml')
        elif config_id == 'baseline_mc':
            config_file = os.path.join(dirs.CONFIG_DIR,
                                       'test_brats_baseline_mc.yaml')
        elif config_id == 'center':
            config_file = os.path.join(dirs.CONFIG_DIR,
                                       'test_brats_center.yaml')
        elif config_id == 'center_mc':
            config_file = os.path.join(dirs.CONFIG_DIR,
                                       'test_brats_center_mc.yaml')
        elif config_id in ('cv0', 'cv1', 'cv2', 'cv3', 'cv4'):
            config_file = os.path.join(
                dirs.CONFIG_DIR, 'baseline_cv',
                'test_brats_baseline_cv{}.yaml'.format(config_id[-1]))
        else:
            config_file = os.path.join(dirs.CONFIG_DIR,
                                       'test_brats_baseline.yaml')

    context = ctx.TorchTestContext('cuda')
    context.load_from_config(config_file)

    build_test = data.BuildData(
        build_dataset=data.BuildParametrizableDataset(), )

    if hasattr(context.config.others, 'mc'):
        test_steps = [
            customstep.McPredictStep(context.config.others.mc),
            customstep.MultiPredictionSummary()
        ]
    else:
        test_steps = [step.SegmentationPredictStep(do_probs=True)]
    subject_steps = [step.ExtractSubjectInfoStep(), EvalSubjectStep()]

    subject_assembler = assembler.SubjectAssembler()
    test = loop.Test(test_steps,
                     subject_steps,
                     subject_assembler,
                     entries=('probabilities', ))

    hook = hooks.ReducedComposeTestLoopHook([
        hooks.ConsoleTestLogHook(),
        hooks.WriteTestMetricsCsvHook('metrics.csv'),
        WriteHook()
    ])
    test(context, build_test, hook=hook)
Exemple #6
0
def main(config_file: str):

    if config_file is None:
        config_file = os.path.join(dirs.CONFIG_DIR,
                                   'train_brats_auxiliary_feat.yaml')

    context = ctx.TorchTrainContext('cuda')
    context.load_from_config(config_file)

    if hasattr(context.config.others, 'model_dir') and hasattr(
            context.config.others, 'test_at'):
        mf = mgt.ModelFiles.from_model_dir(context.config.others.model_dir)
        checkpoint_path = mgt.model_service.find_checkpoint_file(
            mf.weight_checkpoint_dir, context.config.others.test_at)

        model = mgt.model_service.load_model_from_parameters(
            mf.model_path(), with_optimizer=False)
        model.provide_features = True
        mgt.model_service.load_checkpoint(checkpoint_path, model)
        test_model = model.to(context.device)

        test_model.eval()
        for params in test_model.parameters():
            params.requires_grad = False

    build_train = data.BuildData(
        build_dataset=data.BuildParametrizableDataset(),
        build_sampler=data.BuildSelectionSampler(),
    )
    build_valid = data.BuildData(
        build_dataset=data.BuildParametrizableDataset(), )

    train_steps = [SpecialTrainStep(test_model), step.EvalStep()]
    train = trainloop.Train(train_steps, only_validate=False)

    subject_assembler = assembler.SubjectAssembler()
    validate = trainloop.ValidateSubject(
        [SpecialSegmentationPredictStep(test_model)],
        [step.ExtractSubjectInfoStep(),
         EvalSubjectStep()],
        subject_assembler,
        entries=('probabilities', 'net_predictions'))

    hook = hooks.ComposeTrainLoopHook([
        hooks.TensorboardXHook(),
        hooks.ConsoleLogHook(),
        hooks.SaveBestModelHook(),
        hooks.SaveNLastModelHook(3),
    ])
    train(context, build_train, build_valid, validate, hook=hook)
Exemple #7
0
def main(config_file):

    if config_file is None:
        config_file = os.path.join(dirs.CONFIG_DIR, 'test_brats_ensemble.yaml')

    context = ctx.TorchTestContext('cuda')
    context.load_from_config(config_file)

    build_test = data.BuildData(
        build_dataset=data.BuildParametrizableDataset(),
    )

    if not hasattr(context.config.others, 'model_dir') or not hasattr(context.config.others, 'test_at'):
        raise ValueError('missing "model_dir" or "test_at" entry in the configuration (others)')

    model_dirs = context.config.others.model_dir
    if isinstance(model_dirs, str):
        model_dirs = [model_dirs]

    test_models = []
    for i, model_dir in enumerate(model_dirs):
        logging.info('load additional model [{}/{}] {}'.format(i+1, len(model_dirs), os.path.basename(model_dir)))
        mf = mgt.ModelFiles.from_model_dir(model_dir)
        checkpoint_path = mgt.model_service.find_checkpoint_file(mf.weight_checkpoint_dir, context.config.others.test_at)

        model = mgt.model_service.load_model_from_parameters(mf.model_path(), with_optimizer=False)
        mgt.model_service.load_checkpoint(checkpoint_path, model)
        test_model = model.to(context.device)

        test_model.eval()
        for params in test_model.parameters():
            params.requires_grad = False
        test_models.append(test_model)

    test_steps = [EnsemblePredictionStep(test_models), customstep.MultiPredictionSummary()]
    subject_steps = [step.ExtractSubjectInfoStep(), EvalSubjectStep()]

    subject_assembler = assembler.SubjectAssembler()
    test = loop.Test(test_steps, subject_steps, subject_assembler, entries=None)  # None means all output entries

    hook = hooks.ReducedComposeTestLoopHook([hooks.ConsoleTestLogHook(),
                                             hooks.WriteTestMetricsCsvHook('metrics.csv'),
                                             WriteHook()
                                             ])
    test(context, build_test, hook=hook)
def main(config_file):

    if config_file is None:
        config_file = os.path.join(dirs.CONFIG_DIR,
                                   'test_brats_auxiliary_feat.yaml')

    context = ctx.TorchTestContext('cuda')
    context.load_from_config(config_file)

    build_test = data.BuildData(
        build_dataset=data.BuildParametrizableDataset(), )

    if hasattr(context.config.others, 'model_dir') and hasattr(
            context.config.others, 'test_at'):
        mf = mgt.ModelFiles.from_model_dir(context.config.others.model_dir)
        checkpoint_path = mgt.model_service.find_checkpoint_file(
            mf.weight_checkpoint_dir, context.config.others.test_at)

        model = mgt.model_service.load_model_from_parameters(
            mf.model_path(), with_optimizer=False)
        model.provide_features = True
        mgt.model_service.load_checkpoint(checkpoint_path, model)
        test_model = model.to(context.device)

        test_model.eval()
        for params in test_model.parameters():
            params.requires_grad = False

    test_steps = [SegmentationPredictStep(test_model)]
    subject_steps = [step.ExtractSubjectInfoStep(), EvalSubjectStep()]

    subject_assembler = assembler.SubjectAssembler()
    test = loop.Test(test_steps,
                     subject_steps,
                     subject_assembler,
                     entries=('probabilities', 'segm_probabilities'))

    hook = hooks.ReducedComposeTestLoopHook([
        hooks.ConsoleTestLogHook(),
        hooks.WriteTestMetricsCsvHook('metrics.csv'),
        WriteHook()
    ])
    test(context, build_test, hook=hook)
Exemple #9
0
def main(config_file: str):

    if config_file is None:
        config_file = os.path.join(dirs.CONFIG_DIR,
                                   'train_brats_aleatoric.yaml')

    context = ctx.TorchTrainContext('cuda')
    context.load_from_config(config_file)

    build_train = data.BuildData(
        build_dataset=data.BuildParametrizableDataset(),
        build_sampler=data.BuildSelectionSampler(),
    )
    build_valid = data.BuildData(
        build_dataset=data.BuildParametrizableDataset(), )

    if not hasattr(context.config.others, 'is_log_sigma'):
        raise ValueError('"is_log_sigma" entry missing in configuration file')
    is_log_sigma = context.config.others.is_log_sigma

    train_steps = [TrainStepWithEval(loss.AleatoricLoss(is_log_sigma))]
    train = trainloop.Train(train_steps, only_validate=False)

    subject_assembler = assembler.SubjectAssembler()
    validate = trainloop.ValidateSubject(
        [AleatoricPredictStep(is_log_sigma=is_log_sigma)],
        [step.ExtractSubjectInfoStep(),
         step.EvalSubjectStep()],
        subject_assembler,
        entries=None)

    hook = hooks.ComposeTrainLoopHook([
        hooks.TensorboardXHook(),
        hooks.ConsoleLogHook(),
        hooks.SaveBestModelHook(),
        hooks.SaveNLastModelHook(3)
    ])
    train(context, build_train, build_valid, validate, hook=hook)