def main(args): writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}/prior.pt'.format(args.output_folder) if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']: transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) if args.dataset == 'mnist': # Define the train & test datasets train_dataset = datasets.MNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.MNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'fashion-mnist': # Define the train & test datasets train_dataset = datasets.FashionMNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.FashionMNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'cifar10': # Define the train & test datasets train_dataset = datasets.CIFAR10(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.CIFAR10(args.data_folder, train=False, transform=transform) num_channels = 3 valid_dataset = test_dataset elif args.dataset == 'miniimagenet': transform = transforms.Compose([ transforms.RandomResizedCrop(32), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Define the train, valid & test datasets train_dataset = MiniImagenet(args.data_folder, train=True, download=True, transform=transform) valid_dataset = MiniImagenet(args.data_folder, valid=True, download=True, transform=transform) test_dataset = MiniImagenet(args.data_folder, test=True, download=True, transform=transform) num_channels = 3 # Define the data loaders train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True) # Save the label encoder # with open('./models/{0}/labels.json'.format(args.output_folder), 'w') as f: # json.dump(train_dataset._label_encoder, f) # Fixed images for Tensorboard fixed_images, _ = next(iter(test_loader)) fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True) writer.add_image('original', fixed_grid, 0) model = VectorQuantizedVAE(num_channels, args.hidden_size_vae, args.k).to(args.device) with open(args.model, 'rb') as f: state_dict = torch.load(f) model.load_state_dict(state_dict) model.eval() prior = GatedPixelCNN(args.k, args.hidden_size_prior, args.num_layers, n_classes=32).to(args.device) # args.num_layers, n_classes=len(train_dataset._label_encoder)).to(args.device) optimizer = torch.optim.Adam(prior.parameters(), lr=args.lr) best_loss = -1. for epoch in range(args.num_epochs): print(epoch) train(train_loader, model, prior, optimizer, args, writer) # The validation loss is not properly computed since # the classes in the train and valid splits of Mini-Imagenet # do not overlap. loss = test(valid_loader, model, prior, args, writer) if (epoch == 0) or (loss < best_loss): best_loss = loss with open(save_filename, 'wb') as f: torch.save(prior.state_dict(), f)
def main(args): if args.dataset == 'miniimagenet': readable_labels = { 38: 'organ', 42: 'prayer_rug', 31: 'file', 61: 'cliff', 58: 'consomme', 59: 'hotdog', 21: 'aircraft_carrier', 14: 'French_bulldog', 28: 'cocktail_shaker', 63: 'ear', 3: 'green_mamba', 4: 'harvestman', 17: 'Arctic_fox', 32: 'fire_screen', 11: 'komondor', 43: 'reel', 18: 'ladybug', 45: 'snorkel', 24: 'beer_bottle', 36: 'lipstick', 5: 'toucan', 0: 'house_finch', 16: 'miniature_poodle', 50: 'tile_roof', 15: 'Newfoundland', 46: 'solar_dish', 10: 'Gordon_setter', 7: 'dugong', 52: 'unicycle', 20: 'rock_beauty', 48: 'stage', 22: 'ashcan', 34: 'hair_slide', 30: 'dome', 13: 'Tibetan_mastiff', 53: 'upright', 62: 'bolete', 2: 'triceratops', 40: 'pencil_box', 26: 'chime', 47: 'spider_web', 51: 'tobacco_shop', 60: 'orange', 49: 'tank', 8: 'Walker_hound', 23: 'barrel', 6: 'jellyfish', 33: 'frying_pan', 9: 'Saluki', 37: 'oboe', 1: 'robin', 19: 'three-toed_sloth', 39: 'parallel_bars', 55: 'worm_fence', 27: 'clog', 41: 'photocopier', 25: 'carousel', 29: 'dishrag', 57: 'street_sign', 35: 'holster', 12: 'boxer', 56: 'yawl', 54: 'wok', 44: 'slot' } elif args.dataset == 'cifar10': readable_labels = { 0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer', 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck' } writer = SummaryWriter('./VQVAE/logs/{0}'.format(args.output_folder)) save_filename = './VQVAE/models/{0}/prior.pt'.format(args.output_folder) if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']: transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) if args.dataset == 'mnist': # Define the train & test datasets train_dataset = datasets.MNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.MNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'fashion-mnist': # Define the train & test datasets train_dataset = datasets.FashionMNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.FashionMNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'cifar10': # Define the train & test datasets train_dataset = datasets.CIFAR10(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.CIFAR10(args.data_folder, train=False, transform=transform) num_channels = 3 valid_dataset = test_dataset elif args.dataset == 'miniimagenet': transform = transforms.Compose([ transforms.RandomResizedCrop(128), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Define the train, valid & test datasets train_dataset = MiniImagenet(args.data_folder, train=True, download=True, transform=transform) valid_dataset = MiniImagenet(args.data_folder, valid=True, download=True, transform=transform) test_dataset = MiniImagenet(args.data_folder, test=True, download=True, transform=transform) num_channels = 3 # Define the data loaders train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True) # Fixed images for Tensorboard fixed_images, _ = next(iter(test_loader)) fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True) writer.add_image('original', fixed_grid, 0) model = VectorQuantizedVAE(num_channels, args.hidden_size_vae, args.k).to(args.device) with open(args.model, 'rb') as f: state_dict = torch.load(f) model.load_state_dict(state_dict) model.eval() if args.dataset == 'miniimagenet': print("number of training classes:", len(train_dataset._label_encoder)) n_classes = len(train_dataset._label_encoder) shape = (32, 32) yrange = range(n_classes) csv_filename = 'miniimagenet_generated.csv' sample_size = 25 elif args.dataset == 'cifar10': print("number of training classes:", 10) n_classes = 10 shape = (8, 8) yrange = range(n_classes) csv_filename = 'cifar10_generated.csv' sample_size = 1000 prior = GatedPixelCNN(args.k, args.hidden_size_prior, args.num_layers, n_classes=n_classes).to(args.device) with open(args.prior, 'rb') as f: state_dict = torch.load(f) prior.load_state_dict(state_dict) prior.eval() # maximum number of kept dimensions max_num_dst = 0 max_num_sparsemax = 0 with torch.no_grad(): f = open( './VQVAE/models/{0}/{1}'.format(args.output_folder, csv_filename), 'w') with f: writer = csv.writer(f) writer.writerow(['filename', 'label']) for y in tqdm(yrange): label = torch.tensor([y]) label = label.to(args.device) z = prior.generate(label=label, shape=shape, batch_size=sample_size) x = model.decode(z) for im in range(x.shape[0]): save_image(x.cpu()[im], './data/{0}/dataset_softmax/{1}_{2}.jpg'.format( args.output_folder, str(y).zfill(2), str(im).zfill(3)), range=(-1, 1), normalize=True) if args.dataset == 'miniimagenet': y_str = list(train_dataset._label_encoder.keys())[list( train_dataset._label_encoder.values()).index(y)] writer.writerow([ '{0}_{1}.jpg'.format( str(y).zfill(2), str(im).zfill(3)), str(y_str) ]) elif args.dataset == 'cifar10': writer.writerow([ '{0}_{1}.jpg'.format( str(y).zfill(2), str(im).zfill(3)), str(y) ]) z, num_dst = prior.generate_dst(label=label, shape=shape, batch_size=sample_size) x = model.decode(z) if num_dst > max_num_dst: max_num_dst = num_dst for im in range(x.shape[0]): save_image(x.cpu()[im], './data/{0}/dataset_dst/{1}_{2}.jpg'.format( args.output_folder, str(y).zfill(2), str(im).zfill(3)), range=(-1, 1), normalize=True) z, num_sparsemax = prior.generate_sparsemax( label=label, shape=shape, batch_size=sample_size) x = model.decode(z) if num_sparsemax > max_num_sparsemax: max_num_sparsemax = num_sparsemax for im in range(x.shape[0]): save_image( x.cpu()[im], './data/{0}/dataset_sparsemax/{1}_{2}.jpg'.format( args.output_folder, str(y).zfill(2), str(im).zfill(3)), range=(-1, 1), normalize=True) pkl.dump([max_num_dst, max_num_sparsemax], open( "max_num_" + str(args.dataset) + "_" + str(sample_size) + "_correct_dst.pkl", "wb"))
def main(args): writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}/prior.pt'.format(args.output_folder) if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']: transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ]) if args.dataset == 'mnist': # Define the train & test datasets train_dataset = datasets.MNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.MNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'fashion-mnist': # Define the train & test datasets train_dataset = datasets.FashionMNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.FashionMNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'cifar10': # Define the train & test datasets train_dataset = datasets.CIFAR10(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.CIFAR10(args.data_folder, train=False, transform=transform) num_channels = 3 valid_dataset = test_dataset elif args.dataset == 'miniimagenet': transform = transforms.Compose([ transforms.RandomResizedCrop(128), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Define the train, valid & test datasets train_dataset = MiniImagenet(args.data_folder, train=True, download=True, transform=transform) valid_dataset = MiniImagenet(args.data_folder, valid=True, download=True, transform=transform) test_dataset = MiniImagenet(args.data_folder, test=True, download=True, transform=transform) num_channels = 3 # Define the data loaders train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True) # Save the label encoder # with open('./models/{0}/labels.json'.format(args.output_folder), 'w') as f: # json.dump(train_dataset._label_encoder, f) # Fixed images for Tensorboard fixed_images, _ = next(iter(test_loader)) fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True) writer.add_image('original', fixed_grid, 0) model = VectorQuantizedVAE(num_channels, args.hidden_size_vae, args.k).to(args.device) with open(args.model, 'rb') as f: state_dict = torch.load(f) model.load_state_dict(state_dict) model.eval() prior = GatedPixelCNN(args.k, args.hidden_size_prior, args.num_layers, n_classes=10).to(args.device) state_dict = torch.load('models/mnist-10-prior/prior.pt') prior.load_state_dict(state_dict) # label_ = [int(i%10) for i in range(16)] # label_ = torch.tensor(label_).cuda() # reconstruction = model.decode(prior.generate(label=label_, batch_size=16,sample_index=10)) # grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) # writer.add_image('Partial sampling result', grid, 0) for epoch in range(args.num_epochs): label_ = [int(i % 10) for i in range(16)] label_ = torch.tensor(label_).cuda() reconstruction = model.decode( prior.generate(label=label_, batch_size=16, sample_index=20)) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('Full sampling result', grid, epoch + 1)
transform=preproc_transform, ), batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True) test_loader = torch.utils.data.DataLoader(eval('datasets.' + DATASET)( '../data/{}/'.format(DATASET), train=False, transform=preproc_transform), batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True) autoencoder = VectorQuantizedVAE(INPUT_DIM, VAE_DIM, K).to(DEVICE) autoencoder.load_state_dict(torch.load('models/{}_vqvae.pt'.format(DATASET))) autoencoder.eval() model = GatedPixelCNN(K, DIM, N_LAYERS).to(DEVICE) criterion = nn.CrossEntropyLoss().to(DEVICE) opt = torch.optim.Adam(model.parameters(), lr=LR, amsgrad=True) def train(): train_loss = [] for batch_idx, (x, label) in enumerate(train_loader): start_time = time.time() x = x.to(DEVICE) label = label.to(DEVICE) # Get the latent codes for image x latents, _ = autoencoder.encode(x)