Пример #1
0
#!/usr/bin/env python
import tensorflow as tf

from models.constrained_adversarial_autoencoder_Chen import constrained_adversarial_autoencoder_Chen
from trainers.ConstrainedAAE import ConstrainedAAE
from utils import Evaluation
from utils.default_config_setup import get_config, get_options, get_datasets, Dataset

tf.reset_default_graph()
dataset = Dataset.BRAINWEB
options = get_options(batchsize=8,
                      learningrate=0.001,
                      numEpochs=1,
                      zDim=128,
                      outputWidth=128,
                      outputHeight=128)
options['data']['dir'] = options["globals"][dataset.value]
datasetHC, datasetPC = get_datasets(options, dataset=dataset)
config = get_config(trainer=ConstrainedAAE,
                    options=options,
                    optimizer='ADAM',
                    intermediateResolutions=[16, 16],
                    dropout_rate=0.1,
                    dataset=datasetHC)

config.kappa = 1.0
config.scale = 10.0
config.rho = 1.0

# Create an instance of the model and train it
model = ConstrainedAAE(tf.Session(),
Пример #2
0
Файл: run.py Проект: irfixq/AE
def main(args):
    # reset default graph
    tf.reset_default_graph()
    base_path_trainer = os.path.join(
        os.path.dirname(os.path.abspath(__file__)), 'trainers',
        f'{args.trainer}.py')
    base_path_network = os.path.join(
        os.path.dirname(os.path.abspath(__file__)), 'models',
        f'{args.model}.py')
    trainer = getattr(
        SourceFileLoader(args.trainer, base_path_trainer).load_module(),
        args.trainer)
    network = getattr(
        SourceFileLoader(args.model, base_path_network).load_module(),
        args.model)

    with open(os.path.join(base_path, args.config), 'r') as f:
        json_config = json.load(f)

    dataset = Dataset.BRAINWEB
    options = get_options(batchsize=args.batchsize,
                          learningrate=args.lr,
                          numEpochs=args.numEpochs,
                          zDim=args.zDim,
                          outputWidth=args.outputWidth,
                          outputHeight=args.outputHeight,
                          slices_start=args.slices_start,
                          slices_end=args.slices_end,
                          numMonteCarloSamples=args.numMonteCarloSamples,
                          config=json_config)
    options['data']['dir'] = options["globals"][dataset.value]
    dataset_hc, dataset_pc = get_datasets(options, dataset=dataset)
    config = get_config(trainer=trainer,
                        options=options,
                        optimizer=args.optimizer,
                        intermediateResolutions=args.intermediateResolutions,
                        dropout_rate=0.2,
                        dataset=dataset_hc)

    # handle additional Config parameters e.g. for GMVAE
    for arg in vars(args):
        if hasattr(config, arg):
            setattr(config, arg, getattr(args, arg))

    # Create an instance of the model and train it
    model = trainer(tf.Session(), config, network=network)

    # Train it
    model.train(dataset_hc)

    ########################
    #  Evaluate best dice  #
    #########################
    if not args.threshold:
        # if no threshold is given but a dataset => Best dice evaluation on specific dataset
        if args.ds:
            # evaluate specific dataset
            evaluate_optimal(model, options, args.ds)
            return
        else:
            # evaluate all datasets for best dice without hyper intensity prior
            options['applyHyperIntensityPrior'] = False
            evaluate_optimal(model, options, Dataset.Brainweb)
            evaluate_optimal(model, options, Dataset.MSLUB)
            evaluate_optimal(model, options, Dataset.MSISBI2015)

            # evaluate all datasets for best dice without hyper intensity prior
            options['applyHyperIntensityPrior'] = True
            evaluate_optimal(model, options, Dataset.Brainweb)
            evaluate_optimal(model, options, Dataset.MSLUB)
            evaluate_optimal(model, options, Dataset.MSISBI2015)

    ###############################################
    #  Evaluate generalization to other datasets  #
    ###############################################
    if args.threshold and args.ds:  # only threshold is invalid
        evaluate_with_threshold(model, options, args.threshold, args.ds)
    else:
        options['applyHyperIntensityPrior'] = False
        datasetBrainweb = get_evaluation_dataset(options, Dataset.Brainweb)
        _bestDiceVAL, _threshVAL = determine_threshold_on_labeled_patients(
            [datasetBrainweb], model, options, description='VAL')

        print(
            f"Optimal threshold on MS Lesion Validation Set without optimal postprocessing: {_threshVAL} (Dice-Score {_bestDiceVAL})"
        )

        # Re-evaluate with the previously determined threshold
        evaluate_with_threshold(model, options, _threshVAL, Dataset.Brainweb)
        evaluate_with_threshold(model, options, _threshVAL, Dataset.MSLUB)
        evaluate_with_threshold(model, options, _threshVAL, Dataset.MSISBI2015)