def main(config_file: str): if config_file is None: config_file = os.path.join(dirs.CONFIG_DIR, 'train_brats_auxiliary_segm.yaml') context = ctx.TorchTrainContext('cuda') context.load_from_config(config_file) build_train = data.BuildData( build_dataset=data.BuildParametrizableDataset(), build_sampler=data.BuildSelectionSampler(), ) build_valid = data.BuildData( build_dataset=data.BuildParametrizableDataset(), ) train_steps = [SpecialTrainStep(), step.EvalStep()] train = trainloop.Train(train_steps, only_validate=False) subject_assembler = assembler.SubjectAssembler() validate = trainloop.ValidateSubject([SpecialSegmentationPredictStep()], [step.ExtractSubjectInfoStep(), EvalSubjectStep()], subject_assembler, entries=('probabilities',)) hook = hooks.ComposeTrainLoopHook([hooks.TensorboardXHook(), hooks.ConsoleLogHook(), hooks.SaveBestModelHook(), hooks.SaveNLastModelHook(3)]) train(context, build_train, build_valid, validate, hook=hook)
def main(config_file: str, config_id: str): if config_file is None: if config_id == 'baseline': config_file = os.path.join(dirs.CONFIG_DIR, 'train_brats_baseline.yaml') elif config_id == 'center': config_file = os.path.join(dirs.CONFIG_DIR, 'train_brats_center.yaml') elif config_id in ('cv0', 'cv1', 'cv2', 'cv3', 'cv4'): config_file = os.path.join( dirs.CONFIG_DIR, 'baseline_cv', 'train_brats_baseline_cv{}.yaml'.format(config_id[-1])) elif config_id in ('ensemble0', 'ensemble1', 'ensemble2', 'ensemble3', 'ensemble4', 'ensemble5', 'ensemble6', 'ensemble7', 'ensemble8', 'ensemble9'): config_file = os.path.join( dirs.CONFIG_DIR, 'train_ensemble', 'train_brats_ensemble_{}.yaml'.format(config_id[-1])) else: config_file = os.path.join(dirs.CONFIG_DIR, 'train_brats_baseline.yaml') context = ctx.TorchTrainContext('cuda') context.load_from_config(config_file) build_train = data.BuildData( build_dataset=data.BuildParametrizableDataset(), build_sampler=data.BuildSelectionSampler(), ) build_valid = data.BuildData( build_dataset=data.BuildParametrizableDataset(), ) train_steps = [step.TrainStep(), step.EvalStep()] train = trainloop.Train(train_steps, only_validate=False) subject_assembler = assembler.SubjectAssembler() validate = trainloop.ValidateSubject( [step.SegmentationPredictStep(do_probs=True)], [step.ExtractSubjectInfoStep(), EvalSubjectStep()], subject_assembler, entries=('probabilities', )) hook = hooks.ComposeTrainLoopHook([ hooks.TensorboardXHook(), hooks.ConsoleLogHook(), hooks.SaveBestModelHook(), hooks.SaveNLastModelHook(3) ]) train(context, build_train, build_valid, validate, hook=hook)
def main(config_file: str): if config_file is None: config_file = os.path.join(dirs.CONFIG_DIR, 'train_brats_auxiliary_feat.yaml') context = ctx.TorchTrainContext('cuda') context.load_from_config(config_file) if hasattr(context.config.others, 'model_dir') and hasattr( context.config.others, 'test_at'): mf = mgt.ModelFiles.from_model_dir(context.config.others.model_dir) checkpoint_path = mgt.model_service.find_checkpoint_file( mf.weight_checkpoint_dir, context.config.others.test_at) model = mgt.model_service.load_model_from_parameters( mf.model_path(), with_optimizer=False) model.provide_features = True mgt.model_service.load_checkpoint(checkpoint_path, model) test_model = model.to(context.device) test_model.eval() for params in test_model.parameters(): params.requires_grad = False build_train = data.BuildData( build_dataset=data.BuildParametrizableDataset(), build_sampler=data.BuildSelectionSampler(), ) build_valid = data.BuildData( build_dataset=data.BuildParametrizableDataset(), ) train_steps = [SpecialTrainStep(test_model), step.EvalStep()] train = trainloop.Train(train_steps, only_validate=False) subject_assembler = assembler.SubjectAssembler() validate = trainloop.ValidateSubject( [SpecialSegmentationPredictStep(test_model)], [step.ExtractSubjectInfoStep(), EvalSubjectStep()], subject_assembler, entries=('probabilities', 'net_predictions')) hook = hooks.ComposeTrainLoopHook([ hooks.TensorboardXHook(), hooks.ConsoleLogHook(), hooks.SaveBestModelHook(), hooks.SaveNLastModelHook(3), ]) train(context, build_train, build_valid, validate, hook=hook)
def main(config_file): if config_file is None: config_file = os.path.join(dirs.CONFIG_DIR, 'test_brats_auxiliary_segm.yaml') context = ctx.TorchTestContext('cuda') context.load_from_config(config_file) build_test = data.BuildData( build_dataset=data.BuildParametrizableDataset(), ) test_steps = [SegmentationPredictStep()] subject_steps = [step.ExtractSubjectInfoStep(), EvalSubjectStep()] subject_assembler = assembler.SubjectAssembler() test = loop.Test(test_steps, subject_steps, subject_assembler, entries=('probabilities', 'orig_prediction')) hook = hooks.ReducedComposeTestLoopHook([ hooks.ConsoleTestLogHook(), hooks.WriteTestMetricsCsvHook('metrics.csv'), WriteHook() ]) test(context, build_test, hook=hook)
def main(config_file): if config_file is None: config_file = os.path.join(dirs.CONFIG_DIR, 'test_brats_aleatoric.yaml') context = ctx.TorchTestContext('cuda') context.load_from_config(config_file) build_test = data.BuildData( build_dataset=data.BuildParametrizableDataset(), ) if not hasattr(context.config.others, 'is_log_sigma'): raise ValueError('"is_log_sigma" entry missing in configuration file') is_log_sigma = context.config.others.is_log_sigma test_steps = [AleatoricPredictStep(is_log_sigma)] subject_steps = [step.ExtractSubjectInfoStep(), step.EvalSubjectStep()] subject_assembler = assembler.SubjectAssembler() test = loop.Test(test_steps, subject_steps, subject_assembler, entries=None) hook = hooks.ReducedComposeTestLoopHook([hooks.ConsoleTestLogHook(), hooks.WriteTestMetricsCsvHook('metrics.csv'), WriteHook() ]) test(context, build_test, hook=hook)
def main(config_file: str): if config_file is None: config_file = os.path.join(dirs.CONFIG_DIR, 'train_brats_aleatoric.yaml') context = ctx.TorchTrainContext('cuda') context.load_from_config(config_file) build_train = data.BuildData( build_dataset=data.BuildParametrizableDataset(), build_sampler=data.BuildSelectionSampler(), ) build_valid = data.BuildData( build_dataset=data.BuildParametrizableDataset(), ) if not hasattr(context.config.others, 'is_log_sigma'): raise ValueError('"is_log_sigma" entry missing in configuration file') is_log_sigma = context.config.others.is_log_sigma train_steps = [TrainStepWithEval(loss.AleatoricLoss(is_log_sigma))] train = trainloop.Train(train_steps, only_validate=False) subject_assembler = assembler.SubjectAssembler() validate = trainloop.ValidateSubject( [AleatoricPredictStep(is_log_sigma=is_log_sigma)], [step.ExtractSubjectInfoStep(), step.EvalSubjectStep()], subject_assembler, entries=None) hook = hooks.ComposeTrainLoopHook([ hooks.TensorboardXHook(), hooks.ConsoleLogHook(), hooks.SaveBestModelHook(), hooks.SaveNLastModelHook(3) ]) train(context, build_train, build_valid, validate, hook=hook)
def main(config_file, config_id): if config_file is None: if config_id == 'baseline': config_file = os.path.join(dirs.CONFIG_DIR, 'test_brats_baseline.yaml') elif config_id == 'baseline_mc': config_file = os.path.join(dirs.CONFIG_DIR, 'test_brats_baseline_mc.yaml') elif config_id == 'center': config_file = os.path.join(dirs.CONFIG_DIR, 'test_brats_center.yaml') elif config_id == 'center_mc': config_file = os.path.join(dirs.CONFIG_DIR, 'test_brats_center_mc.yaml') elif config_id in ('cv0', 'cv1', 'cv2', 'cv3', 'cv4'): config_file = os.path.join( dirs.CONFIG_DIR, 'baseline_cv', 'test_brats_baseline_cv{}.yaml'.format(config_id[-1])) else: config_file = os.path.join(dirs.CONFIG_DIR, 'test_brats_baseline.yaml') context = ctx.TorchTestContext('cuda') context.load_from_config(config_file) build_test = data.BuildData( build_dataset=data.BuildParametrizableDataset(), ) if hasattr(context.config.others, 'mc'): test_steps = [ customstep.McPredictStep(context.config.others.mc), customstep.MultiPredictionSummary() ] else: test_steps = [step.SegmentationPredictStep(do_probs=True)] subject_steps = [step.ExtractSubjectInfoStep(), EvalSubjectStep()] subject_assembler = assembler.SubjectAssembler() test = loop.Test(test_steps, subject_steps, subject_assembler, entries=('probabilities', )) hook = hooks.ReducedComposeTestLoopHook([ hooks.ConsoleTestLogHook(), hooks.WriteTestMetricsCsvHook('metrics.csv'), WriteHook() ]) test(context, build_test, hook=hook)
def main(config_file): if config_file is None: config_file = os.path.join(dirs.CONFIG_DIR, 'test_brats_ensemble.yaml') context = ctx.TorchTestContext('cuda') context.load_from_config(config_file) build_test = data.BuildData( build_dataset=data.BuildParametrizableDataset(), ) if not hasattr(context.config.others, 'model_dir') or not hasattr(context.config.others, 'test_at'): raise ValueError('missing "model_dir" or "test_at" entry in the configuration (others)') model_dirs = context.config.others.model_dir if isinstance(model_dirs, str): model_dirs = [model_dirs] test_models = [] for i, model_dir in enumerate(model_dirs): logging.info('load additional model [{}/{}] {}'.format(i+1, len(model_dirs), os.path.basename(model_dir))) mf = mgt.ModelFiles.from_model_dir(model_dir) checkpoint_path = mgt.model_service.find_checkpoint_file(mf.weight_checkpoint_dir, context.config.others.test_at) model = mgt.model_service.load_model_from_parameters(mf.model_path(), with_optimizer=False) mgt.model_service.load_checkpoint(checkpoint_path, model) test_model = model.to(context.device) test_model.eval() for params in test_model.parameters(): params.requires_grad = False test_models.append(test_model) test_steps = [EnsemblePredictionStep(test_models), customstep.MultiPredictionSummary()] subject_steps = [step.ExtractSubjectInfoStep(), EvalSubjectStep()] subject_assembler = assembler.SubjectAssembler() test = loop.Test(test_steps, subject_steps, subject_assembler, entries=None) # None means all output entries hook = hooks.ReducedComposeTestLoopHook([hooks.ConsoleTestLogHook(), hooks.WriteTestMetricsCsvHook('metrics.csv'), WriteHook() ]) test(context, build_test, hook=hook)
def main(config_file): if config_file is None: config_file = os.path.join(dirs.CONFIG_DIR, 'test_brats_auxiliary_feat.yaml') context = ctx.TorchTestContext('cuda') context.load_from_config(config_file) build_test = data.BuildData( build_dataset=data.BuildParametrizableDataset(), ) if hasattr(context.config.others, 'model_dir') and hasattr( context.config.others, 'test_at'): mf = mgt.ModelFiles.from_model_dir(context.config.others.model_dir) checkpoint_path = mgt.model_service.find_checkpoint_file( mf.weight_checkpoint_dir, context.config.others.test_at) model = mgt.model_service.load_model_from_parameters( mf.model_path(), with_optimizer=False) model.provide_features = True mgt.model_service.load_checkpoint(checkpoint_path, model) test_model = model.to(context.device) test_model.eval() for params in test_model.parameters(): params.requires_grad = False test_steps = [SegmentationPredictStep(test_model)] subject_steps = [step.ExtractSubjectInfoStep(), EvalSubjectStep()] subject_assembler = assembler.SubjectAssembler() test = loop.Test(test_steps, subject_steps, subject_assembler, entries=('probabilities', 'segm_probabilities')) hook = hooks.ReducedComposeTestLoopHook([ hooks.ConsoleTestLogHook(), hooks.WriteTestMetricsCsvHook('metrics.csv'), WriteHook() ]) test(context, build_test, hook=hook)