def go(arg): tbw = SummaryWriter(log_dir=arg.tb_dir) if arg.task == 'mnist': transform = Compose([Pad(padding=2), ToTensor()]) trainset = torchvision.datasets.MNIST(root=arg.data_dir, train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=arg.batch_size, shuffle=True, num_workers=2) testset = torchvision.datasets.MNIST(root=arg.data_dir, train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=arg.batch_size, shuffle=False, num_workers=2) C, H, W = 1, 32, 32 elif arg.task == 'imagenet64': transform = Compose([ToTensor()]) trainset = torchvision.datasets.ImageFolder(root=arg.data_dir + os.sep + 'train', transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=arg.batch_size, shuffle=True, num_workers=2) testset = torchvision.datasets.ImageFolder(root=arg.data_dir + os.sep + 'valid', transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=arg.batch_size, shuffle=False, num_workers=2) C, H, W = 3, 64, 64 else: raise Exception('Task not recognized.') krn = arg.kernel_size pad = krn // 2 OUTCN = 64 encoder = models.ImEncoder(in_size=(H, W), zsize=arg.zsize, depth=arg.vae_depth, colors=C) # decoder = util.Lambda(lambda x: x) # identity decoder = models.ImDecoder(in_size=(H, W), zsize=arg.zsize, depth=arg.vae_depth, out_channels=OUTCN) pixcnn = models.CGated((C, H, W), (arg.zsize, ), arg.channels, num_layers=arg.num_layers, k=krn, padding=pad) ####################### if options.loadPreModel: encoder, decoder, pixcnn = loadModel(encoder, decoder, pixcnn) ######################## mods = [encoder, decoder, pixcnn] if torch.cuda.is_available(): for m in mods: m.cuda() print('Constructed network', encoder, decoder, pixcnn) sample_zs = torch.randn(12, arg.zsize) sample_zs = sample_zs.unsqueeze(1).expand(12, 6, -1).contiguous().view( 72, 1, -1).squeeze(1) # A sample of 144 square images with 3 channels, of the chosen resolution # (144 so we can arrange them in a 12 by 12 grid) sample_init_zeros = torch.zeros(72, C, H, W) sample_init_seeds = torch.zeros(72, C, H, W) sh, sw = H // SEEDFRAC, W // SEEDFRAC # Init second half of sample with patches from test set, to seed the sampling testbatch = util.readn(testloader, n=12) testbatch = testbatch.unsqueeze(1).expand(12, 6, C, H, W).contiguous().view( 72, 1, C, H, W).squeeze(1) sample_init_seeds[:, :, :sh, :] = testbatch[:, :, :sh, :] params = [] for m in mods: params.extend(m.parameters()) optimizer = Adam(params, lr=arg.lr) instances_seen = 0 for epoch in range(arg.epochs): # Train err_tr = [] for m in mods: m.train(True) for i, (input, _) in enumerate(tqdm.tqdm(trainloader)): if arg.limit is not None and i * arg.batch_size > arg.limit: break # Prepare the input b, c, w, h = input.size() if torch.cuda.is_available(): input = input.cuda() target = (input.data * 255).long() input, target = Variable(input), Variable(target) # Forward pass zs = encoder(input) kl_loss = util.kl_loss(*zs) z = util.sample(*zs) out = decoder(z) rec = pixcnn(input, out) rec_loss = cross_entropy(rec, target, reduce=False).view(b, -1).sum(dim=1) loss = (rec_loss + kl_loss).mean() instances_seen += input.size(0) tbw.add_scalar('pixel-models/vae/training/kl-loss', kl_loss.mean().data.item(), instances_seen) tbw.add_scalar('pixel-models/vae/training/rec-loss', rec_loss.mean().data.item(), instances_seen) err_tr.append(loss.data.item()) # Backward pass optimizer.zero_grad() loss.backward() optimizer.step() # Evaluate # - we evaluate on the test set, since this is only a simpe reproduction experiment # make sure to split off a validation set if you want to tune hyperparameters for something important err_te = [] for m in mods: m.train(False) for i, (input, _) in enumerate(tqdm.tqdm(testloader)): if arg.limit is not None and i * arg.batch_size > arg.limit: break b, c, w, h = input.size() if torch.cuda.is_available(): input = input.cuda() target = (input.data * 255).long() input, target = Variable(input), Variable(target) zs = encoder(input) kl_loss = util.kl_loss(*zs) z = util.sample(*zs) out = decoder(z) rec = pixcnn(input, out) rec_loss = cross_entropy(rec, target, reduce=False).view(b, -1).sum(dim=1) loss = (rec_loss + kl_loss).mean() err_te.append(loss.data.item()) tbw.add_scalar('pixel-models/test-loss', sum(err_te) / len(err_te), epoch) print('epoch={:02}; training loss: {:.3f}; test loss: {:.3f}'.format( epoch, sum(err_tr) / len(err_tr), sum(err_te) / len(err_te))) for m in mods: m.train(False) sample_zeros = draw_sample(sample_init_zeros, decoder, pixcnn, sample_zs, seedsize=(0, 0)) sample_seeds = draw_sample(sample_init_seeds, decoder, pixcnn, sample_zs, seedsize=(sh, W)) sample = torch.cat([sample_zeros, sample_seeds], dim=0) torchvision.utils.save_image( sample, 'myResults/sample_{:02d}.png'.format(epoch), nrow=12, padding=0) saveModel(encoder, decoder, pixcnn)
def go(arg): tbw = SummaryWriter(log_dir=arg.tb_dir) ## Load the data if arg.task == 'mnist': trainset = torchvision.datasets.MNIST(root=arg.data_dir, train=True, download=True, transform=ToTensor()) trainloader = torch.utils.data.DataLoader(trainset, batch_size=arg.batch_size, shuffle=True, num_workers=2) testset = torchvision.datasets.MNIST(root=arg.data_dir, train=False, download=True, transform=ToTensor()) testloader = torch.utils.data.DataLoader(testset, batch_size=arg.batch_size, shuffle=False, num_workers=2) C, H, W = 1, 28, 28 CLS = 10 elif arg.task == 'cifar10': trainset = torchvision.datasets.CIFAR10(root=arg.data_dir, train=True, download=True, transform=ToTensor()) trainloader = torch.utils.data.DataLoader(trainset, batch_size=arg.batch_size, shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR10(root=arg.data_dir, train=False, download=True, transform=ToTensor()) testloader = torch.utils.data.DataLoader(testset, batch_size=arg.batch_size, shuffle=False, num_workers=2) C, H, W = 3, 32, 32 CLS = 10 else: raise Exception('Task {} not recognized.'.format(arg.task)) ## Set up the model if arg.model == 'gated': model = models.CGated((C, H, W), (CLS, ), arg.channels, num_layers=arg.num_layers, k=arg.kernel_size, padding=arg.kernel_size // 2) else: raise Exception('model "{}" not recognized'.format(arg.model)) print('Constructed network', model) # A sample of 144 square images with 3 channels, of the chosen resolution # (144 so we can arrange them in a 12 by 12 grid) sample_init_zeros = torch.zeros(72, C, H, W) sample_init_seeds = torch.zeros(72, C, H, W) sh, sw = H // SEEDFRAC, W // SEEDFRAC # Init second half of sample with patches from test set, to seed the sampling testbatch = util.readn(testloader, n=12) testcls_seeds = util.readn(testloader, n=12, cls=True, maxval=CLS) testbatch = testbatch.unsqueeze(1).expand(12, 6, C, H, W).contiguous().view( 72, 1, C, H, W).squeeze(1) sample_init_seeds[:, :, :sh, :] = testbatch[:, :, :sh, :] testcls_seeds = testcls_seeds.unsqueeze(1).expand( 12, 6, CLS).contiguous().view(72, 1, CLS).squeeze(1) # Get classes for the unseeded part testcls_zeros = util.readn(testloader, n=24, cls=True, maxval=CLS)[12:] testcls_zeros = testcls_zeros.unsqueeze(1).expand( 12, 6, CLS).contiguous().view(72, 1, CLS).squeeze(1) optimizer = Adam(model.parameters(), lr=arg.lr) if torch.cuda.is_available(): model.cuda() model, optimizer = amp.initialize(model, optimizer, opt_level="O1") instances_seen = 0 for epoch in range(arg.epochs): # Train err_tr = [] model.train(True) for i, (input, classes) in enumerate(tqdm.tqdm(trainloader)): if arg.limit is not None and i * arg.batch_size > arg.limit: break # Prepare the input b, c, w, h = input.size() classes = util.one_hot(classes, CLS) if torch.cuda.is_available(): input, classes = input.cuda(), classes.cuda() target = (input.data * 255).long() input, classes, target = Variable(input), Variable( classes), Variable(target) # Forward pass result = model(input, classes) loss = cross_entropy(result, target) loss = loss * util.LOG2E # Convert from nats to bits instances_seen += input.size(0) tbw.add_scalar('pixel-models/training-loss', loss.data.item(), instances_seen) err_tr.append(loss.data.item()) # Backward pass optimizer.zero_grad() with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() if epoch % arg.eval_every == 0 and epoch != 0: with torch.no_grad(): # Evaluate # - we evaluate on the test set, since this is only a simpe reproduction experiment # make sure to split off a validation set if you want to tune hyperparameters for something important err_te = [] model.train(False) for i, (input, classes) in enumerate(tqdm.tqdm(testloader)): if arg.limit is not None and i * arg.batch_size > arg.limit: break classes = util.one_hot(classes, CLS) if torch.cuda.is_available(): input, classes = input.cuda(), classes.cuda() target = (input.data * 255).long() input, classes, target = Variable(input), Variable( classes), Variable(target) result = model(input, classes) loss = cross_entropy(result, target) loss = loss * util.LOG2E # Convert from nats to bits err_te.append(loss.data.item()) tbw.add_scalar('pixel-models/test-loss', sum(err_te) / len(err_te), epoch) print('epoch={:02}; training loss: {:.3f}; test loss: {:.3f}'. format(epoch, sum(err_tr) / len(err_tr), sum(err_te) / len(err_te))) model.train(False) sample_zeros = draw_sample(sample_init_zeros, testcls_zeros, model, seedsize=(0, 0), batch_size=arg.batch_size) sample_seeds = draw_sample(sample_init_seeds, testcls_seeds, model, seedsize=(sh, W), batch_size=arg.batch_size) sample = torch.cat([sample_zeros, sample_seeds], dim=0) utils.save_image(sample, 'sample_{:02d}.png'.format(epoch), nrow=12, padding=0)