def val_test(args): writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}'.format(args.output_folder) train_loader, valid_loader, test_loader = train_util.get_dataloaders(args) recons_input_img = train_util.log_input_img_grid(test_loader, writer) input_dim = 3 model = VectorQuantizedVAE(input_dim, args.hidden_size, args.k, args.enc_type, args.dec_type) # if torch.cuda.device_count() > 1 and args.device == "cuda": # model = torch.nn.DataParallel(model) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) discriminators = {} if args.recons_loss == "gan": recons_disc = Discriminator(input_dim, args.img_res, args.input_type).to(args.device) recons_disc_opt = torch.optim.Adam(recons_disc.parameters(), lr=args.disc_lr, amsgrad=True) discriminators["recons_disc"] = [recons_disc, recons_disc_opt] model.to(args.device) for disc in discriminators: discriminators[disc][0].to(args.device) if args.weights == "load": start_epoch = train_util.load_state(save_filename, model, optimizer, discriminators) else: start_epoch = 0 stop_patience = args.stop_patience best_loss = torch.tensor(np.inf) for epoch in tqdm(range(start_epoch, 4), file=sys.stdout): val_loss_dict, z = train_util.test(get_losses, model, valid_loader, args, discriminators, True) # if args.weights == "init" and epoch==1: # epoch+=1 # break train_util.log_recons_img_grid(recons_input_img, model, epoch + 1, args.device, writer) train_util.log_interp_img_grid(recons_input_img, model, epoch + 1, args.device, writer) train_util.log_losses("val", val_loss_dict, epoch + 1, writer) train_util.log_latent_metrics("val", z, epoch + 1, writer) train_util.save_state(model, optimizer, discriminators, val_loss_dict["recons_loss"], best_loss, args.recons_loss, epoch, save_filename) print(val_loss_dict)
def main(args): writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}'.format(args.output_folder) if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']: transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) if args.dataset == 'mnist': # Define the train & test datasets train_dataset = datasets.MNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.MNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'fashion-mnist': # Define the train & test datasets train_dataset = datasets.FashionMNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.FashionMNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'cifar10': # Define the train & test datasets train_dataset = datasets.CIFAR10(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.CIFAR10(args.data_folder, train=False, transform=transform) num_channels = 3 valid_dataset = test_dataset elif args.dataset == 'miniimagenet': transform = transforms.Compose([ transforms.RandomResizedCrop(32), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Define the train, valid & test datasets train_dataset = MiniImagenet(args.data_folder, train=True, download=True, transform=transform) valid_dataset = MiniImagenet(args.data_folder, valid=True, download=True, transform=transform) test_dataset = MiniImagenet(args.data_folder, test=True, download=True, transform=transform) num_channels = 3 # Define the data loaders train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True) # Fixed images for Tensorboard fixed_images, _ = next(iter(test_loader)) fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True) writer.add_image('original', fixed_grid, 0) model = VectorQuantizedVAE(num_channels, args.hidden_size, args.k).to(args.device) if args.ckp != "": model.load_state_dict(torch.load(args.ckp)) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) if args.tmodel != '': net = vgg.VGG('VGG19') net = net.to(args.device) net = torch.nn.DataParallel(net) checkpoint = torch.load(args.tmodel) net.load_state_dict(checkpoint['net']) target_model = net # Generate the samples first once reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('reconstruction', grid, 0) best_loss = -1. for epoch in range(args.num_epochs): print(epoch) # if epoch<100: # args.lr = 1e-5 # if epoch>100 and epoch< 400: # args.lr = 2e-5 train(train_loader, model, target_model, optimizer, args, writer) loss, _ = test(valid_loader, model, args, writer) print("test loss:", loss) reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('reconstruction', grid, epoch + 1) if (epoch == 0) or (loss < best_loss): best_loss = loss with open('{0}/best.pt'.format(save_filename), 'wb') as f: torch.save(model.state_dict(), f) with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1), 'wb') as f: torch.save(model.state_dict(), f)
def main(args): writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}'.format(args.output_folder) if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']: transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) if args.dataset == 'mnist': # Define the train & test datasets train_dataset = datasets.MNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.MNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'fashion-mnist': # Define the train & test datasets train_dataset = datasets.FashionMNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.FashionMNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'cifar10': # Define the train & test datasets train_dataset = datasets.CIFAR10(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.CIFAR10(args.data_folder, train=False, transform=transform) num_channels = 3 valid_dataset = test_dataset elif args.dataset == 'PubTabNet': transform = transforms.Compose( [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) train_dataset = PubTabNet(args.data_folder, args.data_name, 'TRAIN', transform=transform) test_dataset = PubTabNet(args.data_folder, args.data_name, 'VAL', transform=transform) valid_dataset = test_dataset num_channels = 3 elif args.dataset == 'miniimagenet': transform = transforms.Compose([ transforms.RandomResizedCrop(128), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Define the train, valid & test datasets train_dataset = MiniImagenet(args.data_folder, train=True, download=True, transform=transform) valid_dataset = MiniImagenet(args.data_folder, valid=True, download=True, transform=transform) test_dataset = MiniImagenet(args.data_folder, test=True, download=True, transform=transform) num_channels = 3 # Define the data loaders train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True) # Fixed images for Tensorboard fixed_images, _ = next(iter(test_loader)) fixed_grid = make_grid(fixed_images, nrow=4, range=(-1, 1), normalize=True) writer.add_image('original', fixed_grid, 0) model = VectorQuantizedVAE(num_channels, args.hidden_size, args.k).to(args.device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # Generate the samples first once reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=4, range=(-1, 1), normalize=True) writer.add_image('reconstruction', grid, 0) best_loss = -1. for epoch in range(args.num_epochs): train(train_loader, model, optimizer, args, writer, epoch) loss, _ = test(valid_loader, model, args, writer) eprint('Validataion loss at epoch %d: Loss = %.4f' % (epoch, loss)) reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=4, range=(-1, 1), normalize=True) writer.add_image('reconstruction', grid, epoch + 1) if (epoch == 0) or (loss < best_loss): best_loss = loss with open('{0}/best.pt'.format(save_filename), 'wb') as f: torch.save(model.state_dict(), f) with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1), 'wb') as f: torch.save(model.state_dict(), f)
def main(args): # set manualseed random.seed(args.manualseed) torch.manual_seed(args.manualseed) torch.cuda.manual_seed_all(args.manualseed) np.random.seed(args.manualseed) torch.backends.cudnn.deterministic = True writer = SummaryWriter(args.log_dir) save_filename = args.save_dir # if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']: # transform = transforms.Compose([ # transforms.ToTensor(), # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # ]) # if args.dataset == 'mnist': # assert args.nc == 1 # # Define the train & test datasets # train_dataset = datasets.MNIST(args.dataroot, # train=True, # download=True, # transform=transform) # test_dataset = datasets.MNIST(args.dataroot, # train=False, # transform=transform) # elif args.dataset == 'fashion-mnist': # # Define the train & test datasets # train_dataset = datasets.FashionMNIST(args.dataroot, # train=True, # download=True, # transform=transform) # test_dataset = datasets.FashionMNIST(args.dataroot, # train=False, # transform=transform) # elif args.dataset == 'cifar10': # # Define the train & test datasets # train_dataset = datasets.CIFAR10(args.dataroot, # train=True, # download=True, # transform=transform) # test_dataset = datasets.CIFAR10(args.dataroot, # train=False, # transform=transform) # valid_dataset = test_dataset # elif args.dataset == 'miniimagenet': # transform = transforms.Compose([ # transforms.RandomResizedCrop(128), # transforms.ToTensor(), # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # ]) # # Define the train, valid & test datasets # train_dataset = MiniImagenet(args.dataroot, # train=True, # download=True, # transform=transform) # valid_dataset = MiniImagenet(args.dataroot, # valid=True, # download=True, # transform=transform) # test_dataset = MiniImagenet(args.dataroot, # test=True, # download=True, # transform=transform) # Define the data loaders # train_loader = torch.utils.data.DataLoader(train_dataset, # batch_size=args.batch_size, # shuffle=False, # num_workers=args.num_workers, # pin_memory=True) # valid_loader = torch.utils.data.DataLoader(valid_dataset, # batch_size=args.batch_size, # shuffle=False, # drop_last=True, # num_workers=args.num_workers, # pin_memory=True) # test_loader = torch.utils.data.DataLoader(test_dataset, # batch_size=16, # shuffle=True) dataloader = load_data(args) train_loader = dataloader['train'] valid_loader = dataloader['valid'] test_loader = dataloader['test'] # Fixed images for Tensorboard fixed_images, _ = next(iter(valid_loader)) fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True) writer.add_image('original', fixed_grid, 0) model = VectorQuantizedVAE(args.nc, args.hidden_size, args.k).to(args.device) if len(args.gpu_ids) > 1: model = torch.nn.DataParallel(model, device_ids=args.gpu_ids, output_device=args.gpu_ids[0]) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # Generate the samples first once reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('reconstruction', grid, 0) best_loss = -1. # for epoch in range(args.num_epochs): for epoch in range(1, args.num_epochs + 1): train(train_loader, model, optimizer, args, writer, epoch) loss, _ = test(test_loader, model, args, writer) reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('reconstruction', grid, epoch + 1) if (epoch == 1) or (loss < best_loss): best_loss = loss with open('{0}/best.pt'.format(save_filename), 'wb') as f: torch.save(model.state_dict(), f) if (epoch % args.save_step) == 0: with open('{0}/model_{1}.pt'.format(save_filename, epoch), 'wb') as f: torch.save(model.state_dict(), f)
def main(args): now = datetime.now() current_time = now.strftime("%H:%M:%S") print("Start Time =", current_time) writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}'.format(args.output_folder) transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) # Define the train & test dataSets train_set = datasets.MNIST(args.data_folder, train=True, download=True, transform=transform) test_set = datasets.MNIST(args.data_folder, train=False, download=True, transform=transform) num_channels = 1 # Define the data loaders train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_set, num_workers=args.num_workers, batch_size=16, shuffle=False) # Fixed images for TensorBoard fixed_images, _ = next(iter(test_loader)) fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True) writer.add_image('original', fixed_grid, 0) model = VectorQuantizedVAE(num_channels, args.hidden_size, args.k).to(args.device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) writer.add_graph(model, fixed_images.to(args.device)) # get model structure on tensorboard # Generate the samples first once reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('reconstruction at start', grid, 0) img_list = [] best_loss = -1. for epoch in range(args.num_epochs): train(train_loader, model, optimizer, args, writer) loss, _ = test(test_loader, model, args, writer) reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('reconstruction at epoch {:f}'.format(epoch + 1), grid, epoch + 1) print("loss = {:f} at epoch {:f}".format(loss, epoch + 1)) writer.add_scalar('loss/testing_loss', loss, epoch + 1) img_list.append(grid) if (epoch == 0) or (loss < best_loss): best_loss = loss with open('{0}/best.pt'.format(save_filename), 'wb') as f: torch.save(model.state_dict(), f) with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1), 'wb') as f: torch.save(model.state_dict(), f) now = datetime.now() current_time = now.strftime("%H:%M:%S") print("End Time =", current_time)
def main(args): writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}'.format(args.output_folder) train_loader, valid_loader, test_loader = train_util.get_dataloaders(args) num_channels = 3 model = VectorQuantizedVAE(num_channels, args.hidden_size, args.k, args.enc_type, args.dec_type) model.to(args.device) # Fixed images for Tensorboard recons_input_img = train_util.log_input_img_grid(test_loader, writer) train_util.log_recons_img_grid(recons_input_img, model, 0, args.device, writer) discriminators = {} input_dim = 3 if args.recons_loss != "mse": if args.recons_loss == "gan": recons_disc = Discriminator(input_dim, args.img_res, args.input_type).to(args.device) elif args.recons_loss == "comp": recons_disc = AnchorComparator(input_dim * 2, args.img_res, args.input_type).to(args.device) elif "comp_2" in args.recons_loss: recons_disc = ClubbedPermutationComparator( input_dim * 2, args.img_res, args.input_type).to(args.device) elif "comp_6" in args.recons_loss: recons_disc = FullPermutationComparator( input_dim * 2, args.img_res, args.input_type).to(args.device) recons_disc_opt = torch.optim.Adam(recons_disc.parameters(), lr=args.disc_lr, amsgrad=True) recons_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( recons_disc_opt, "min", patience=args.lr_patience, factor=0.5, threshold=args.threshold, threshold_mode="abs", min_lr=1e-7) discriminators["recons_disc"] = [recons_disc, recons_disc_opt] optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) ae_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, "min", patience=args.lr_patience, factor=0.5, threshold=args.threshold, threshold_mode="abs", min_lr=1e-7) if torch.cuda.device_count() > 1: model = train_util.ae_data_parallel(model) for disc in discriminators: discriminators[disc][0] = torch.nn.DataParallel( discriminators[disc][0]) model.to(args.device) for disc in discriminators: discriminators[disc][0].to(args.device) # Generate the samples first once recons_input_img = train_util.log_input_img_grid(test_loader, writer) train_util.log_recons_img_grid(recons_input_img, model, 0, args.device, writer) if args.weights == "load": start_epoch = train_util.load_state(save_filename, model, optimizer, discriminators) else: start_epoch = 0 stop_patience = args.stop_patience best_loss = torch.tensor(np.inf) for epoch in tqdm(range(start_epoch, args.num_epochs), file=sys.stdout): try: train(epoch, train_loader, model, optimizer, args, writer, discriminators) except RuntimeError as err: print("".join( traceback.TracebackException.from_exception(err).format()), file=sys.stderr) print("*******") print(err, file=sys.stderr) print(f"batch_size:{args.batch_size}", file=sys.stderr) exit(0) val_loss_dict, z = train_util.test(get_losses, model, valid_loader, args, discriminators) train_util.log_recons_img_grid(recons_input_img, model, epoch + 1, args.device, writer) train_util.log_interp_img_grid(recons_input_img, model, epoch + 1, args.device, writer) train_util.log_losses("val", val_loss_dict, epoch + 1, writer) train_util.log_latent_metrics("val", z, epoch + 1, writer) train_util.save_state(model, optimizer, discriminators, val_loss_dict["recons_loss"], best_loss, args.recons_loss, epoch, save_filename) # early stop check # if val_loss_dict["recons_loss"] - best_loss < args.threshold: # stop_patience -= 1 # else: # stop_patience = args.stop_patience # if stop_patience == 0: # print("training early stopped!") # break ae_lr_scheduler.step(val_loss_dict["recons_loss"]) if args.recons_loss != "mse": recons_disc_lr_scheduler.step(val_loss_dict["recons_disc_loss"])
shuffle=False, num_workers=NUM_WORKERS, pin_memory=True) test_loader = torch.utils.data.DataLoader(eval('datasets.' + DATASET)( '../data/{}/'.format(DATASET), train=False, transform=preproc_transform, ), batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True) model = VectorQuantizedVAE(INPUT_DIM, DIM, K).to(DEVICE) print(model) opt = torch.optim.Adam(model.parameters(), lr=LR, amsgrad=True) def train(): train_loss = [] for batch_idx, (x, _) in enumerate(train_loader): start_time = time.time() x = x.to(DEVICE) opt.zero_grad() x_tilde, z_e_x, z_q_x = model(x) z_q_x.retain_grad() loss_recons = F.mse_loss(x_tilde, x) loss_recons.backward(retain_graph=True)
def main(args): writer = SummaryWriter("./logs/{0}".format(args.output_folder)) save_filename = "./models/{0}".format(args.output_folder) if args.dataset in ["mnist", "fashion-mnist", "cifar10"]: transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] ) if args.dataset == "mnist": # Define the train & test datasets train_dataset = datasets.MNIST( args.data_folder, train=True, download=True, transform=transform ) test_dataset = datasets.MNIST( args.data_folder, train=False, transform=transform ) num_channels = 1 elif args.dataset == "fashion-mnist": # Define the train & test datasets train_dataset = datasets.FashionMNIST( args.data_folder, train=True, download=True, transform=transform ) test_dataset = datasets.FashionMNIST( args.data_folder, train=False, transform=transform ) num_channels = 1 elif args.dataset == "cifar10": # Define the train & test datasets train_dataset = datasets.CIFAR10( args.data_folder, train=True, download=True, transform=transform ) test_dataset = datasets.CIFAR10( args.data_folder, train=False, transform=transform ) num_channels = 3 valid_dataset = test_dataset elif args.dataset == "miniimagenet": transform = transforms.Compose( [ transforms.RandomResizedCrop(128), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] ) # Define the train, valid & test datasets train_dataset = MiniImagenet( args.data_folder, train=True, download=True, transform=transform ) valid_dataset = MiniImagenet( args.data_folder, valid=True, download=True, transform=transform ) test_dataset = MiniImagenet( args.data_folder, test=True, download=True, transform=transform ) num_channels = 3 else: transform = transforms.Compose( [ transforms.RandomResizedCrop(args.image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] ) # Define the train, valid & test datasets train_dataset = ImageFolder( os.path.join(args.data_folder, "train"), transform=transform ) valid_dataset = ImageFolder( os.path.join(args.data_folder, "val"), transform=transform ) test_dataset = valid_dataset num_channels = 3 train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True, ) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers, pin_memory=True, ) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True) # Fixed images for Tensorboard fixed_images, _ = next(iter(test_loader)) fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True) save_image(fixed_grid, "true.png") writer.add_image("original", fixed_grid, 0) model = VectorQuantizedVAE(num_channels, args.hidden_size, args.k).to(args.device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # Generate the samples first once reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) save_image(grid, "rec.png") writer.add_image("reconstruction", grid, 0) best_loss = -1 for epoch in range(args.num_epochs): train(train_loader, model, optimizer, args, writer) loss, _ = test(valid_loader, model, args, writer) print(epoch, "test loss: ", loss) reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) save_image(grid, "rec.png") writer.add_image("reconstruction", grid, epoch + 1) if (epoch == 0) or (loss < best_loss): best_loss = loss with open("{0}/best.pt".format(save_filename), "wb") as f: torch.save(model.state_dict(), f) with open("{0}/model_{1}.pt".format(save_filename, epoch + 1), "wb") as f: torch.save(model.state_dict(), f)
def main(args): writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}'.format(args.output_folder) if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']: transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) if args.dataset == 'mnist': # Define the train & test datasets train_dataset = datasets.MNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.MNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'fashion-mnist': # Define the train & test datasets train_dataset = datasets.FashionMNIST(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.FashionMNIST(args.data_folder, train=False, transform=transform) num_channels = 1 elif args.dataset == 'cifar10': # Define the train & test datasets train_dataset = datasets.CIFAR10(args.data_folder, train=True, download=True, transform=transform) test_dataset = datasets.CIFAR10(args.data_folder, train=False, transform=transform) num_channels = 3 valid_dataset = test_dataset elif args.dataset == 'clevr': transform = transforms.Compose([ # transforms.RandomResizedCrop(128), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) import socket if "Alien" in socket.gethostname(): dataset_name = "/media/mihir/dataset/clevr_veggies/" else: dataset_name = "/projects/katefgroup/datasets/clevr_veggies/" dataset_name = '/home/mprabhud/dataset/clevr_veggies' # Define the train, valid & test datasets train_dataset = Clevr(dataset_name,mod = args.modname\ , train=True, transform=transform,object_level= args.object_level) valid_dataset = Clevr(dataset_name,mod = args.modname,\ valid=True,transform=transform,object_level= args.object_level) test_dataset = Clevr(dataset_name,mod = args.modname,\ test=True, transform=transform,object_level= args.object_level) num_channels = 3 elif args.dataset == 'carla': if args.use_depth: transform = transforms.Compose([ # transforms.RandomResizedCrop(128), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5, 0.5), (0.5, 0.5, 0.5, 0.5)) ]) num_channels = 4 else: transform = transforms.Compose([ # transforms.RandomResizedCrop(128), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) num_channels = 3 import socket if "Alien" in socket.gethostname(): dataset_name = "/media/mihir/dataset/clevr_veggies/" else: dataset_name = '/home/shamitl/datasets/carla' dataset_name = "/projects/katefgroup/datasets/carla/" dataset_name = '/home/mprabhud/dataset/carla' # Define the train, valid & test datasets train_dataset = Clevr(dataset_name,mod = args.modname\ , train=True, transform=transform,object_level= args.object_level,use_depth=args.use_depth) valid_dataset = Clevr(dataset_name,mod = args.modname,\ valid=True,transform=transform,object_level= args.object_level,use_depth=args.use_depth) test_dataset = Clevr(dataset_name,mod = args.modname,\ test=True, transform=transform,object_level= args.object_level,use_depth=args.use_depth) # elif args.dataset == 'miniimagenet': # transform = transforms.Compose([ # transforms.RandomResizedCrop(128), # transforms.ToTensor(), # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # ]) # # Define the train, valid & test datasets # train_dataset = MiniImagenet(args.data_folder, train=True, # download=True, transform=transform) # # valid_dataset = MiniImagenet(args.data_folder, valid=True, # # download=True, transform=transform) # # test_dataset = MiniImagenet(args.data_folder, test=True, # # download=True, transform=transform) # # num_channels = 3 # Define the data loaders train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True) fixed_images, _ = next(iter(train_loader)) # Fixed images for Tensorboard fixed_images, _ = next(iter(test_loader)) fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True) writer.add_image('original', fixed_grid, 0) model = VectorQuantizedVAE(num_channels, args.hidden_size, args.object_level, args.k).to(args.device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) if args.load_model is not "": with open(args.load_model, 'rb') as f: state_dict = torch.load(f) model.load_state_dict(state_dict) # Generate the samples first once reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('reconstruction', grid, 0) # st() best_loss = -1. for epoch in range(args.num_epochs): if not args.test_mode: train(train_loader, model, optimizer, args, writer, epoch) # st() loss, _ = test_old(valid_loader, model, args, writer) reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('reconstruction', grid, epoch + 1) # st() with open('{0}/recent.pt'.format(save_filename), 'wb') as f: torch.save(model.state_dict(), f) if (epoch == 0) or (loss < best_loss): best_loss = loss with open('{0}/best.pt'.format(save_filename), 'wb') as f: torch.save(model.state_dict(), f) # else: # print("nothing") else: test(train_loader, model, args, writer) reconstruction = generate_samples(fixed_images, model, args) grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True) writer.add_image('reconstruction', grid, epoch + 1)
def main(args): writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}'.format(args.output_folder) if args.dataset == 'atari': transform = transforms.Compose([ transforms.RandomResizedCrop(84), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) dictDir = args.data_folder+'data_dict/' dataDir = args.data_folder+'data_traj/' #dictDir = args.data_folder+'test_dict/' #dataDir = args.data_folder+'test_traj/' all_partition = defaultdict(list) all_labels = defaultdict(list) # Datasets for dictionary in os.listdir(dictDir): ######## if args.out_game not in dictionary: ####### dfile = open(dictDir+dictionary, 'rb') d = pickle.load(dfile) dfile.close() if("partition" in dictionary): for key in d: all_partition[key] += d[key] elif("labels" in dictionary): for key in d: all_labels[key] = d[key] else: print("Error: Unexpected data dictionary") #partition = # IDs #labels = # Labels # Generators training_set = Dataset(all_partition['train'], all_labels, dataDir) train_loader = data.DataLoader(training_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) validation_set = Dataset(all_partition['validation'], all_labels, dataDir) valid_loader = data.DataLoader(validation_set, batch_size=args.batch_size, shuffle=True, drop_last=False, num_workers=args.num_workers, pin_memory=True) test_loader = data.DataLoader(validation_set, batch_size=16, shuffle=True) input_channels = 13 output_channels = 4 # Define the data loaders # train_loader = torch.utils.data.DataLoader(train_dataset, # batch_size=args.batch_size, shuffle=False, # num_workers=args.num_workers, pin_memory=True) # valid_loader = torch.utils.data.DataLoader(valid_dataset, # batch_size=args.batch_size, shuffle=False, drop_last=True, # num_workers=args.num_workers, pin_memory=True) # test_loader = torch.utils.data.DataLoader(test_dataset, # batch_size=16, shuffle=True) # Fixed images for Tensorboard fixed_images, fixed_y = next(iter(test_loader)) fixed_y = fixed_y[:,0:3,:,:] fixed_grid = make_grid(fixed_y, nrow=8, range=(0, 1), normalize=True) writer.add_image('original', fixed_grid, 0) model = VectorQuantizedVAE(input_channels, output_channels, args.hidden_size, args.k).to(args.device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # Generate the samples first once reconstruction = generate_samples(fixed_images, model, args) reconstruction_image = reconstruction[:,0:3,:,:] grid = make_grid(reconstruction_image.cpu(), nrow=8, range=(0, 1), normalize=True) writer.add_image('reconstruction', grid, 0) best_loss = -1. print("Starting to train...") for epoch in range(args.num_epochs): train(train_loader, model, optimizer, args, writer) loss, _ = test(valid_loader, model, args, writer) print("Finished Epoch: " + str(epoch) + " Validation Loss: " + str(loss)) reconstruction = generate_samples(fixed_images, model, args) reconstruction_image = reconstruction[:,0:3,:,:] grid = make_grid(reconstruction_image.cpu(), nrow=8, range=(0, 1), normalize=True) writer.add_image('reconstruction', grid, epoch + 1) if (epoch == 0) or (loss < best_loss): best_loss = loss with open('{0}/best.pt'.format(save_filename), 'wb') as f: torch.save(model.state_dict(), f) with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1), 'wb') as f: torch.save(model.state_dict(), f)