コード例 #1
0
ファイル: main_constrainedAAE_Chen.py プロジェクト: irfixq/AE
from utils import Evaluation
from utils.default_config_setup import get_config, get_options, get_datasets, Dataset

tf.reset_default_graph()
dataset = Dataset.BRAINWEB
options = get_options(batchsize=8,
                      learningrate=0.001,
                      numEpochs=1,
                      zDim=128,
                      outputWidth=128,
                      outputHeight=128)
options['data']['dir'] = options["globals"][dataset.value]
datasetHC, datasetPC = get_datasets(options, dataset=dataset)
config = get_config(trainer=ConstrainedAAE,
                    options=options,
                    optimizer='ADAM',
                    intermediateResolutions=[16, 16],
                    dropout_rate=0.1,
                    dataset=datasetHC)

config.kappa = 1.0
config.scale = 10.0
config.rho = 1.0

# Create an instance of the model and train it
model = ConstrainedAAE(tf.Session(),
                       config,
                       network=constrained_adversarial_autoencoder_Chen)

# Train it
model.train(datasetHC)
コード例 #2
0
ファイル: main_GMVAE_You.py プロジェクト: irfixq/AE
#!/usr/bin/env python
import tensorflow as tf

from models.gaussian_mixture_variational_autoencoder_You import gaussian_mixture_variational_autoencoder_You
from trainers.GMVAE_spatial import GMVAE_spatial
from utils import Evaluation
from utils.default_config_setup import get_config, get_options, get_datasets, Dataset

tf.reset_default_graph()
dataset = Dataset.BRAINWEB
options = get_options(batchsize=8, learningrate=5e-5, numEpochs=1, zDim=128, outputWidth=128, outputHeight=128)
options['data']['dir'] = options["globals"][dataset.value]
datasetHC, datasetPC = get_datasets(options, dataset=dataset)
config = get_config(trainer=GMVAE_spatial, options=options, optimizer='ADAM', intermediateResolutions=[8, 8], dropout_rate=0.1, dataset=datasetHC)

config.dim_c = 9
config.dim_z = 1
config.dim_w = 1
config.c_lambda = 1
config.restore_lr = 1e-3
config.restore_steps = 0
config.tv_lambda = -1.0

# Create an instance of the model and train it
model = GMVAE_spatial(tf.Session(), config, network=gaussian_mixture_variational_autoencoder_You)

# Train it
model.train(datasetHC)

# Evaluate
Evaluation.evaluate(datasetPC, model, options, description=f"{type(datasetHC).__name__}-{options['threshold']}", epoch=str(options['train']['numEpochs']))
コード例 #3
0
from utils import Evaluation
from utils.default_config_setup import get_config, get_options, get_datasets, Dataset

tf.reset_default_graph()
dataset = Dataset.BRAINWEB
options = get_options(batchsize=128,
                      learningrate=0.0001,
                      numEpochs=2,
                      zDim=128,
                      outputWidth=128,
                      outputHeight=128)
options['data']['dir'] = options["globals"][dataset.value]
datasetHC, datasetPC = get_datasets(options, dataset=dataset)
config = get_config(trainer=AE,
                    options=options,
                    optimizer='ADAM',
                    intermediateResolutions=[8, 8],
                    dropout_rate=0.2,
                    dataset=datasetHC)

# Create an instance of the model and train it
model = AE(tf.Session(), config, network=autoencoder.autoencoder)

# Train it
model.train(datasetHC)

# Evaluate
Evaluation.evaluate(
    datasetPC,
    model,
    options,
    description=f"{type(datasetHC).__name__}-{options['threshold']}",
コード例 #4
0
ファイル: run.py プロジェクト: irfixq/AE
def main(args):
    # reset default graph
    tf.reset_default_graph()
    base_path_trainer = os.path.join(
        os.path.dirname(os.path.abspath(__file__)), 'trainers',
        f'{args.trainer}.py')
    base_path_network = os.path.join(
        os.path.dirname(os.path.abspath(__file__)), 'models',
        f'{args.model}.py')
    trainer = getattr(
        SourceFileLoader(args.trainer, base_path_trainer).load_module(),
        args.trainer)
    network = getattr(
        SourceFileLoader(args.model, base_path_network).load_module(),
        args.model)

    with open(os.path.join(base_path, args.config), 'r') as f:
        json_config = json.load(f)

    dataset = Dataset.BRAINWEB
    options = get_options(batchsize=args.batchsize,
                          learningrate=args.lr,
                          numEpochs=args.numEpochs,
                          zDim=args.zDim,
                          outputWidth=args.outputWidth,
                          outputHeight=args.outputHeight,
                          slices_start=args.slices_start,
                          slices_end=args.slices_end,
                          numMonteCarloSamples=args.numMonteCarloSamples,
                          config=json_config)
    options['data']['dir'] = options["globals"][dataset.value]
    dataset_hc, dataset_pc = get_datasets(options, dataset=dataset)
    config = get_config(trainer=trainer,
                        options=options,
                        optimizer=args.optimizer,
                        intermediateResolutions=args.intermediateResolutions,
                        dropout_rate=0.2,
                        dataset=dataset_hc)

    # handle additional Config parameters e.g. for GMVAE
    for arg in vars(args):
        if hasattr(config, arg):
            setattr(config, arg, getattr(args, arg))

    # Create an instance of the model and train it
    model = trainer(tf.Session(), config, network=network)

    # Train it
    model.train(dataset_hc)

    ########################
    #  Evaluate best dice  #
    #########################
    if not args.threshold:
        # if no threshold is given but a dataset => Best dice evaluation on specific dataset
        if args.ds:
            # evaluate specific dataset
            evaluate_optimal(model, options, args.ds)
            return
        else:
            # evaluate all datasets for best dice without hyper intensity prior
            options['applyHyperIntensityPrior'] = False
            evaluate_optimal(model, options, Dataset.Brainweb)
            evaluate_optimal(model, options, Dataset.MSLUB)
            evaluate_optimal(model, options, Dataset.MSISBI2015)

            # evaluate all datasets for best dice without hyper intensity prior
            options['applyHyperIntensityPrior'] = True
            evaluate_optimal(model, options, Dataset.Brainweb)
            evaluate_optimal(model, options, Dataset.MSLUB)
            evaluate_optimal(model, options, Dataset.MSISBI2015)

    ###############################################
    #  Evaluate generalization to other datasets  #
    ###############################################
    if args.threshold and args.ds:  # only threshold is invalid
        evaluate_with_threshold(model, options, args.threshold, args.ds)
    else:
        options['applyHyperIntensityPrior'] = False
        datasetBrainweb = get_evaluation_dataset(options, Dataset.Brainweb)
        _bestDiceVAL, _threshVAL = determine_threshold_on_labeled_patients(
            [datasetBrainweb], model, options, description='VAL')

        print(
            f"Optimal threshold on MS Lesion Validation Set without optimal postprocessing: {_threshVAL} (Dice-Score {_bestDiceVAL})"
        )

        # Re-evaluate with the previously determined threshold
        evaluate_with_threshold(model, options, _threshVAL, Dataset.Brainweb)
        evaluate_with_threshold(model, options, _threshVAL, Dataset.MSLUB)
        evaluate_with_threshold(model, options, _threshVAL, Dataset.MSISBI2015)