from __future__ import print_function import os,sys,inspect currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) parentdir = os.path.dirname(os.path.dirname(currentdir)) sys.path.insert(0,parentdir) from utils.args import args import setup.categories.ae_setup as AESetup from models.autoencoders import * from datasets.PADChest import PADChestBinaryTrainSplit if __name__ == "__main__": dataset = PADChestBinaryTrainSplit(root_path=os.path.join(args.root_path, "PADChest"), binary=True, expand_channels=False, downsample=64) model = Generic_VAE(dims=(1, 64, 64), max_channels=512, depth=12, n_hidden=512) #model = ALILikeAE(dims=(1, 64, 64)) AESetup.train_variational_autoencoder(args, model=model, dataset=dataset.get_D1_train(), BCE_Loss=False)
from __future__ import print_function import os, sys, inspect currentdir = os.path.dirname( os.path.abspath(inspect.getfile(inspect.currentframe()))) parentdir = os.path.dirname(os.path.dirname(currentdir)) sys.path.insert(0, parentdir) from utils.args import args import setup.categories.ae_setup as AESetup from models.autoencoders import * from datasets.PADChest import PADChestBinaryTrainSplit if __name__ == "__main__": dataset = PADChestBinaryTrainSplit(root_path=os.path.join( args.root_path, "PADChest"), binary=True, expand_channels=False, downsample=64) model = Generic_AE(dims=(1, 64, 64), max_channels=512, depth=12, n_hidden=512) #model = ALILikeAE(dims=(1, 64, 64)) AESetup.train_autoencoder(args, model=model, dataset=dataset.get_D1_train(), BCE_Loss=True)
'vaemseaeknn/8', 'vaebceaeknn/8', 'mseaeknn/8', 'bceaeknn/8', 'alivaemseaeknn/1', 'alivaebceaeknn/1', 'alimseaeknn/1', 'alibceaeknn/1', 'alivaemseaeknn/8', 'alivaebceaeknn/8', 'alimseaeknn/8', 'alibceaeknn/8', ] D1 = PADChestBinaryTrainSplit(root_path=os.path.join( args.root_path, 'PADChest'), binary=True) D164 = PADChestBinaryTrainSplit(root_path=os.path.join( args.root_path, "PADChest"), binary=True, downsample=64) args.D1 = 'PADChest' All_ODs = [ 'UniformNoise', 'NormalNoise', 'MNIST', 'FashionMNIST', 'NotMNIST', 'CIFAR100', 'CIFAR10',