예제 #1
0
from __future__ import print_function
import os,sys,inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(os.path.dirname(currentdir))
sys.path.insert(0,parentdir)

from utils.args import args

import setup.categories.ae_setup as AESetup
from models.autoencoders import *
from datasets.NIH_Chest import NIHChestBinaryTrainSplit

if __name__ == "__main__":
    dataset = NIHChestBinaryTrainSplit(root_path=os.path.join(args.root_path, "NIHCC"), binary=True, expand_channels=False, downsample=64)
    model = Generic_VAE(dims=(1, 64, 64), max_channels=512, depth=12, n_hidden=512)
    #model = ALILikeVAE(dims=(1, 64, 64))
    AESetup.train_variational_autoencoder(args, model=model, dataset=dataset.get_D1_train(), BCE_Loss=False)

sys.path.insert(0, parentdir)

import models as Models
import global_vars as Global
from utils.args import args

import categories.classifier_setup as CLSetup
from models.classifiers import NIHDenseBinary, NIHChestVGG
from datasets.NIH_Chest import NIHChestBinaryTrainSplit

if __name__ == "__main__":
    dataset = NIHChestBinaryTrainSplit(root_path=os.path.join(
        args.root_path, "NIHCC"),
                                       binary=True)
    model = NIHChestVGG()
    CLSetup.train_classifier(args, model=model, dataset=dataset.get_D1_train())

    # task_list = [
    #     # The list of models,   The function that does the training,    Can I skip-test?,   suffix of the operation.
    #     # The procedures that can be skip-test are the ones that we can determine
    #     # whether we have done them before without instantiating the network architecture or dataset.
    #     # saves quite a lot of time when possible.
    #     (Global.dataset_reference_classifiers, CLSetup.train_classifier,            True, ['base0']),
    #     (Global.dataset_reference_classifiers, KLogisticSetup.train_classifier,     True, ['KLogistic']),
    #     (Global.dataset_reference_classifiers, DeepEnsembleSetup.train_classifier,  True, ['DE.%d'%i for i in range(5)]),
    #     (Global.dataset_reference_autoencoders, AESetup.train_BCE_AE,               False, []),
    #     (Global.dataset_reference_autoencoders, AESetup.train_MSE_AE,               False, []),
    #     (Global.dataset_reference_vaes, AESetup.train_variational_autoencoder,      False, []),
    #     (Global.dataset_reference_pcnns, PCNNSetup.train_pixelcnn,                  False, []),
    # ]
    #
예제 #3
0
from __future__ import print_function
import os, sys, inspect

currentdir = os.path.dirname(
    os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(os.path.dirname(currentdir))
sys.path.insert(0, parentdir)

from utils.args import args

import setup.categories.ae_setup as AESetup
from models.autoencoders import Generic_VAE, Generic_AE, Residual_AE
from datasets.NIH_Chest import NIHChestBinaryTrainSplit

if __name__ == "__main__":
    dataset = NIHChestBinaryTrainSplit(root_path=os.path.join(
        args.root_path, "NIHCC"),
                                       binary=True,
                                       expand_channels=False,
                                       downsample=64)
    model = Residual_AE(dims=(1, 64, 64))
    AESetup.train_autoencoder(args,
                              model=model,
                              dataset=dataset.get_D1_train(),
                              BCE_Loss=False)
예제 #4
0
    if not args.load or not os.path.exists(
            os.path.join(
                args.experiment_path, "all_embs_UC3_ppd_%d_d1_%s.npy" %
                (args.points_per_d2, args.dataset))):
        assert args.dataset in ['NIHCC', 'PADChest']
        if args.dataset.lower() == 'nihcc':
            D164 = NIHChestBinaryTrainSplit(root_path=os.path.join(
                args.root_path, 'NIHCC'),
                                            downsample=64)
        elif args.dataset.lower() == "padchest":
            D164 = PADChestBinaryTrainSplit(root_path=os.path.join(
                args.root_path, "PADChest"),
                                            binary=True,
                                            downsample=64)

        D1 = D164.get_D1_train()

        emb = args.embedding_function.lower()
        assert emb in ["vae", "ae", "ali"]
        dummy_args = EasyDict()
        dummy_args.exp = "foo"
        dummy_args.experiment_path = args.experiment_path
        if args.encoder_loss.lower() == "bce":
            tag = "BCE"
        else:
            tag = "MSE"
        if emb == "vae":
            model = Global.dataset_reference_vaes[args.dataset][0]()
            home_path = Models.get_ref_model_path(dummy_args,
                                                  model.__class__.__name__,
                                                  D164.name,