def get_recons(loader, hidden_size=256): model = VectorQuantizedVAE(3, hidden_size, 512).to(DEVICE) if hidden_size < 256: ckpt = torch.load( "./models/imagenet/hs_{}/best.pt".format(hidden_size)) else: ckpt = torch.load("./models/imagenet/best.pt".format(hidden_size)) model.load_state_dict(ckpt) args = type('', (), {})() args.device = DEVICE gen_img, _ = next(iter(loader)) # grid = make_grid(gen_img.cpu(), nrow=8) # torchvision.utils.save_image(grid, "hs_{}_recons.png".format(hidden_size)) #exit() reconstruction = generate_samples(gen_img, model, args) grid = make_grid(reconstruction.cpu(), nrow=8) return grid
def main(args): writer = SummaryWriter("./logs/{0}".format(args.output_folder)) save_filename = "./models/{0}".format(args.output_folder) # Define the train, valid & test datasets all_dataset = TripletWhaleDataset(args.data_folder, min_instance_count=10) from torch.utils.data import random_split k = len(all_dataset) train_s = int(k * 0.6) test_s = int(k * 0.2) valid_s = k - train_s - test_s train_dataset, test_dataset, valid_dataset = random_split( all_dataset, (train_s, test_s, valid_s) ) num_channels = 3 # Define the data loaders train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True, ) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers, pin_memory=True, ) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True) # Fixed images for Tensorboard fixed_images, _, fixed_ids = next(iter(test_loader))[:3] fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True) writer.add_image("original", fixed_grid, 0) n_classes = len(all_dataset.index_to_class) vae = VectorQuantizedVAE(num_channels, args.hidden_size, args.k) vae.load_state_dict(torch.load(args.model_path)) encoder = vae.encoder.eval().to(args.device) model = AffineClassifier(all_dataset, encoder, n_classes).to(args.device).train() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) best_loss = -1.0 for epoch in range(args.num_epochs): with torch.autograd.set_detect_anomaly(True): train(train_loader, model, optimizer, args, writer) loss = test(valid_loader, model, args, writer) if (epoch == 0) or (loss < best_loss): best_loss = loss with open("{0}/best.pt".format(save_filename), "wb") as f: torch.save(model.state_dict(), f) with open("{0}/model_{1}.pt".format(save_filename, epoch + 1), "wb") as f: torch.save(model.state_dict(), f)
def main(args): writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}'.format(args.output_folder) if args.dataset == 'whales': # Define the train, valid & test datasets all_dataset = WhaleDataset(args.data_folder, image_transformation=BASIC_IMAGE_T) from torch.utils.data import random_split k = len(all_dataset) train_s = int(k * .6) test_s = int(k * .2) valid_s = k - train_s - test_s train_dataset, test_dataset, valid_dataset = random_split(all_dataset, (train_s, test_s, valid_s)) num_channels = 3 # Define the data loaders train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True) # Fixed images for Tensorboard fixed_images, _, fixed_ids = next(iter(test_loader)) fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True) writer.add_image('original', fixed_grid, 0) vae = VectorQuantizedVAE(num_channels, args.hidden_size, args.k) vae.load_state_dict(torch.load(args.model_path)) encoder = vae.encoder.eval().to(args.device) affine_resample = AffineCropper(all_dataset, args.hidden_size, encoder).to(args.device) generator = affine_resample.generator.train() optimizer = torch.optim.Adam(generator.parameters(), lr=args.lr) # Generate the samples first once reconstruction = generate_samples(fixed_images, fixed_ids, affine_resample, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('cropped', grid, 0) best_loss = -1. for epoch in range(args.num_epochs): with torch.autograd.set_detect_anomaly(True): train(train_loader, encoder, affine_resample, optimizer, args, writer) loss = test(valid_loader, encoder, affine_resample, args, writer) reconstruction = generate_samples(fixed_images, fixed_ids, affine_resample, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('cropped', grid, epoch + 1) if (epoch == 0) or (loss < best_loss): best_loss = loss with open('{0}/best.pt'.format(save_filename), 'wb') as f: torch.save(generator.state_dict(), f) with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1), 'wb') as f: torch.save(generator.state_dict(), f)
def get_model(ae_type, hidden_size, k, num_channels): CKPT_DIR = "models/imagenet/hs_32_4/best.pt" #.format(hidden_size) model = VectorQuantizedVAE(num_channels, hidden_size, k) imgnetclassifier = ImgnetClassifier(model, hidden_size) ckpt = torch.load(CKPT_DIR) model.load_state_dict(ckpt) return imgnetclassifier
def val_test(args): writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}'.format(args.output_folder) train_loader, valid_loader, test_loader = train_util.get_dataloaders(args) recons_input_img = train_util.log_input_img_grid(test_loader, writer) input_dim = 3 model = VectorQuantizedVAE(input_dim, args.hidden_size, args.k, args.enc_type, args.dec_type) # if torch.cuda.device_count() > 1 and args.device == "cuda": # model = torch.nn.DataParallel(model) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) discriminators = {} if args.recons_loss == "gan": recons_disc = Discriminator(input_dim, args.img_res, args.input_type).to(args.device) recons_disc_opt = torch.optim.Adam(recons_disc.parameters(), lr=args.disc_lr, amsgrad=True) discriminators["recons_disc"] = [recons_disc, recons_disc_opt] model.to(args.device) for disc in discriminators: discriminators[disc][0].to(args.device) if args.weights == "load": start_epoch = train_util.load_state(save_filename, model, optimizer, discriminators) else: start_epoch = 0 stop_patience = args.stop_patience best_loss = torch.tensor(np.inf) for epoch in tqdm(range(start_epoch, 4), file=sys.stdout): val_loss_dict, z = train_util.test(get_losses, model, valid_loader, args, discriminators, True) # if args.weights == "init" and epoch==1: # epoch+=1 # break train_util.log_recons_img_grid(recons_input_img, model, epoch + 1, args.device, writer) train_util.log_interp_img_grid(recons_input_img, model, epoch + 1, args.device, writer) train_util.log_losses("val", val_loss_dict, epoch + 1, writer) train_util.log_latent_metrics("val", z, epoch + 1, writer) train_util.save_state(model, optimizer, discriminators, val_loss_dict["recons_loss"], best_loss, args.recons_loss, epoch, save_filename) print(val_loss_dict)
def get_model(ae_type, hidden_size, k, num_channels): if ae_type == "vqvae": if hidden_size==256: CKPT_DIR = "models/imagenet/best.pt" elif hidden_size==128: CKPT_DIR = "models/imagenet/hs_{}/best.pt".format(hidden_size) else: CKPT_DIR = "models/imagenet/hs_{}/best.pt".format(hidden_size) model = VectorQuantizedVAE(num_channels, hidden_size, k) imgnetclassifier = ImgnetClassifier(model, hidden_size) elif ae_type == "vae": CKPT_DIR = "models/imagenet_vae.pt" model = VAE(num_channels, hidden_size, 4096) imgnetclassifier = ImgnetClassifier(model, 4) ckpt = torch.load(CKPT_DIR) model.load_state_dict(ckpt) return imgnetclassifier
def get_model(model, hidden_size, k, num_channels, resolution, num_classes): if model == "vqvae": if hidden_size == 256: CKPT_DIR = "models/imagenet/best.pt" elif hidden_size == 128: CKPT_DIR = "models/imagenet/hs_{}/best.pt".format(hidden_size) else: CKPT_DIR = "models/imagenet/hs_{}/best.pt".format(hidden_size) #CKPT_DIR = "models/imagenet/hs_32_4/best.pt"#.format(hidden_size) model = VectorQuantizedVAE(num_channels, hidden_size, k) elif model == "vae": CKPT_DIR = f"models/imagenet_hs_128_{hidden_size}_vae.pt" model = VAE(num_channels, hidden_size, hidden_size) elif model == "aae": CKPT_DIR = f"models/aae/imagenet_hs_32_{hidden_size}/best.pt" model = AAE(32, num_channels, hidden_size) else: model = None imgnetclassifier = ImgnetClassifier(model, hidden_size, resolution) if hidden_size > 0: ckpt = torch.load(CKPT_DIR) model.load_state_dict(ckpt) return imgnetclassifier
def get_model(model, hidden_size, num_channels, resolution, enc_type, dec_type, num_classes, k=512): CKPT_DIR = f"models/{model}_{args.recons_loss}/{args.train_dataset}/depth_{enc_type}_{dec_type}_hs_{args.img_res}_{hidden_size}/best.pt" if model == "vqvae": #CKPT_DIR = "models/imagenet/hs_32_4/best.pt"#.format(hidden_size) model = VectorQuantizedVAE(num_channels, hidden_size, k, enc_type, dec_type) elif model == "vae": model = VAE(num_channels, hidden_size, enc_type, dec_type) elif model == "acai": model = ACAI(resolution, num_channels, hidden_size, enc_type, dec_type) else: model = None if model != "supervised": imgclassifier = ImgClassifier(model, hidden_size, resolution, num_classes) else: imgclassifier = SupervisedImgClassifier(hidden_size, enc_type, resolution, num_classes) if hidden_size > 0: ckpt = torch.load(CKPT_DIR) model.load_state_dict(ckpt["model"]) return imgclassifier
def main(args): now = datetime.now() current_time = now.strftime("%H:%M:%S") print("Start Time =", current_time) if args.input == "MNIST": transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) # Define the train & test dataSets test_set = datasets.MNIST("MNIST", train=False, download=True, transform=transform) # Define the data loaders test_loader = torch.utils.data.DataLoader(test_set, batch_size=100, shuffle=True, pin_memory=True) real_batch, _ = next(iter(test_loader)) else: real_batch = torch.normal(mean=0, std=1, size=(100, 1, 28, 28)) model = VectorQuantizedVAE(1, args.hidden_size_vae, args.k).to(args.device) with open(args.model, 'rb') as f: state_dict = torch.load(f) model.load_state_dict(state_dict) generated = generate_samples(real_batch, model, args) save_image(make_grid(generated, nrow=10), './generatedImages/{0}.png'.format(args.filename)) now = datetime.now() current_time = now.strftime("%H:%M:%S") print("End Time =", current_time)
), batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True) test_loader = torch.utils.data.DataLoader(eval('datasets.' + DATASET)( '../data/{}/'.format(DATASET), train=False, transform=preproc_transform, ), batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True) model = VectorQuantizedVAE(INPUT_DIM, DIM, K).to(DEVICE) print(model) opt = torch.optim.Adam(model.parameters(), lr=LR, amsgrad=True) def train(): train_loss = [] for batch_idx, (x, _) in enumerate(train_loader): start_time = time.time() x = x.to(DEVICE) opt.zero_grad() x_tilde, z_e_x, z_q_x = model(x) z_q_x.retain_grad()
def main(args): writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}'.format(args.output_folder) if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']: transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) if args.dataset == 'mnist': # Define the train & test datasets train_dataset = datasets.MNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.MNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'fashion-mnist': # Define the train & test datasets train_dataset = datasets.FashionMNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.FashionMNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'cifar10': # Define the train & test datasets train_dataset = datasets.CIFAR10(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.CIFAR10(args.data_folder, train=False, transform=transform) num_channels = 3 valid_dataset = test_dataset elif args.dataset == 'PubTabNet': transform = transforms.Compose( [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) train_dataset = PubTabNet(args.data_folder, args.data_name, 'TRAIN', transform=transform) test_dataset = PubTabNet(args.data_folder, args.data_name, 'VAL', transform=transform) valid_dataset = test_dataset num_channels = 3 elif args.dataset == 'miniimagenet': transform = transforms.Compose([ transforms.RandomResizedCrop(128), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Define the train, valid & test datasets train_dataset = MiniImagenet(args.data_folder, train=True, download=True, transform=transform) valid_dataset = MiniImagenet(args.data_folder, valid=True, download=True, transform=transform) test_dataset = MiniImagenet(args.data_folder, test=True, download=True, transform=transform) num_channels = 3 # Define the data loaders train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True) # Fixed images for Tensorboard fixed_images, _ = next(iter(test_loader)) fixed_grid = make_grid(fixed_images, nrow=4, range=(-1, 1), normalize=True) writer.add_image('original', fixed_grid, 0) model = VectorQuantizedVAE(num_channels, args.hidden_size, args.k).to(args.device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # Generate the samples first once reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=4, range=(-1, 1), normalize=True) writer.add_image('reconstruction', grid, 0) best_loss = -1. for epoch in range(args.num_epochs): train(train_loader, model, optimizer, args, writer, epoch) loss, _ = test(valid_loader, model, args, writer) eprint('Validataion loss at epoch %d: Loss = %.4f' % (epoch, loss)) reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=4, range=(-1, 1), normalize=True) writer.add_image('reconstruction', grid, epoch + 1) if (epoch == 0) or (loss < best_loss): best_loss = loss with open('{0}/best.pt'.format(save_filename), 'wb') as f: torch.save(model.state_dict(), f) with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1), 'wb') as f: torch.save(model.state_dict(), f)
def main(args): writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}/prior.pt'.format(args.output_folder) if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']: transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) if args.dataset == 'mnist': # Define the train & test datasets train_dataset = datasets.MNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.MNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'fashion-mnist': # Define the train & test datasets train_dataset = datasets.FashionMNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.FashionMNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'cifar10': # Define the train & test datasets train_dataset = datasets.CIFAR10(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.CIFAR10(args.data_folder, train=False, transform=transform) num_channels = 3 valid_dataset = test_dataset elif args.dataset == 'miniimagenet': transform = transforms.Compose([ transforms.RandomResizedCrop(32), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Define the train, valid & test datasets train_dataset = MiniImagenet(args.data_folder, train=True, download=True, transform=transform) valid_dataset = MiniImagenet(args.data_folder, valid=True, download=True, transform=transform) test_dataset = MiniImagenet(args.data_folder, test=True, download=True, transform=transform) num_channels = 3 # Define the data loaders train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True) # Save the label encoder # with open('./models/{0}/labels.json'.format(args.output_folder), 'w') as f: # json.dump(train_dataset._label_encoder, f) # Fixed images for Tensorboard fixed_images, _ = next(iter(test_loader)) fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True) writer.add_image('original', fixed_grid, 0) model = VectorQuantizedVAE(num_channels, args.hidden_size_vae, args.k).to(args.device) with open(args.model, 'rb') as f: state_dict = torch.load(f) model.load_state_dict(state_dict) model.eval() prior = GatedPixelCNN(args.k, args.hidden_size_prior, args.num_layers, n_classes=32).to(args.device) # args.num_layers, n_classes=len(train_dataset._label_encoder)).to(args.device) optimizer = torch.optim.Adam(prior.parameters(), lr=args.lr) best_loss = -1. for epoch in range(args.num_epochs): print(epoch) train(train_loader, model, prior, optimizer, args, writer) # The validation loss is not properly computed since # the classes in the train and valid splits of Mini-Imagenet # do not overlap. loss = test(valid_loader, model, prior, args, writer) if (epoch == 0) or (loss < best_loss): best_loss = loss with open(save_filename, 'wb') as f: torch.save(prior.state_dict(), f)
def main(args): # set manualseed random.seed(args.manualseed) torch.manual_seed(args.manualseed) torch.cuda.manual_seed_all(args.manualseed) np.random.seed(args.manualseed) torch.backends.cudnn.deterministic = True writer = SummaryWriter(args.log_dir) save_filename = args.save_dir # if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']: # transform = transforms.Compose([ # transforms.ToTensor(), # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # ]) # if args.dataset == 'mnist': # assert args.nc == 1 # # Define the train & test datasets # train_dataset = datasets.MNIST(args.dataroot, # train=True, # download=True, # transform=transform) # test_dataset = datasets.MNIST(args.dataroot, # train=False, # transform=transform) # elif args.dataset == 'fashion-mnist': # # Define the train & test datasets # train_dataset = datasets.FashionMNIST(args.dataroot, # train=True, # download=True, # transform=transform) # test_dataset = datasets.FashionMNIST(args.dataroot, # train=False, # transform=transform) # elif args.dataset == 'cifar10': # # Define the train & test datasets # train_dataset = datasets.CIFAR10(args.dataroot, # train=True, # download=True, # transform=transform) # test_dataset = datasets.CIFAR10(args.dataroot, # train=False, # transform=transform) # valid_dataset = test_dataset # elif args.dataset == 'miniimagenet': # transform = transforms.Compose([ # transforms.RandomResizedCrop(128), # transforms.ToTensor(), # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # ]) # # Define the train, valid & test datasets # train_dataset = MiniImagenet(args.dataroot, # train=True, # download=True, # transform=transform) # valid_dataset = MiniImagenet(args.dataroot, # valid=True, # download=True, # transform=transform) # test_dataset = MiniImagenet(args.dataroot, # test=True, # download=True, # transform=transform) # Define the data loaders # train_loader = torch.utils.data.DataLoader(train_dataset, # batch_size=args.batch_size, # shuffle=False, # num_workers=args.num_workers, # pin_memory=True) # valid_loader = torch.utils.data.DataLoader(valid_dataset, # batch_size=args.batch_size, # shuffle=False, # drop_last=True, # num_workers=args.num_workers, # pin_memory=True) # test_loader = torch.utils.data.DataLoader(test_dataset, # batch_size=16, # shuffle=True) dataloader = load_data(args) train_loader = dataloader['train'] valid_loader = dataloader['valid'] test_loader = dataloader['test'] # Fixed images for Tensorboard fixed_images, _ = next(iter(valid_loader)) fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True) writer.add_image('original', fixed_grid, 0) model = VectorQuantizedVAE(args.nc, args.hidden_size, args.k).to(args.device) if len(args.gpu_ids) > 1: model = torch.nn.DataParallel(model, device_ids=args.gpu_ids, output_device=args.gpu_ids[0]) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # Generate the samples first once reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('reconstruction', grid, 0) best_loss = -1. # for epoch in range(args.num_epochs): for epoch in range(1, args.num_epochs + 1): train(train_loader, model, optimizer, args, writer, epoch) loss, _ = test(test_loader, model, args, writer) reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('reconstruction', grid, epoch + 1) if (epoch == 1) or (loss < best_loss): best_loss = loss with open('{0}/best.pt'.format(save_filename), 'wb') as f: torch.save(model.state_dict(), f) if (epoch % args.save_step) == 0: with open('{0}/model_{1}.pt'.format(save_filename, epoch), 'wb') as f: torch.save(model.state_dict(), f)
def main(args): now = datetime.now() current_time = now.strftime("%H:%M:%S") print("Start Time =", current_time) writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}'.format(args.output_folder) transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) # Define the train & test dataSets train_set = datasets.MNIST(args.data_folder, train=True, download=True, transform=transform) test_set = datasets.MNIST(args.data_folder, train=False, download=True, transform=transform) num_channels = 1 # Define the data loaders train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_set, num_workers=args.num_workers, batch_size=16, shuffle=False) # Fixed images for TensorBoard fixed_images, _ = next(iter(test_loader)) fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True) writer.add_image('original', fixed_grid, 0) model = VectorQuantizedVAE(num_channels, args.hidden_size, args.k).to(args.device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) writer.add_graph(model, fixed_images.to(args.device)) # get model structure on tensorboard # Generate the samples first once reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('reconstruction at start', grid, 0) img_list = [] best_loss = -1. for epoch in range(args.num_epochs): train(train_loader, model, optimizer, args, writer) loss, _ = test(test_loader, model, args, writer) reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('reconstruction at epoch {:f}'.format(epoch + 1), grid, epoch + 1) print("loss = {:f} at epoch {:f}".format(loss, epoch + 1)) writer.add_scalar('loss/testing_loss', loss, epoch + 1) img_list.append(grid) if (epoch == 0) or (loss < best_loss): best_loss = loss with open('{0}/best.pt'.format(save_filename), 'wb') as f: torch.save(model.state_dict(), f) with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1), 'wb') as f: torch.save(model.state_dict(), f) now = datetime.now() current_time = now.strftime("%H:%M:%S") print("End Time =", current_time)
def model_summary(model_type, img_res, hidden_size, enc_type, dec_type, loss, batch_size, device=torch.device("cuda:1"), verbose=True): pattern = re.compile(r"Params size \(MB\):(.*)\n") pattern2 = re.compile(r"Forward/backward pass size \(MB\):(.*)\n") input_dim = 3 enc_input_size = (input_dim, img_res, img_res) dec_input_size = (hidden_size, img_res // 4, img_res // 4) pdb.set_trace() if verbose: print(f"model:{model_type}") print(f"depth:{enc_type}_{dec_type}") if model_type == "acai": model = ACAI(img_res, input_dim, hidden_size, enc_type, dec_type).to(device) elif model_type == "vqvae": model = VectorQuantizedVAE(input_dim, hidden_size, enc_type=enc_type, dec_type=dec_type).to(device) elif model_type == "vae": model = VAE(input_dim, hidden_size, enc_type=enc_type, dec_type=dec_type).to(device) encoder_summary, _ = torchsummary.summary_string(model.encoder, enc_input_size, device=device, batch_size=batch_size) decoder_summary, _ = torchsummary.summary_string(model.decoder, dec_input_size, device=device, batch_size=batch_size) if verbose: print(encoder_summary) print(decoder_summary) discriminators = {} if model_type == "acai": disc = Discriminator(input_dim, img_res, "image").to(device) disc_summary, _ = torchsummary.summary_string(disc, enc_input_size, device=device, batch_size=batch_size) disc_param_size = float(re.search(pattern, disc_summary).group(1)) disc_forward_size = float(re.search(pattern2, disc_summary).group(1)) discriminators["interp_disc"] = (disc_param_size, disc_forward_size) if loss == "gan": disc = Discriminator(input_dim, img_res, "image").to(device) disc_summary, _ = torchsummary.summary_string(disc, enc_input_size, device=device, batch_size=batch_size) disc_param_size = float(re.search(pattern, disc_summary).group(1)) disc_forward_size = float(re.search(pattern2, disc_summary).group(1)) discriminators["recons_disc"] = (disc_param_size, 2 * disc_forward_size) elif loss == "comp": disc = AnchorComparator(input_dim * 2, img_res, "image").to(device) disc_summary, _ = torchsummary.summary_string(disc, enc_input_size, device=device, batch_size=batch_size) disc_param_size = float(re.search(pattern, disc_summary).group(1)) disc_forward_size = float(re.search(pattern2, disc_summary).group(1)) discriminators["recons_disc"] = (disc_param_size, 2 * disc_forward_size) elif "comp_2" in loss: disc = ClubbedPermutationComparator(input_dim * 2, img_res, "image").to(device) disc_summary, _ = torchsummary.summary_string(disc, enc_input_size, device=device, batch_size=batch_size) disc_param_size = float(re.search(pattern, disc_summary).group(1)) disc_forward_size = float(re.search(pattern2, disc_summary).group(1)) discriminators["recons_disc"] = (disc_param_size, 2 * disc_forward_size) elif "comp_6" in loss: disc = FullPermutationComparator(input_dim * 2, img_res, "image").to(device) disc_summary, _ = torchsummary.summary_string(disc, enc_input_size, device=device, batch_size=batch_size) disc_param_size = float(re.search(pattern, disc_summary).group(1)) disc_forward_size = float(re.search(pattern2, disc_summary).group(1)) discriminators["recons_disc"] = (disc_param_size, 2 * disc_forward_size) encoder_param_size = float(re.search(pattern, encoder_summary).group(1)) encoder_forward_size = float(re.search(pattern2, encoder_summary).group(1)) decoder_param_size = float(re.search(pattern, decoder_summary).group(1)) decoder_forward_size = float(re.search(pattern2, decoder_summary).group(1)) if verbose: if "ACAI" in str(type(model)): print( f"discriminator:\n\tparams:{disc_param_size}\n\tforward:{disc_forward_size}" ) if loss == "gan": print( f"reconstruction discriminator:\n\tparams:{disc_param_size}\n\tforward:{disc_forward_size}" ) print( f"encoder:\n\tparams:{encoder_param_size}\n\tforward:{encoder_forward_size}" ) print( f"decoder:\n\tparams:{decoder_param_size}\n\tforward:{decoder_forward_size}" ) encoder = {"params": encoder_param_size, "forward": encoder_forward_size} decoder = {"params": decoder_param_size, "forward": decoder_forward_size} return encoder, decoder, discriminators
def main(args): writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}'.format(args.output_folder) if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']: transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) if args.dataset == 'mnist': # Define the train & test datasets train_dataset = datasets.MNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.MNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'fashion-mnist': # Define the train & test datasets train_dataset = datasets.FashionMNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.FashionMNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'cifar10': # Define the train & test datasets train_dataset = datasets.CIFAR10(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.CIFAR10(args.data_folder, train=False, transform=transform) num_channels = 3 valid_dataset = test_dataset elif args.dataset == 'clevr': transform = transforms.Compose([ # transforms.RandomResizedCrop(128), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) import socket if "Alien" in socket.gethostname(): dataset_name = "/media/mihir/dataset/clevr_veggies/" else: dataset_name = "/projects/katefgroup/datasets/clevr_veggies/" dataset_name = '/home/mprabhud/dataset/clevr_veggies' # Define the train, valid & test datasets train_dataset = Clevr(dataset_name,mod = args.modname\ , train=True, transform=transform,object_level= args.object_level) valid_dataset = Clevr(dataset_name,mod = args.modname,\ valid=True,transform=transform,object_level= args.object_level) test_dataset = Clevr(dataset_name,mod = args.modname,\ test=True, transform=transform,object_level= args.object_level) num_channels = 3 elif args.dataset == 'carla': if args.use_depth: transform = transforms.Compose([ # transforms.RandomResizedCrop(128), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5, 0.5), (0.5, 0.5, 0.5, 0.5)) ]) num_channels = 4 else: transform = transforms.Compose([ # transforms.RandomResizedCrop(128), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) num_channels = 3 import socket if "Alien" in socket.gethostname(): dataset_name = "/media/mihir/dataset/clevr_veggies/" else: dataset_name = '/home/shamitl/datasets/carla' dataset_name = "/projects/katefgroup/datasets/carla/" dataset_name = '/home/mprabhud/dataset/carla' # Define the train, valid & test datasets train_dataset = Clevr(dataset_name,mod = args.modname\ , train=True, transform=transform,object_level= args.object_level,use_depth=args.use_depth) valid_dataset = Clevr(dataset_name,mod = args.modname,\ valid=True,transform=transform,object_level= args.object_level,use_depth=args.use_depth) test_dataset = Clevr(dataset_name,mod = args.modname,\ test=True, transform=transform,object_level= args.object_level,use_depth=args.use_depth) # elif args.dataset == 'miniimagenet': # transform = transforms.Compose([ # transforms.RandomResizedCrop(128), # transforms.ToTensor(), # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # ]) # # Define the train, valid & test datasets # train_dataset = MiniImagenet(args.data_folder, train=True, # download=True, transform=transform) # # valid_dataset = MiniImagenet(args.data_folder, valid=True, # # download=True, transform=transform) # # test_dataset = MiniImagenet(args.data_folder, test=True, # # download=True, transform=transform) # # num_channels = 3 # Define the data loaders train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True) fixed_images, _ = next(iter(train_loader)) # Fixed images for Tensorboard fixed_images, _ = next(iter(test_loader)) fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True) writer.add_image('original', fixed_grid, 0) model = VectorQuantizedVAE(num_channels, args.hidden_size, args.object_level, args.k).to(args.device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) if args.load_model is not "": with open(args.load_model, 'rb') as f: state_dict = torch.load(f) model.load_state_dict(state_dict) # Generate the samples first once reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('reconstruction', grid, 0) # st() best_loss = -1. for epoch in range(args.num_epochs): if not args.test_mode: train(train_loader, model, optimizer, args, writer, epoch) # st() loss, _ = test_old(valid_loader, model, args, writer) reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('reconstruction', grid, epoch + 1) # st() with open('{0}/recent.pt'.format(save_filename), 'wb') as f: torch.save(model.state_dict(), f) if (epoch == 0) or (loss < best_loss): best_loss = loss with open('{0}/best.pt'.format(save_filename), 'wb') as f: torch.save(model.state_dict(), f) # else: # print("nothing") else: test(train_loader, model, args, writer) reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('reconstruction', grid, epoch + 1)
def main(args): writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}'.format(args.output_folder) train_loader, valid_loader, test_loader = train_util.get_dataloaders(args) num_channels = 3 model = VectorQuantizedVAE(num_channels, args.hidden_size, args.k, args.enc_type, args.dec_type) model.to(args.device) # Fixed images for Tensorboard recons_input_img = train_util.log_input_img_grid(test_loader, writer) train_util.log_recons_img_grid(recons_input_img, model, 0, args.device, writer) discriminators = {} input_dim = 3 if args.recons_loss != "mse": if args.recons_loss == "gan": recons_disc = Discriminator(input_dim, args.img_res, args.input_type).to(args.device) elif args.recons_loss == "comp": recons_disc = AnchorComparator(input_dim * 2, args.img_res, args.input_type).to(args.device) elif "comp_2" in args.recons_loss: recons_disc = ClubbedPermutationComparator( input_dim * 2, args.img_res, args.input_type).to(args.device) elif "comp_6" in args.recons_loss: recons_disc = FullPermutationComparator( input_dim * 2, args.img_res, args.input_type).to(args.device) recons_disc_opt = torch.optim.Adam(recons_disc.parameters(), lr=args.disc_lr, amsgrad=True) recons_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( recons_disc_opt, "min", patience=args.lr_patience, factor=0.5, threshold=args.threshold, threshold_mode="abs", min_lr=1e-7) discriminators["recons_disc"] = [recons_disc, recons_disc_opt] optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) ae_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, "min", patience=args.lr_patience, factor=0.5, threshold=args.threshold, threshold_mode="abs", min_lr=1e-7) if torch.cuda.device_count() > 1: model = train_util.ae_data_parallel(model) for disc in discriminators: discriminators[disc][0] = torch.nn.DataParallel( discriminators[disc][0]) model.to(args.device) for disc in discriminators: discriminators[disc][0].to(args.device) # Generate the samples first once recons_input_img = train_util.log_input_img_grid(test_loader, writer) train_util.log_recons_img_grid(recons_input_img, model, 0, args.device, writer) if args.weights == "load": start_epoch = train_util.load_state(save_filename, model, optimizer, discriminators) else: start_epoch = 0 stop_patience = args.stop_patience best_loss = torch.tensor(np.inf) for epoch in tqdm(range(start_epoch, args.num_epochs), file=sys.stdout): try: train(epoch, train_loader, model, optimizer, args, writer, discriminators) except RuntimeError as err: print("".join( traceback.TracebackException.from_exception(err).format()), file=sys.stderr) print("*******") print(err, file=sys.stderr) print(f"batch_size:{args.batch_size}", file=sys.stderr) exit(0) val_loss_dict, z = train_util.test(get_losses, model, valid_loader, args, discriminators) train_util.log_recons_img_grid(recons_input_img, model, epoch + 1, args.device, writer) train_util.log_interp_img_grid(recons_input_img, model, epoch + 1, args.device, writer) train_util.log_losses("val", val_loss_dict, epoch + 1, writer) train_util.log_latent_metrics("val", z, epoch + 1, writer) train_util.save_state(model, optimizer, discriminators, val_loss_dict["recons_loss"], best_loss, args.recons_loss, epoch, save_filename) # early stop check # if val_loss_dict["recons_loss"] - best_loss < args.threshold: # stop_patience -= 1 # else: # stop_patience = args.stop_patience # if stop_patience == 0: # print("training early stopped!") # break ae_lr_scheduler.step(val_loss_dict["recons_loss"]) if args.recons_loss != "mse": recons_disc_lr_scheduler.step(val_loss_dict["recons_disc_loss"])
def main(args): writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}/prior.pt'.format(args.output_folder) if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']: transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ]) if args.dataset == 'mnist': # Define the train & test datasets train_dataset = datasets.MNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.MNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'fashion-mnist': # Define the train & test datasets train_dataset = datasets.FashionMNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.FashionMNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'cifar10': # Define the train & test datasets train_dataset = datasets.CIFAR10(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.CIFAR10(args.data_folder, train=False, transform=transform) num_channels = 3 valid_dataset = test_dataset elif args.dataset == 'miniimagenet': transform = transforms.Compose([ transforms.RandomResizedCrop(128), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Define the train, valid & test datasets train_dataset = MiniImagenet(args.data_folder, train=True, download=True, transform=transform) valid_dataset = MiniImagenet(args.data_folder, valid=True, download=True, transform=transform) test_dataset = MiniImagenet(args.data_folder, test=True, download=True, transform=transform) num_channels = 3 # Define the data loaders train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True) # Save the label encoder # with open('./models/{0}/labels.json'.format(args.output_folder), 'w') as f: # json.dump(train_dataset._label_encoder, f) # Fixed images for Tensorboard fixed_images, _ = next(iter(test_loader)) fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True) writer.add_image('original', fixed_grid, 0) model = VectorQuantizedVAE(num_channels, args.hidden_size_vae, args.k).to(args.device) with open(args.model, 'rb') as f: state_dict = torch.load(f) model.load_state_dict(state_dict) model.eval() prior = GatedPixelCNN(args.k, args.hidden_size_prior, args.num_layers, n_classes=10).to(args.device) state_dict = torch.load('models/mnist-10-prior/prior.pt') prior.load_state_dict(state_dict) # label_ = [int(i%10) for i in range(16)] # label_ = torch.tensor(label_).cuda() # reconstruction = model.decode(prior.generate(label=label_, batch_size=16,sample_index=10)) # grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) # writer.add_image('Partial sampling result', grid, 0) for epoch in range(args.num_epochs): label_ = [int(i % 10) for i in range(16)] label_ = torch.tensor(label_).cuda() reconstruction = model.decode( prior.generate(label=label_, batch_size=16, sample_index=20)) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('Full sampling result', grid, epoch + 1)
def main(args): writer = SummaryWriter("./logs/{0}".format(args.output_folder)) save_filename = "./models/{0}".format(args.output_folder) if args.dataset in ["mnist", "fashion-mnist", "cifar10"]: transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] ) if args.dataset == "mnist": # Define the train & test datasets train_dataset = datasets.MNIST( args.data_folder, train=True, download=True, transform=transform ) test_dataset = datasets.MNIST( args.data_folder, train=False, transform=transform ) num_channels = 1 elif args.dataset == "fashion-mnist": # Define the train & test datasets train_dataset = datasets.FashionMNIST( args.data_folder, train=True, download=True, transform=transform ) test_dataset = datasets.FashionMNIST( args.data_folder, train=False, transform=transform ) num_channels = 1 elif args.dataset == "cifar10": # Define the train & test datasets train_dataset = datasets.CIFAR10( args.data_folder, train=True, download=True, transform=transform ) test_dataset = datasets.CIFAR10( args.data_folder, train=False, transform=transform ) num_channels = 3 valid_dataset = test_dataset elif args.dataset == "miniimagenet": transform = transforms.Compose( [ transforms.RandomResizedCrop(128), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] ) # Define the train, valid & test datasets train_dataset = MiniImagenet( args.data_folder, train=True, download=True, transform=transform ) valid_dataset = MiniImagenet( args.data_folder, valid=True, download=True, transform=transform ) test_dataset = MiniImagenet( args.data_folder, test=True, download=True, transform=transform ) num_channels = 3 else: transform = transforms.Compose( [ transforms.RandomResizedCrop(args.image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] ) # Define the train, valid & test datasets train_dataset = ImageFolder( os.path.join(args.data_folder, "train"), transform=transform ) valid_dataset = ImageFolder( os.path.join(args.data_folder, "val"), transform=transform ) test_dataset = valid_dataset num_channels = 3 train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True, ) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers, pin_memory=True, ) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True) # Fixed images for Tensorboard fixed_images, _ = next(iter(test_loader)) fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True) save_image(fixed_grid, "true.png") writer.add_image("original", fixed_grid, 0) model = VectorQuantizedVAE(num_channels, args.hidden_size, args.k).to(args.device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # Generate the samples first once reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) save_image(grid, "rec.png") writer.add_image("reconstruction", grid, 0) best_loss = -1 for epoch in range(args.num_epochs): train(train_loader, model, optimizer, args, writer) loss, _ = test(valid_loader, model, args, writer) print(epoch, "test loss: ", loss) reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) save_image(grid, "rec.png") writer.add_image("reconstruction", grid, epoch + 1) if (epoch == 0) or (loss < best_loss): best_loss = loss with open("{0}/best.pt".format(save_filename), "wb") as f: torch.save(model.state_dict(), f) with open("{0}/model_{1}.pt".format(save_filename, epoch + 1), "wb") as f: torch.save(model.state_dict(), f)
def main(args): if args.dataset == 'miniimagenet': readable_labels = { 38: 'organ', 42: 'prayer_rug', 31: 'file', 61: 'cliff', 58: 'consomme', 59: 'hotdog', 21: 'aircraft_carrier', 14: 'French_bulldog', 28: 'cocktail_shaker', 63: 'ear', 3: 'green_mamba', 4: 'harvestman', 17: 'Arctic_fox', 32: 'fire_screen', 11: 'komondor', 43: 'reel', 18: 'ladybug', 45: 'snorkel', 24: 'beer_bottle', 36: 'lipstick', 5: 'toucan', 0: 'house_finch', 16: 'miniature_poodle', 50: 'tile_roof', 15: 'Newfoundland', 46: 'solar_dish', 10: 'Gordon_setter', 7: 'dugong', 52: 'unicycle', 20: 'rock_beauty', 48: 'stage', 22: 'ashcan', 34: 'hair_slide', 30: 'dome', 13: 'Tibetan_mastiff', 53: 'upright', 62: 'bolete', 2: 'triceratops', 40: 'pencil_box', 26: 'chime', 47: 'spider_web', 51: 'tobacco_shop', 60: 'orange', 49: 'tank', 8: 'Walker_hound', 23: 'barrel', 6: 'jellyfish', 33: 'frying_pan', 9: 'Saluki', 37: 'oboe', 1: 'robin', 19: 'three-toed_sloth', 39: 'parallel_bars', 55: 'worm_fence', 27: 'clog', 41: 'photocopier', 25: 'carousel', 29: 'dishrag', 57: 'street_sign', 35: 'holster', 12: 'boxer', 56: 'yawl', 54: 'wok', 44: 'slot' } elif args.dataset == 'cifar10': readable_labels = { 0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer', 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck' } writer = SummaryWriter('./VQVAE/logs/{0}'.format(args.output_folder)) save_filename = './VQVAE/models/{0}/prior.pt'.format(args.output_folder) if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']: transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) if args.dataset == 'mnist': # Define the train & test datasets train_dataset = datasets.MNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.MNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'fashion-mnist': # Define the train & test datasets train_dataset = datasets.FashionMNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.FashionMNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'cifar10': # Define the train & test datasets train_dataset = datasets.CIFAR10(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.CIFAR10(args.data_folder, train=False, transform=transform) num_channels = 3 valid_dataset = test_dataset elif args.dataset == 'miniimagenet': transform = transforms.Compose([ transforms.RandomResizedCrop(128), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Define the train, valid & test datasets train_dataset = MiniImagenet(args.data_folder, train=True, download=True, transform=transform) valid_dataset = MiniImagenet(args.data_folder, valid=True, download=True, transform=transform) test_dataset = MiniImagenet(args.data_folder, test=True, download=True, transform=transform) num_channels = 3 # Define the data loaders train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True) # Fixed images for Tensorboard fixed_images, _ = next(iter(test_loader)) fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True) writer.add_image('original', fixed_grid, 0) model = VectorQuantizedVAE(num_channels, args.hidden_size_vae, args.k).to(args.device) with open(args.model, 'rb') as f: state_dict = torch.load(f) model.load_state_dict(state_dict) model.eval() if args.dataset == 'miniimagenet': print("number of training classes:", len(train_dataset._label_encoder)) n_classes = len(train_dataset._label_encoder) shape = (32, 32) yrange = range(n_classes) csv_filename = 'miniimagenet_generated.csv' sample_size = 25 elif args.dataset == 'cifar10': print("number of training classes:", 10) n_classes = 10 shape = (8, 8) yrange = range(n_classes) csv_filename = 'cifar10_generated.csv' sample_size = 1000 prior = GatedPixelCNN(args.k, args.hidden_size_prior, args.num_layers, n_classes=n_classes).to(args.device) with open(args.prior, 'rb') as f: state_dict = torch.load(f) prior.load_state_dict(state_dict) prior.eval() # maximum number of kept dimensions max_num_dst = 0 max_num_sparsemax = 0 with torch.no_grad(): f = open( './VQVAE/models/{0}/{1}'.format(args.output_folder, csv_filename), 'w') with f: writer = csv.writer(f) writer.writerow(['filename', 'label']) for y in tqdm(yrange): label = torch.tensor([y]) label = label.to(args.device) z = prior.generate(label=label, shape=shape, batch_size=sample_size) x = model.decode(z) for im in range(x.shape[0]): save_image(x.cpu()[im], './data/{0}/dataset_softmax/{1}_{2}.jpg'.format( args.output_folder, str(y).zfill(2), str(im).zfill(3)), range=(-1, 1), normalize=True) if args.dataset == 'miniimagenet': y_str = list(train_dataset._label_encoder.keys())[list( train_dataset._label_encoder.values()).index(y)] writer.writerow([ '{0}_{1}.jpg'.format( str(y).zfill(2), str(im).zfill(3)), str(y_str) ]) elif args.dataset == 'cifar10': writer.writerow([ '{0}_{1}.jpg'.format( str(y).zfill(2), str(im).zfill(3)), str(y) ]) z, num_dst = prior.generate_dst(label=label, shape=shape, batch_size=sample_size) x = model.decode(z) if num_dst > max_num_dst: max_num_dst = num_dst for im in range(x.shape[0]): save_image(x.cpu()[im], './data/{0}/dataset_dst/{1}_{2}.jpg'.format( args.output_folder, str(y).zfill(2), str(im).zfill(3)), range=(-1, 1), normalize=True) z, num_sparsemax = prior.generate_sparsemax( label=label, shape=shape, batch_size=sample_size) x = model.decode(z) if num_sparsemax > max_num_sparsemax: max_num_sparsemax = num_sparsemax for im in range(x.shape[0]): save_image( x.cpu()[im], './data/{0}/dataset_sparsemax/{1}_{2}.jpg'.format( args.output_folder, str(y).zfill(2), str(im).zfill(3)), range=(-1, 1), normalize=True) pkl.dump([max_num_dst, max_num_sparsemax], open( "max_num_" + str(args.dataset) + "_" + str(sample_size) + "_correct_dst.pkl", "wb"))
train=True, download=True, transform=preproc_transform, ), batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True) test_loader = torch.utils.data.DataLoader(eval('datasets.' + DATASET)( '../data/{}/'.format(DATASET), train=False, transform=preproc_transform), batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True) autoencoder = VectorQuantizedVAE(INPUT_DIM, VAE_DIM, K).to(DEVICE) autoencoder.load_state_dict(torch.load('models/{}_vqvae.pt'.format(DATASET))) autoencoder.eval() model = GatedPixelCNN(K, DIM, N_LAYERS).to(DEVICE) criterion = nn.CrossEntropyLoss().to(DEVICE) opt = torch.optim.Adam(model.parameters(), lr=LR, amsgrad=True) def train(): train_loss = [] for batch_idx, (x, label) in enumerate(train_loader): start_time = time.time() x = x.to(DEVICE) label = label.to(DEVICE)
def main(args): writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}'.format(args.output_folder) if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']: transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) if args.dataset == 'mnist': # Define the train & test datasets train_dataset = datasets.MNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.MNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'fashion-mnist': # Define the train & test datasets train_dataset = datasets.FashionMNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.FashionMNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'cifar10': # Define the train & test datasets train_dataset = datasets.CIFAR10(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.CIFAR10(args.data_folder, train=False, transform=transform) num_channels = 3 valid_dataset = test_dataset elif args.dataset == 'miniimagenet': transform = transforms.Compose([ transforms.RandomResizedCrop(32), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Define the train, valid & test datasets train_dataset = MiniImagenet(args.data_folder, train=True, download=True, transform=transform) valid_dataset = MiniImagenet(args.data_folder, valid=True, download=True, transform=transform) test_dataset = MiniImagenet(args.data_folder, test=True, download=True, transform=transform) num_channels = 3 # Define the data loaders train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True) # Fixed images for Tensorboard fixed_images, _ = next(iter(test_loader)) fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True) writer.add_image('original', fixed_grid, 0) model = VectorQuantizedVAE(num_channels, args.hidden_size, args.k).to(args.device) if args.ckp != "": model.load_state_dict(torch.load(args.ckp)) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) if args.tmodel != '': net = vgg.VGG('VGG19') net = net.to(args.device) net = torch.nn.DataParallel(net) checkpoint = torch.load(args.tmodel) net.load_state_dict(checkpoint['net']) target_model = net # Generate the samples first once reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('reconstruction', grid, 0) best_loss = -1. for epoch in range(args.num_epochs): print(epoch) # if epoch<100: # args.lr = 1e-5 # if epoch>100 and epoch< 400: # args.lr = 2e-5 train(train_loader, model, target_model, optimizer, args, writer) loss, _ = test(valid_loader, model, args, writer) print("test loss:", loss) reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('reconstruction', grid, epoch + 1) if (epoch == 0) or (loss < best_loss): best_loss = loss with open('{0}/best.pt'.format(save_filename), 'wb') as f: torch.save(model.state_dict(), f) with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1), 'wb') as f: torch.save(model.state_dict(), f)
def main(args): writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}'.format(args.output_folder) if args.dataset == 'atari': transform = transforms.Compose([ transforms.RandomResizedCrop(84), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) dictDir = args.data_folder+'data_dict/' dataDir = args.data_folder+'data_traj/' #dictDir = args.data_folder+'test_dict/' #dataDir = args.data_folder+'test_traj/' all_partition = defaultdict(list) all_labels = defaultdict(list) # Datasets for dictionary in os.listdir(dictDir): ######## if args.out_game not in dictionary: ####### dfile = open(dictDir+dictionary, 'rb') d = pickle.load(dfile) dfile.close() if("partition" in dictionary): for key in d: all_partition[key] += d[key] elif("labels" in dictionary): for key in d: all_labels[key] = d[key] else: print("Error: Unexpected data dictionary") #partition = # IDs #labels = # Labels # Generators training_set = Dataset(all_partition['train'], all_labels, dataDir) train_loader = data.DataLoader(training_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) validation_set = Dataset(all_partition['validation'], all_labels, dataDir) valid_loader = data.DataLoader(validation_set, batch_size=args.batch_size, shuffle=True, drop_last=False, num_workers=args.num_workers, pin_memory=True) test_loader = data.DataLoader(validation_set, batch_size=16, shuffle=True) input_channels = 13 output_channels = 4 # Define the data loaders # train_loader = torch.utils.data.DataLoader(train_dataset, # batch_size=args.batch_size, shuffle=False, # num_workers=args.num_workers, pin_memory=True) # valid_loader = torch.utils.data.DataLoader(valid_dataset, # batch_size=args.batch_size, shuffle=False, drop_last=True, # num_workers=args.num_workers, pin_memory=True) # test_loader = torch.utils.data.DataLoader(test_dataset, # batch_size=16, shuffle=True) # Fixed images for Tensorboard fixed_images, fixed_y = next(iter(test_loader)) fixed_y = fixed_y[:,0:3,:,:] fixed_grid = make_grid(fixed_y, nrow=8, range=(0, 1), normalize=True) writer.add_image('original', fixed_grid, 0) model = VectorQuantizedVAE(input_channels, output_channels, args.hidden_size, args.k).to(args.device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # Generate the samples first once reconstruction = generate_samples(fixed_images, model, args) reconstruction_image = reconstruction[:,0:3,:,:] grid = make_grid(reconstruction_image.cpu(), nrow=8, range=(0, 1), normalize=True) writer.add_image('reconstruction', grid, 0) best_loss = -1. print("Starting to train...") for epoch in range(args.num_epochs): train(train_loader, model, optimizer, args, writer) loss, _ = test(valid_loader, model, args, writer) print("Finished Epoch: " + str(epoch) + " Validation Loss: " + str(loss)) reconstruction = generate_samples(fixed_images, model, args) reconstruction_image = reconstruction[:,0:3,:,:] grid = make_grid(reconstruction_image.cpu(), nrow=8, range=(0, 1), normalize=True) writer.add_image('reconstruction', grid, epoch + 1) if (epoch == 0) or (loss < best_loss): best_loss = loss with open('{0}/best.pt'.format(save_filename), 'wb') as f: torch.save(model.state_dict(), f) with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1), 'wb') as f: torch.save(model.state_dict(), f)