def main(args): model = None if args.model == 'GAN': model = GAN(args) elif args.model == 'DCGAN': model = DCGAN_MODEL(args) elif args.model == 'WGAN-CP': model = WGAN_CP(args) elif args.model == 'WGAN-GP': model = WGAN_GP(args) else: print("Model type non-existing. Try again.") exit(-1) # Load datasets to train and test loaders train_loader, test_loader = get_data_loader(args) # feature_extraction = FeatureExtractionTest(train_loader, test_loader, args.cuda, args.batch_size) # Start model training if args.is_train == 'True': model.train(train_loader) # start evaluating on test data else: model.evaluate(test_loader, args.load_D, args.load_G)
def main(args): #--------------prepare data------------------------ dataset = args.dataset dataroot = args.dataroot if not os.path.exists(dataroot): os.makedirs(dataroot) batch_size = args.batch_size epochs = args.epochs channels = args.channels model = None model_name = args.model if args.model == 'GAN': model = GAN(epochs, batch_size) elif args.model == 'DCGAN': model = DCGAN_MODEL(args) elif args.model == 'WGAN-CP': model = WGAN_CP(args) elif args.model == 'WGAN-GP': model = WGAN_GP(args) else: print("Model type non-existing. Try again.") exit(-1) workers = 0 # number of workers for dataloader, 2 creates problems utils = Utils() train_loader, test_loader = utils.prepare_data(dataroot, batch_size, workers, dataset, model_name, channels) # Start model training resume_training = False if args.resume_training == 'True': resume_training = True if args.is_train == 'True': model.train(train_loader, resume_training) # start evaluating on test data else: model.evaluate(test_loader, args.load_D, args.load_G)