def main(): model = None model = WGAN_GP(opt, device) print("using WGAN_GP model") normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # cudnn.benchmark = True # preparing the training laoder train_loader = torch.utils.data.DataLoader( ImageLoader( opt.img_path, transforms.Compose([ transforms.Scale( 128 ), # rescale the image keeping the original aspect ratio transforms.CenterCrop( 128), # we get only the center of that rescaled transforms.RandomCrop( 128), # random crop within the center crop transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]), data_path=opt.data_path, partition='train'), batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers, pin_memory=True) print('Training loader prepared.') # preparing validation loader val_loader = torch.utils.data.DataLoader( ImageLoader( opt.img_path, transforms.Compose([ transforms.Scale( 128 ), # rescale the image keeping the original aspect ratio transforms.CenterCrop( 128), # we get only the center of that rescaled transforms.ToTensor(), normalize, ]), data_path=opt.data_path, partition='val'), batch_size=opt.batch_size, shuffle=False, num_workers=opt.workers, pin_memory=True) print('Validation loader prepared.') # Start model training if opt.is_train == 'True': model.train(train_loader) else: print("Done!")
def main(args): model = None if args.model == 'GAN': model = GAN(args) elif args.model == 'DCGAN': model = DCGAN_MODEL(args) elif args.model == 'WGAN-CP': model = WGAN_CP(args) elif args.model == 'WGAN-GP': model = WGAN_GP(args) else: print("Model type non-existing. Try again.") exit(-1) # Load datasets to train and test loaders train_loader, test_loader = get_data_loader(args) # feature_extraction = FeatureExtractionTest(train_loader, test_loader, args.cuda, args.batch_size) # Start model training if args.is_train == 'True': model.train(train_loader) # start evaluating on test data else: model.evaluate(test_loader, args.load_D, args.load_G)