def val_dataloader(self): if self.hparams.dataset == 'miniimagenet': transform_val = transforms.Compose([ transforms.RandomResizedCrop(128), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # transforms.Normalize(self.mean, self.std)]) dataset = MiniImagenet(root=self.hparams.data_dir, train=False, test=True, transform=transform_val) elif self.hparams.dataset == 'miniimagenetgenerated': transform_val = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # transforms.Normalize(self.mean, self.std)]) dataset = MiniImagenet(root=self.hparams.data_dir, train=False, test=True, transform=transform_val, dataset=self.hparams.dataset) dataloader = DataLoader(dataset, batch_size=self.hparams.batch_size, num_workers=4, pin_memory=True, shuffle=False) return dataloader
def main(args): writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}'.format(args.output_folder) if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']: transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) if args.dataset == 'mnist': # Define the train & test datasets train_dataset = datasets.MNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.MNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'fashion-mnist': # Define the train & test datasets train_dataset = datasets.FashionMNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.FashionMNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'cifar10': # Define the train & test datasets train_dataset = datasets.CIFAR10(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.CIFAR10(args.data_folder, train=False, transform=transform) num_channels = 3 valid_dataset = test_dataset elif args.dataset == 'miniimagenet': transform = transforms.Compose([ transforms.RandomResizedCrop(32), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Define the train, valid & test datasets train_dataset = MiniImagenet(args.data_folder, train=True, download=True, transform=transform) valid_dataset = MiniImagenet(args.data_folder, valid=True, download=True, transform=transform) test_dataset = MiniImagenet(args.data_folder, test=True, download=True, transform=transform) num_channels = 3 # Define the data loaders train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True) # Fixed images for Tensorboard fixed_images, _ = next(iter(test_loader)) fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True) writer.add_image('original', fixed_grid, 0) model = VectorQuantizedVAE(num_channels, args.hidden_size, args.k).to(args.device) if args.ckp != "": model.load_state_dict(torch.load(args.ckp)) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) if args.tmodel != '': net = vgg.VGG('VGG19') net = net.to(args.device) net = torch.nn.DataParallel(net) checkpoint = torch.load(args.tmodel) net.load_state_dict(checkpoint['net']) target_model = net # Generate the samples first once reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('reconstruction', grid, 0) best_loss = -1. for epoch in range(args.num_epochs): print(epoch) # if epoch<100: # args.lr = 1e-5 # if epoch>100 and epoch< 400: # args.lr = 2e-5 train(train_loader, model, target_model, optimizer, args, writer) loss, _ = test(valid_loader, model, args, writer) print("test loss:", loss) reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('reconstruction', grid, epoch + 1) if (epoch == 0) or (loss < best_loss): best_loss = loss with open('{0}/best.pt'.format(save_filename), 'wb') as f: torch.save(model.state_dict(), f) with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1), 'wb') as f: torch.save(model.state_dict(), f)
def main(args): writer = SummaryWriter('./logs_vae/{0}'.format(args.output_folder)) save_filename = './models_vae/{0}'.format(args.output_folder) if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']: transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ]) if args.dataset == 'mnist': # Define the train & test datasets train_dataset = datasets.MNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.MNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'fashion-mnist': transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Define the train & test datasets train_dataset = datasets.FashionMNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.FashionMNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'cifar10': transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Define the train & test datasets train_dataset = datasets.CIFAR10(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.CIFAR10(args.data_folder, train=False, transform=transform) num_channels = 3 valid_dataset = test_dataset elif args.dataset == 'miniimagenet': transform = transforms.Compose([ transforms.RandomResizedCrop(128), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Define the train, valid & test datasets train_dataset = MiniImagenet(args.data_folder, train=True, download=True, transform=transform) valid_dataset = MiniImagenet(args.data_folder, valid=True, download=True, transform=transform) test_dataset = MiniImagenet(args.data_folder, test=True, download=True, transform=transform) num_channels = 3 # Define the data loaders train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True) # Fixed images for Tensorboard fixed_images, _ = next(iter(test_loader)) fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True) writer.add_image('original', fixed_grid, 0) model = VAE(num_channels, args.hidden_size, args.z).to(args.device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, amsgrad=True) # Generate the samples first once reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('reconstruction', grid, 0) best_loss = -1. for epoch in range(args.num_epochs): train(epoch, train_loader, model, optimizer, args, writer) loss = test(valid_loader, model, args, writer) reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('reconstruction', grid, epoch + 1) if (epoch == 0) or (loss < best_loss): best_loss = loss with open('{0}/best.pt'.format(save_filename), 'wb') as f: torch.save(model.state_dict(), f) with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1), 'wb') as f: torch.save(model.state_dict(), f)
def main(args): writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}/prior.pt'.format(args.output_folder) if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']: transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) if args.dataset == 'mnist': # Define the train & test datasets train_dataset = datasets.MNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.MNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'fashion-mnist': # Define the train & test datasets train_dataset = datasets.FashionMNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.FashionMNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'cifar10': # Define the train & test datasets train_dataset = datasets.CIFAR10(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.CIFAR10(args.data_folder, train=False, transform=transform) num_channels = 3 valid_dataset = test_dataset elif args.dataset == 'miniimagenet': transform = transforms.Compose([ transforms.RandomResizedCrop(32), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Define the train, valid & test datasets train_dataset = MiniImagenet(args.data_folder, train=True, download=True, transform=transform) valid_dataset = MiniImagenet(args.data_folder, valid=True, download=True, transform=transform) test_dataset = MiniImagenet(args.data_folder, test=True, download=True, transform=transform) num_channels = 3 # Define the data loaders train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True) # Save the label encoder # with open('./models/{0}/labels.json'.format(args.output_folder), 'w') as f: # json.dump(train_dataset._label_encoder, f) # Fixed images for Tensorboard fixed_images, _ = next(iter(test_loader)) fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True) writer.add_image('original', fixed_grid, 0) model = VectorQuantizedVAE(num_channels, args.hidden_size_vae, args.k).to(args.device) with open(args.model, 'rb') as f: state_dict = torch.load(f) model.load_state_dict(state_dict) model.eval() prior = GatedPixelCNN(args.k, args.hidden_size_prior, args.num_layers, n_classes=32).to(args.device) # args.num_layers, n_classes=len(train_dataset._label_encoder)).to(args.device) optimizer = torch.optim.Adam(prior.parameters(), lr=args.lr) best_loss = -1. for epoch in range(args.num_epochs): print(epoch) train(train_loader, model, prior, optimizer, args, writer) # The validation loss is not properly computed since # the classes in the train and valid splits of Mini-Imagenet # do not overlap. loss = test(valid_loader, model, prior, args, writer) if (epoch == 0) or (loss < best_loss): best_loss = loss with open(save_filename, 'wb') as f: torch.save(prior.state_dict(), f)
def main(args): writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}/prior.pt'.format(args.output_folder) if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']: transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ]) if args.dataset == 'mnist': # Define the train & test datasets train_dataset = datasets.MNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.MNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'fashion-mnist': # Define the train & test datasets train_dataset = datasets.FashionMNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.FashionMNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'cifar10': # Define the train & test datasets train_dataset = datasets.CIFAR10(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.CIFAR10(args.data_folder, train=False, transform=transform) num_channels = 3 valid_dataset = test_dataset elif args.dataset == 'miniimagenet': transform = transforms.Compose([ transforms.RandomResizedCrop(128), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Define the train, valid & test datasets train_dataset = MiniImagenet(args.data_folder, train=True, download=True, transform=transform) valid_dataset = MiniImagenet(args.data_folder, valid=True, download=True, transform=transform) test_dataset = MiniImagenet(args.data_folder, test=True, download=True, transform=transform) num_channels = 3 # Define the data loaders train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True) # Save the label encoder # with open('./models/{0}/labels.json'.format(args.output_folder), 'w') as f: # json.dump(train_dataset._label_encoder, f) # Fixed images for Tensorboard fixed_images, _ = next(iter(test_loader)) fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True) writer.add_image('original', fixed_grid, 0) model = VectorQuantizedVAE(num_channels, args.hidden_size_vae, args.k).to(args.device) with open(args.model, 'rb') as f: state_dict = torch.load(f) model.load_state_dict(state_dict) model.eval() prior = GatedPixelCNN(args.k, args.hidden_size_prior, args.num_layers, n_classes=10).to(args.device) state_dict = torch.load('models/mnist-10-prior/prior.pt') prior.load_state_dict(state_dict) # label_ = [int(i%10) for i in range(16)] # label_ = torch.tensor(label_).cuda() # reconstruction = model.decode(prior.generate(label=label_, batch_size=16,sample_index=10)) # grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) # writer.add_image('Partial sampling result', grid, 0) for epoch in range(args.num_epochs): label_ = [int(i % 10) for i in range(16)] label_ = torch.tensor(label_).cuda() reconstruction = model.decode( prior.generate(label=label_, batch_size=16, sample_index=20)) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('Full sampling result', grid, epoch + 1)
def main(args): writer = SummaryWriter("./logs/{0}".format(args.output_folder)) save_filename = "./models/{0}".format(args.output_folder) if args.dataset in ["mnist", "fashion-mnist", "cifar10"]: transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] ) if args.dataset == "mnist": # Define the train & test datasets train_dataset = datasets.MNIST( args.data_folder, train=True, download=True, transform=transform ) test_dataset = datasets.MNIST( args.data_folder, train=False, transform=transform ) num_channels = 1 elif args.dataset == "fashion-mnist": # Define the train & test datasets train_dataset = datasets.FashionMNIST( args.data_folder, train=True, download=True, transform=transform ) test_dataset = datasets.FashionMNIST( args.data_folder, train=False, transform=transform ) num_channels = 1 elif args.dataset == "cifar10": # Define the train & test datasets train_dataset = datasets.CIFAR10( args.data_folder, train=True, download=True, transform=transform ) test_dataset = datasets.CIFAR10( args.data_folder, train=False, transform=transform ) num_channels = 3 valid_dataset = test_dataset elif args.dataset == "miniimagenet": transform = transforms.Compose( [ transforms.RandomResizedCrop(128), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] ) # Define the train, valid & test datasets train_dataset = MiniImagenet( args.data_folder, train=True, download=True, transform=transform ) valid_dataset = MiniImagenet( args.data_folder, valid=True, download=True, transform=transform ) test_dataset = MiniImagenet( args.data_folder, test=True, download=True, transform=transform ) num_channels = 3 else: transform = transforms.Compose( [ transforms.RandomResizedCrop(args.image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] ) # Define the train, valid & test datasets train_dataset = ImageFolder( os.path.join(args.data_folder, "train"), transform=transform ) valid_dataset = ImageFolder( os.path.join(args.data_folder, "val"), transform=transform ) test_dataset = valid_dataset num_channels = 3 train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True, ) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers, pin_memory=True, ) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True) # Fixed images for Tensorboard fixed_images, _ = next(iter(test_loader)) fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True) save_image(fixed_grid, "true.png") writer.add_image("original", fixed_grid, 0) model = VectorQuantizedVAE(num_channels, args.hidden_size, args.k).to(args.device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # Generate the samples first once reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) save_image(grid, "rec.png") writer.add_image("reconstruction", grid, 0) best_loss = -1 for epoch in range(args.num_epochs): train(train_loader, model, optimizer, args, writer) loss, _ = test(valid_loader, model, args, writer) print(epoch, "test loss: ", loss) reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) save_image(grid, "rec.png") writer.add_image("reconstruction", grid, epoch + 1) if (epoch == 0) or (loss < best_loss): best_loss = loss with open("{0}/best.pt".format(save_filename), "wb") as f: torch.save(model.state_dict(), f) with open("{0}/model_{1}.pt".format(save_filename, epoch + 1), "wb") as f: torch.save(model.state_dict(), f)
def main(args): if args.dataset == 'miniimagenet': readable_labels = { 38: 'organ', 42: 'prayer_rug', 31: 'file', 61: 'cliff', 58: 'consomme', 59: 'hotdog', 21: 'aircraft_carrier', 14: 'French_bulldog', 28: 'cocktail_shaker', 63: 'ear', 3: 'green_mamba', 4: 'harvestman', 17: 'Arctic_fox', 32: 'fire_screen', 11: 'komondor', 43: 'reel', 18: 'ladybug', 45: 'snorkel', 24: 'beer_bottle', 36: 'lipstick', 5: 'toucan', 0: 'house_finch', 16: 'miniature_poodle', 50: 'tile_roof', 15: 'Newfoundland', 46: 'solar_dish', 10: 'Gordon_setter', 7: 'dugong', 52: 'unicycle', 20: 'rock_beauty', 48: 'stage', 22: 'ashcan', 34: 'hair_slide', 30: 'dome', 13: 'Tibetan_mastiff', 53: 'upright', 62: 'bolete', 2: 'triceratops', 40: 'pencil_box', 26: 'chime', 47: 'spider_web', 51: 'tobacco_shop', 60: 'orange', 49: 'tank', 8: 'Walker_hound', 23: 'barrel', 6: 'jellyfish', 33: 'frying_pan', 9: 'Saluki', 37: 'oboe', 1: 'robin', 19: 'three-toed_sloth', 39: 'parallel_bars', 55: 'worm_fence', 27: 'clog', 41: 'photocopier', 25: 'carousel', 29: 'dishrag', 57: 'street_sign', 35: 'holster', 12: 'boxer', 56: 'yawl', 54: 'wok', 44: 'slot' } elif args.dataset == 'cifar10': readable_labels = { 0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer', 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck' } writer = SummaryWriter('./VQVAE/logs/{0}'.format(args.output_folder)) save_filename = './VQVAE/models/{0}/prior.pt'.format(args.output_folder) if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']: transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) if args.dataset == 'mnist': # Define the train & test datasets train_dataset = datasets.MNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.MNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'fashion-mnist': # Define the train & test datasets train_dataset = datasets.FashionMNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.FashionMNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'cifar10': # Define the train & test datasets train_dataset = datasets.CIFAR10(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.CIFAR10(args.data_folder, train=False, transform=transform) num_channels = 3 valid_dataset = test_dataset elif args.dataset == 'miniimagenet': transform = transforms.Compose([ transforms.RandomResizedCrop(128), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Define the train, valid & test datasets train_dataset = MiniImagenet(args.data_folder, train=True, download=True, transform=transform) valid_dataset = MiniImagenet(args.data_folder, valid=True, download=True, transform=transform) test_dataset = MiniImagenet(args.data_folder, test=True, download=True, transform=transform) num_channels = 3 # Define the data loaders train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True) # Fixed images for Tensorboard fixed_images, _ = next(iter(test_loader)) fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True) writer.add_image('original', fixed_grid, 0) model = VectorQuantizedVAE(num_channels, args.hidden_size_vae, args.k).to(args.device) with open(args.model, 'rb') as f: state_dict = torch.load(f) model.load_state_dict(state_dict) model.eval() if args.dataset == 'miniimagenet': print("number of training classes:", len(train_dataset._label_encoder)) n_classes = len(train_dataset._label_encoder) shape = (32, 32) yrange = range(n_classes) csv_filename = 'miniimagenet_generated.csv' sample_size = 25 elif args.dataset == 'cifar10': print("number of training classes:", 10) n_classes = 10 shape = (8, 8) yrange = range(n_classes) csv_filename = 'cifar10_generated.csv' sample_size = 1000 prior = GatedPixelCNN(args.k, args.hidden_size_prior, args.num_layers, n_classes=n_classes).to(args.device) with open(args.prior, 'rb') as f: state_dict = torch.load(f) prior.load_state_dict(state_dict) prior.eval() # maximum number of kept dimensions max_num_dst = 0 max_num_sparsemax = 0 with torch.no_grad(): f = open( './VQVAE/models/{0}/{1}'.format(args.output_folder, csv_filename), 'w') with f: writer = csv.writer(f) writer.writerow(['filename', 'label']) for y in tqdm(yrange): label = torch.tensor([y]) label = label.to(args.device) z = prior.generate(label=label, shape=shape, batch_size=sample_size) x = model.decode(z) for im in range(x.shape[0]): save_image(x.cpu()[im], './data/{0}/dataset_softmax/{1}_{2}.jpg'.format( args.output_folder, str(y).zfill(2), str(im).zfill(3)), range=(-1, 1), normalize=True) if args.dataset == 'miniimagenet': y_str = list(train_dataset._label_encoder.keys())[list( train_dataset._label_encoder.values()).index(y)] writer.writerow([ '{0}_{1}.jpg'.format( str(y).zfill(2), str(im).zfill(3)), str(y_str) ]) elif args.dataset == 'cifar10': writer.writerow([ '{0}_{1}.jpg'.format( str(y).zfill(2), str(im).zfill(3)), str(y) ]) z, num_dst = prior.generate_dst(label=label, shape=shape, batch_size=sample_size) x = model.decode(z) if num_dst > max_num_dst: max_num_dst = num_dst for im in range(x.shape[0]): save_image(x.cpu()[im], './data/{0}/dataset_dst/{1}_{2}.jpg'.format( args.output_folder, str(y).zfill(2), str(im).zfill(3)), range=(-1, 1), normalize=True) z, num_sparsemax = prior.generate_sparsemax( label=label, shape=shape, batch_size=sample_size) x = model.decode(z) if num_sparsemax > max_num_sparsemax: max_num_sparsemax = num_sparsemax for im in range(x.shape[0]): save_image( x.cpu()[im], './data/{0}/dataset_sparsemax/{1}_{2}.jpg'.format( args.output_folder, str(y).zfill(2), str(im).zfill(3)), range=(-1, 1), normalize=True) pkl.dump([max_num_dst, max_num_sparsemax], open( "max_num_" + str(args.dataset) + "_" + str(sample_size) + "_correct_dst.pkl", "wb"))
def main(args): writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}'.format(args.output_folder) if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']: transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) if args.dataset == 'mnist': print(" Define the train & test datasets") train_dataset = datasets.MNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.MNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'fashion-mnist': # Define the train & test datasets train_dataset = datasets.FashionMNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.FashionMNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'cifar10': # Define the train & test datasets train_dataset = datasets.CIFAR10(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.CIFAR10(args.data_folder, train=False, transform=transform) num_channels = 3 valid_dataset = test_dataset elif args.dataset == 'miniimagenet': transform = transforms.Compose([ transforms.RandomResizedCrop(128), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Define the train, valid & test datasets train_dataset = MiniImagenet(args.data_folder, train=True, download=True, transform=transform) valid_dataset = MiniImagenet(args.data_folder, valid=True, download=True, transform=transform) test_dataset = MiniImagenet(args.data_folder, test=True, download=True, transform=transform) num_channels = 3 elif args.dataset == 'gameRuns': print(" GameRuns Define the train, valid & test datasets") train_dataset = GameRuns(folder = args.data_folder, filename = 'concatAllTrain.hdf5') valid_dataset = GameRuns(folder = args.data_folder, filename = 'concatAllValid.hdf5') test_dataset = GameRuns(folder = args.data_folder, filename = 'concatAllTest.hdf5') num_channels = 3 print("Define the data loaders") train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True) print("Fixed images for Tensorboard .") fixed_images, _ = next(iter(test_loader)) print("Building Grid") fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True) writer.add_image('original', fixed_grid, 0) print("Building Model") # model = VectorQuantizedVAE(num_channels, args.hidden_size, args.k).to(args.device) model = VQVAE_res16(num_channels, args.hidden_size, args.k).to(args.device) # model = VQVAE_res8(num_channels, args.hidden_size, args.k).to(args.device) print("Model :") print(model) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) print("Generate the samples first once") reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('reconstruction', grid, 0) best_loss = -1. print("Begin training") for epoch in range(args.num_epochs): train(train_loader, model, optimizer, args, writer) loss, _ = test(valid_loader, model, args, writer) reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('reconstruction', grid, epoch + 1) print("Epoch ",epoch," Loss : ", loss) if (epoch == 0) or (loss < best_loss): best_loss = loss with open('{0}/best.pt'.format(save_filename), 'wb') as f: torch.save(model.state_dict(), f) with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1), 'wb') as f: torch.save(model.state_dict(), f)