def main(args): # reset default graph tf.reset_default_graph() base_path_trainer = os.path.join( os.path.dirname(os.path.abspath(__file__)), 'trainers', f'{args.trainer}.py') base_path_network = os.path.join( os.path.dirname(os.path.abspath(__file__)), 'models', f'{args.model}.py') trainer = getattr( SourceFileLoader(args.trainer, base_path_trainer).load_module(), args.trainer) network = getattr( SourceFileLoader(args.model, base_path_network).load_module(), args.model) with open(os.path.join(base_path, args.config), 'r') as f: json_config = json.load(f) dataset = Dataset.BRAINWEB options = get_options(batchsize=args.batchsize, learningrate=args.lr, numEpochs=args.numEpochs, zDim=args.zDim, outputWidth=args.outputWidth, outputHeight=args.outputHeight, slices_start=args.slices_start, slices_end=args.slices_end, numMonteCarloSamples=args.numMonteCarloSamples, config=json_config) options['data']['dir'] = options["globals"][dataset.value] dataset_hc, dataset_pc = get_datasets(options, dataset=dataset) config = get_config(trainer=trainer, options=options, optimizer=args.optimizer, intermediateResolutions=args.intermediateResolutions, dropout_rate=0.2, dataset=dataset_hc) # handle additional Config parameters e.g. for GMVAE for arg in vars(args): if hasattr(config, arg): setattr(config, arg, getattr(args, arg)) # Create an instance of the model and train it model = trainer(tf.Session(), config, network=network) # Train it model.train(dataset_hc) ######################## # Evaluate best dice # ######################### if not args.threshold: # if no threshold is given but a dataset => Best dice evaluation on specific dataset if args.ds: # evaluate specific dataset evaluate_optimal(model, options, args.ds) return else: # evaluate all datasets for best dice without hyper intensity prior options['applyHyperIntensityPrior'] = False evaluate_optimal(model, options, Dataset.Brainweb) evaluate_optimal(model, options, Dataset.MSLUB) evaluate_optimal(model, options, Dataset.MSISBI2015) # evaluate all datasets for best dice without hyper intensity prior options['applyHyperIntensityPrior'] = True evaluate_optimal(model, options, Dataset.Brainweb) evaluate_optimal(model, options, Dataset.MSLUB) evaluate_optimal(model, options, Dataset.MSISBI2015) ############################################### # Evaluate generalization to other datasets # ############################################### if args.threshold and args.ds: # only threshold is invalid evaluate_with_threshold(model, options, args.threshold, args.ds) else: options['applyHyperIntensityPrior'] = False datasetBrainweb = get_evaluation_dataset(options, Dataset.Brainweb) _bestDiceVAL, _threshVAL = determine_threshold_on_labeled_patients( [datasetBrainweb], model, options, description='VAL') print( f"Optimal threshold on MS Lesion Validation Set without optimal postprocessing: {_threshVAL} (Dice-Score {_bestDiceVAL})" ) # Re-evaluate with the previously determined threshold evaluate_with_threshold(model, options, _threshVAL, Dataset.Brainweb) evaluate_with_threshold(model, options, _threshVAL, Dataset.MSLUB) evaluate_with_threshold(model, options, _threshVAL, Dataset.MSISBI2015)
from models.constrained_adversarial_autoencoder_Chen import constrained_adversarial_autoencoder_Chen from trainers.ConstrainedAAE import ConstrainedAAE from utils import Evaluation from utils.default_config_setup import get_config, get_options, get_datasets, Dataset tf.reset_default_graph() dataset = Dataset.BRAINWEB options = get_options(batchsize=8, learningrate=0.001, numEpochs=1, zDim=128, outputWidth=128, outputHeight=128) options['data']['dir'] = options["globals"][dataset.value] datasetHC, datasetPC = get_datasets(options, dataset=dataset) config = get_config(trainer=ConstrainedAAE, options=options, optimizer='ADAM', intermediateResolutions=[16, 16], dropout_rate=0.1, dataset=datasetHC) config.kappa = 1.0 config.scale = 10.0 config.rho = 1.0 # Create an instance of the model and train it model = ConstrainedAAE(tf.Session(), config, network=constrained_adversarial_autoencoder_Chen)
def get_evaluation_dataset(options, dataset=Dataset.BRAINWEB): options['data']['dir'] = options["globals"][dataset.value] return get_datasets(options, dataset=dataset)[1]