def get_loaders( traindir, valdir, sz, bs, fp16=True, val_bs=None, workers=8, rect_val=False, min_scale=0.08, distributed=False, synthetic=False, ): val_bs = val_bs or bs train_tfms = [transforms.RandomResizedCrop(sz, scale=(min_scale, 1.0)), transforms.RandomHorizontalFlip()] train_dataset = datasets.ImageFolder(traindir, transforms.Compose(train_tfms)) train_sampler = ( DistributedSampler(train_dataset, num_replicas=env_world_size(), rank=env_rank()) if distributed else None ) if synthetic: print("Using synthetic dataloader") train_loader = SyntheticDataLoader(bs, (3, sz, sz)) elif util.is_set("PYTORCH_USE_SPAWN"): print("Using SPAWN method for dataloader") train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=bs, shuffle=(train_sampler is None), num_workers=workers, pin_memory=True, collate_fn=fast_collate, sampler=train_sampler, multiprocessing_context="spawn", ) else: train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=bs, shuffle=(train_sampler is None), num_workers=workers, pin_memory=True, collate_fn=fast_collate, sampler=train_sampler, ) val_dataset, val_sampler = create_validation_set(valdir, val_bs, sz, rect_val=rect_val, distributed=distributed) val_loader = torch.utils.data.DataLoader( val_dataset, num_workers=workers, pin_memory=True, collate_fn=fast_collate, batch_sampler=val_sampler ) train_loader = BatchTransformDataLoader(train_loader, fp16=fp16) val_loader = BatchTransformDataLoader(val_loader, fp16=fp16) return train_loader, val_loader, train_sampler, val_sampler
def __init__(self, indices, batch_size, distributed=True): self.indices = indices self.batch_size = batch_size if distributed: self.world_size = env_world_size() self.global_rank = env_rank() else: self.global_rank = 0 self.world_size = 1 # expected number of batches per sample. Need this so each distributed gpu validates on same number of batches. # even if there isn't enough data to go around self.expected_num_batches = math.ceil(len(self.indices) / self.world_size / self.batch_size) # num_samples = total images / world_size. This is what we distribute to each gpu self.num_samples = self.expected_num_batches * self.batch_size
def get_loaders(traindir, valdir, sz, bs, fp16=False, val_bs=None, workers=8, rect_val=False, min_scale=0.08, distributed=False): val_bs = val_bs or bs train_tfms = [ transforms.RandomResizedCrop(sz, scale=(min_scale, 1.0)), transforms.RandomHorizontalFlip() ] train_dataset = datasets.ImageFolder(traindir, transforms.Compose(train_tfms)) train_sampler = (DistributedSampler( train_dataset, num_replicas=env_world_size(), rank=env_rank()) if distributed else None) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=bs, shuffle=(train_sampler is None), num_workers=workers, pin_memory=True, collate_fn=fast_collate, sampler=train_sampler) val_dataset, val_sampler = create_validation_set(valdir, val_bs, sz, rect_val=rect_val, distributed=distributed) val_loader = torch.utils.data.DataLoader(val_dataset, num_workers=workers, pin_memory=True, collate_fn=fast_collate, batch_sampler=val_sampler) train_loader = BatchTransformDataLoader(train_loader, fp16=fp16) val_loader = BatchTransformDataLoader(val_loader, fp16=fp16) return train_loader, val_loader, train_sampler, val_sampler
def main(): # os.system('shutdown -c') # cancel previous shutdown command if write_log: utils.makedirs(args.save) logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) logger.info(args) args_file_path = os.path.join(args.save, 'args.yaml') with open(args_file_path, 'w') as f: yaml.dump(vars(args), f, default_flow_style=False) if args.distributed: if write_log: logger.info('Distributed initializing process group') torch.cuda.set_device(args.local_rank) distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=dist_utils.env_world_size(), rank=env_rank()) assert (dist_utils.env_world_size() == distributed.get_world_size()) if write_log: logger.info("Distributed: success (%d/%d)" % (args.local_rank, distributed.get_world_size())) device = torch.device("cuda:%d" % torch.cuda.current_device() if torch.cuda.is_available() else "cpu") else: device = torch.cuda.current_device() # # import pdb; pdb.set_trace() cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True) # load dataset train_loader, test_loader, data_shape = get_dataset(args) trainlog = os.path.join(args.save, 'training.csv') testlog = os.path.join(args.save, 'test.csv') traincolumns = ['itr', 'wall', 'itr_time', 'loss', 'bpd', 'fe', 'total_time', 'grad_norm'] testcolumns = ['wall', 'epoch', 'eval_time', 'bpd', 'fe', 'total_time', 'transport_cost'] # build model regularization_fns, regularization_coeffs = create_regularization_fns(args) model = create_model(args, data_shape, regularization_fns).cuda() if args.distributed: model = dist_utils.DDP(model, device_ids=[args.local_rank], output_device=args.local_rank) traincolumns = append_regularization_keys_header(traincolumns, regularization_fns) if not args.resume and write_log: with open(trainlog, 'w') as f: csvlogger = csv.DictWriter(f, traincolumns) csvlogger.writeheader() with open(testlog, 'w') as f: csvlogger = csv.DictWriter(f, testcolumns) csvlogger.writeheader() set_cnf_options(args, model) if write_log: logger.info(model) if write_log: logger.info("Number of trainable parameters: {}".format(count_parameters(model))) if write_log: logger.info('Iters per train epoch: {}'.format(len(train_loader))) if write_log: logger.info('Iters per test: {}'.format(len(test_loader))) # optimizer if args.optimizer == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9, nesterov=False) # restore parameters # import pdb; pdb.set_trace() if args.resume is not None: # import pdb; pdb.set_trace() print('resume from checkpoint') checkpt = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda(args.local_rank)) model.load_state_dict(checkpt["state_dict"]) if "optim_state_dict" in checkpt.keys(): optimizer.load_state_dict(checkpt["optim_state_dict"]) # Manually move optimizer state to device. for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = cvt(v) # For visualization. if write_log: fixed_z = cvt(torch.randn(min(args.test_batch_size, 100), *data_shape)) if write_log: time_meter = utils.RunningAverageMeter(0.97) bpd_meter = utils.RunningAverageMeter(0.97) loss_meter = utils.RunningAverageMeter(0.97) steps_meter = utils.RunningAverageMeter(0.97) grad_meter = utils.RunningAverageMeter(0.97) tt_meter = utils.RunningAverageMeter(0.97) if not args.resume: best_loss = float("inf") itr = 0 wall_clock = 0. begin_epoch = 1 chkdir = args.save ''' elif args.resume and args.validate: chkdir = os.path.dirname(args.resume) wall_clock = 0 itr = 0 best_loss = 0.0 begin_epoch = 0 ''' else: chkdir = os.path.dirname(args.resume) filename = os.path.join(chkdir, 'test.csv') print(filename) tedf = pd.read_csv(os.path.join(chkdir, 'test.csv')) trdf = pd.read_csv(os.path.join(chkdir, 'training.csv')) # import pdb; pdb.set_trace() wall_clock = trdf['wall'].to_numpy()[-1] itr = trdf['itr'].to_numpy()[-1] best_loss = tedf['bpd'].min() begin_epoch = int(tedf['epoch'].to_numpy()[-1] + 1) # not exactly correct if args.distributed: if write_log: logger.info('Syncing machines before training') dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda()) for epoch in range(begin_epoch, begin_epoch + 1): # compute test loss print('Evaluating') model.eval() if args.local_rank == 0: utils.makedirs(args.save) # import pdb; pdb.set_trace() if hasattr(model, 'module'): _state = model.module.state_dict() else: _state = model.state_dict() torch.save({ "args": args, "state_dict": _state, # model.module.state_dict() if torch.cuda.is_available() else model.state_dict(), "optim_state_dict": optimizer.state_dict(), "fixed_z": fixed_z.cpu() }, os.path.join(args.save, "checkpt_%d.pth" % epoch)) # save real and generate with different temperatures fig_num = 64 if True: # args.save_real: for i, (x, y) in enumerate(test_loader): if i < 100: pass elif i == 100: real = x.size(0) else: break if x.shape[0] > fig_num: x = x[:fig_num, ...] # import pdb; pdb.set_trace() fig_filename = os.path.join(chkdir, "real.jpg") save_image(x.float() / 255.0, fig_filename, nrow=8) if True: # args.generate: print('\nGenerating images... ') fixed_z = cvt(torch.randn(fig_num, *data_shape)) nb = int(np.ceil(np.sqrt(float(fixed_z.size(0))))) for t in [ 1.0, 0.99, 0.98, 0.97,0.96,0.95,0.93,0.92,0.90,0.85,0.8,0.75,0.7,0.65,0.6]: # visualize samples and density fig_filename = os.path.join(chkdir, "generated-T%g.jpg" % t) utils.makedirs(os.path.dirname(fig_filename)) generated_samples = model(t * fixed_z, reverse=True) x = unshift(generated_samples[0].view(-1, *data_shape), 8) save_image(x, fig_filename, nrow=nb)
def get_dataset(args): trans = lambda im_size: tforms.Compose([tforms.Resize(im_size)]) if args.data == "mnist": im_dim = 1 im_size = 28 if args.imagesize is None else args.imagesize train_set = dset.MNIST(root=args.datadir, train=True, transform=trans(im_size), download=True) test_set = dset.MNIST(root=args.datadir, train=False, transform=trans(im_size), download=True) elif args.data == "svhn": im_dim = 3 im_size = 32 if args.imagesize is None else args.imagesize train_set = dset.SVHN(root=args.datadir, split="train", transform=trans(im_size), download=True) test_set = dset.SVHN(root=args.datadir, split="test", transform=trans(im_size), download=True) elif args.data == "cifar10": im_dim = 3 im_size = 32 if args.imagesize is None else args.imagesize train_set = dset.CIFAR10( root=args.datadir, train=True, transform=tforms.Compose([ tforms.Resize(im_size), tforms.RandomHorizontalFlip(), ]), download=True ) test_set = dset.CIFAR10(root=args.datadir, train=False, transform=None, download=True) elif args.data == 'celebahq': im_dim = 3 im_size = 256 if args.imagesize is None else args.imagesize ''' train_set = CelebAHQ( train=True, root=args.datadir, transform=tforms.Compose([ tforms.ToPILImage(), tforms.Resize(im_size), tforms.RandomHorizontalFlip(), ]) ) test_set = CelebAHQ( train=False, root=args.datadir, transform=tforms.Compose([ tforms.ToPILImage(), tforms.Resize(im_size), ]) ) ''' train_set = CelebAHQ(train=True, root=args.datadir) test_set = CelebAHQ(train=False, root=args.datadir) elif args.data == 'imagenet64': im_dim = 3 if args.imagesize != 64: args.imagesize = 64 im_size = 64 train_set = Imagenet64(train=True, root=args.datadir) test_set = Imagenet64(train=False, root=args.datadir) elif args.data == 'lsun_church': im_dim = 3 im_size = 64 if args.imagesize is None else args.imagesize train_set = dset.LSUN( 'data', ['church_outdoor_train'], transform=tforms.Compose([ tforms.Resize(96), tforms.RandomCrop(64), tforms.Resize(im_size), ]) ) test_set = dset.LSUN( 'data', ['church_outdoor_val'], transform=tforms.Compose([ tforms.Resize(96), tforms.RandomCrop(64), tforms.Resize(im_size), ]) ) data_shape = (im_dim, im_size, im_size) def fast_collate(batch): imgs = [img[0] for img in batch] targets = torch.tensor([target[1] for target in batch], dtype=torch.int64) w = imgs[0].size[0] h = imgs[0].size[1] tensor = torch.zeros((len(imgs), im_dim, im_size, im_size), dtype=torch.uint8) for i, img in enumerate(imgs): nump_array = np.asarray(img, dtype=np.uint8) tens = torch.from_numpy(nump_array) if (nump_array.ndim < 3): nump_array = np.expand_dims(nump_array, axis=-1) nump_array = np.rollaxis(nump_array, 2) tensor[i] += torch.from_numpy(nump_array) return tensor, targets train_sampler = (DistributedSampler(train_set, num_replicas=env_world_size(), rank=env_rank()) if args.distributed else None) if not args.distributed: train_loader = torch.utils.data.DataLoader( dataset=train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.nworkers, pin_memory=True, collate_fn=fast_collate ) else: train_loader = torch.utils.data.DataLoader( dataset=train_set, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.nworkers, pin_memory=True, collate_fn=fast_collate ) # import pdb # pdb.set_trace() test_sampler = (DistributedSampler(test_set, num_replicas=env_world_size(), rank=env_rank()) if args.distributed else None) if not args.distributed: test_loader = torch.utils.data.DataLoader( dataset=test_set, batch_size=args.test_batch_size, shuffle=False, num_workers=args.nworkers, pin_memory=True, collate_fn=fast_collate ) else: test_loader = torch.utils.data.DataLoader( dataset=test_set, batch_size=args.test_batch_size, num_workers=args.nworkers, pin_memory=True, sampler=test_sampler, collate_fn=fast_collate ) return train_loader, test_loader, data_shape
# help='Shutdown instance at the end of training or failure') # parser.add_argument('--auto-shutdown-success-delay-mins', default=10, type=int, # help='how long to wait until shutting down on success') # parser.add_argument('--auto-shutdown-failure-delay-mins', default=60, type=int, # help='how long to wait before shutting down on error') return parser cudnn.benchmark = True args = get_parser().parse_args() # torch.manual_seed(args.seed) nvals = 2 ** args.nbits # Only want master rank logging is_master = (not args.distributed) or (dist_utils.env_rank() == 0) is_rank0 = args.local_rank == 0 write_log = is_rank0 and is_master def add_noise(x, nbits=8): if nbits < 8: x = x // (2 ** (8 - nbits)) if args.add_noise: noise = x.new().resize_as_(x).uniform_() else: noise = 1 / 2 return x.add_(noise).div_(2 ** nbits) def shift(x, nbits=8):
def main(): #os.system('shutdown -c') # cancel previous shutdown command if write_log: utils.makedirs(args.save) logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) logger.info(args) args_file_path = os.path.join(args.save, 'args.yaml') with open(args_file_path, 'w') as f: yaml.dump(vars(args), f, default_flow_style=False) if args.distributed: if write_log: logger.info('Distributed initializing process group') torch.cuda.set_device(args.local_rank) distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=dist_utils.env_world_size(), rank=env_rank()) assert (dist_utils.env_world_size() == distributed.get_world_size()) if write_log: logger.info("Distributed: success (%d/%d)" % (args.local_rank, distributed.get_world_size())) # get deivce # device = torch.device("cuda:%d"%torch.cuda.current_device() if torch.cuda.is_available() else "cpu") device = "cpu" cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True) # load dataset train_loader, test_loader, data_shape = get_dataset(args) trainlog = os.path.join(args.save, 'training.csv') testlog = os.path.join(args.save, 'test.csv') traincolumns = [ 'itr', 'wall', 'itr_time', 'loss', 'bpd', 'fe', 'total_time', 'grad_norm' ] testcolumns = [ 'wall', 'epoch', 'eval_time', 'bpd', 'fe', 'total_time', 'transport_cost' ] # build model regularization_fns, regularization_coeffs = create_regularization_fns(args) model = create_model(args, data_shape, regularization_fns) # model = model.cuda() if args.distributed: model = dist_utils.DDP(model, device_ids=[args.local_rank], output_device=args.local_rank) traincolumns = append_regularization_keys_header(traincolumns, regularization_fns) if not args.resume and write_log: with open(trainlog, 'w') as f: csvlogger = csv.DictWriter(f, traincolumns) csvlogger.writeheader() with open(testlog, 'w') as f: csvlogger = csv.DictWriter(f, testcolumns) csvlogger.writeheader() set_cnf_options(args, model) if write_log: logger.info(model) if write_log: logger.info("Number of trainable parameters: {}".format( count_parameters(model))) if write_log: logger.info('Iters per train epoch: {}'.format(len(train_loader))) if write_log: logger.info('Iters per test: {}'.format(len(test_loader))) # optimizer if args.optimizer == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9, nesterov=False) # restore parameters if args.resume is not None: checkpt = torch.load( args.resume, map_location=lambda storage, loc: storage.cuda(args.local_rank)) model.load_state_dict(checkpt["state_dict"]) if "optim_state_dict" in checkpt.keys(): optimizer.load_state_dict(checkpt["optim_state_dict"]) # Manually move optimizer state to device. for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = cvt(v) # For visualization. if write_log: fixed_z = cvt(torch.randn(min(args.test_batch_size, 100), *data_shape)) if write_log: time_meter = utils.RunningAverageMeter(0.97) bpd_meter = utils.RunningAverageMeter(0.97) loss_meter = utils.RunningAverageMeter(0.97) steps_meter = utils.RunningAverageMeter(0.97) grad_meter = utils.RunningAverageMeter(0.97) tt_meter = utils.RunningAverageMeter(0.97) if not args.resume: best_loss = float("inf") itr = 0 wall_clock = 0. begin_epoch = 1 else: chkdir = os.path.dirname(args.resume) tedf = pd.read_csv(os.path.join(chkdir, 'test.csv')) trdf = pd.read_csv(os.path.join(chkdir, 'training.csv')) wall_clock = trdf['wall'].to_numpy()[-1] itr = trdf['itr'].to_numpy()[-1] best_loss = tedf['bpd'].min() begin_epoch = int(tedf['epoch'].to_numpy()[-1] + 1) # not exactly correct if args.distributed: if write_log: logger.info('Syncing machines before training') dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda()) for epoch in range(begin_epoch, args.num_epochs + 1): if not args.validate: model.train() with open(trainlog, 'a') as f: if write_log: csvlogger = csv.DictWriter(f, traincolumns) for _, (x, y) in enumerate(train_loader): start = time.time() update_lr(optimizer, itr) optimizer.zero_grad() # cast data and move to device x = add_noise(cvt(x), nbits=args.nbits) #x = x.clamp_(min=0, max=1) # compute loss bpd, (x, z), reg_states = compute_bits_per_dim(x, model) if np.isnan(bpd.data.item()): raise ValueError('model returned nan during training') elif np.isinf(bpd.data.item()): raise ValueError('model returned inf during training') loss = bpd if regularization_coeffs: reg_loss = sum(reg_state * coeff for reg_state, coeff in zip( reg_states, regularization_coeffs) if coeff != 0) loss = loss + reg_loss total_time = count_total_time(model) loss.backward() nfe_opt = count_nfe(model) if write_log: steps_meter.update(nfe_opt) grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), args.max_grad_norm) optimizer.step() itr_time = time.time() - start wall_clock += itr_time batch_size = x.size(0) metrics = torch.tensor([ 1., batch_size, loss.item(), bpd.item(), nfe_opt, grad_norm, *reg_states ]).float() rv = tuple(torch.tensor(0.) for r in reg_states) total_gpus, batch_total, r_loss, r_bpd, r_nfe, r_grad_norm, *rv = dist_utils.sum_tensor( metrics).cpu().numpy() if write_log: time_meter.update(itr_time) bpd_meter.update(r_bpd / total_gpus) loss_meter.update(r_loss / total_gpus) grad_meter.update(r_grad_norm / total_gpus) tt_meter.update(total_time) fmt = '{:.4f}' logdict = { 'itr': itr, 'wall': fmt.format(wall_clock), 'itr_time': fmt.format(itr_time), 'loss': fmt.format(r_loss / total_gpus), 'bpd': fmt.format(r_bpd / total_gpus), 'total_time': fmt.format(total_time), 'fe': r_nfe / total_gpus, 'grad_norm': fmt.format(r_grad_norm / total_gpus), } if regularization_coeffs: rv = tuple(v_ / total_gpus for v_ in rv) logdict = append_regularization_csv_dict( logdict, regularization_fns, rv) csvlogger.writerow(logdict) if itr % args.log_freq == 0: log_message = ( "Itr {:06d} | Wall {:.3e}({:.2f}) | " "Time/Itr {:.2f}({:.2f}) | BPD {:.2f}({:.2f}) | " "Loss {:.2f}({:.2f}) | " "FE {:.0f}({:.0f}) | Grad Norm {:.3e}({:.3e}) | " "TT {:.2f}({:.2f})".format( itr, wall_clock, wall_clock / (itr + 1), time_meter.val, time_meter.avg, bpd_meter.val, bpd_meter.avg, loss_meter.val, loss_meter.avg, steps_meter.val, steps_meter.avg, grad_meter.val, grad_meter.avg, tt_meter.val, tt_meter.avg)) if regularization_coeffs: log_message = append_regularization_to_log( log_message, regularization_fns, rv) logger.info(log_message) itr += 1 # compute test loss model.eval() if args.local_rank == 0: utils.makedirs(args.save) torch.save( { "args": args, "state_dict": model.module.state_dict() if torch.cuda.is_available() else model.state_dict(), "optim_state_dict": optimizer.state_dict(), "fixed_z": fixed_z.cpu() }, os.path.join(args.save, "checkpt.pth")) if epoch % args.val_freq == 0 or args.validate: with open(testlog, 'a') as f: if write_log: csvlogger = csv.DictWriter(f, testcolumns) with torch.no_grad(): start = time.time() if write_log: logger.info("validating...") lossmean = 0. meandist = 0. steps = 0 tt = 0. for i, (x, y) in enumerate(test_loader): sh = x.shape x = shift(cvt(x), nbits=args.nbits) loss, (x, z), _ = compute_bits_per_dim(x, model) dist = (x.view(x.size(0), -1) - z).pow(2).mean(dim=-1).mean() meandist = i / (i + 1) * dist + meandist / (i + 1) lossmean = i / (i + 1) * lossmean + loss / (i + 1) tt = i / (i + 1) * tt + count_total_time(model) / (i + 1) steps = i / (i + 1) * steps + count_nfe(model) / (i + 1) loss = lossmean.item() metrics = torch.tensor([1., loss, meandist, steps]).float() total_gpus, r_bpd, r_mdist, r_steps = dist_utils.sum_tensor( metrics).cpu().numpy() eval_time = time.time() - start if write_log: fmt = '{:.4f}' logdict = { 'epoch': epoch, 'eval_time': fmt.format(eval_time), 'bpd': fmt.format(r_bpd / total_gpus), 'wall': fmt.format(wall_clock), 'total_time': fmt.format(tt), 'transport_cost': fmt.format(r_mdist / total_gpus), 'fe': '{:.2f}'.format(r_steps / total_gpus) } csvlogger.writerow(logdict) logger.info( "Epoch {:04d} | Time {:.4f}, Bit/dim {:.4f}, Steps {:.4f}, TT {:.2f}, Transport Cost {:.2e}" .format(epoch, eval_time, r_bpd / total_gpus, r_steps / total_gpus, tt, r_mdist / total_gpus)) loss = r_bpd / total_gpus if loss < best_loss and args.local_rank == 0: best_loss = loss shutil.copyfile(os.path.join(args.save, "checkpt.pth"), os.path.join(args.save, "best.pth")) # visualize samples and density if write_log: with torch.no_grad(): fig_filename = os.path.join(args.save, "figs", "{:04d}.jpg".format(epoch)) utils.makedirs(os.path.dirname(fig_filename)) generated_samples, _, _ = model(fixed_z, reverse=True) generated_samples = generated_samples.view(-1, *data_shape) nb = int(np.ceil(np.sqrt(float(fixed_z.size(0))))) save_image(unshift(generated_samples, nbits=args.nbits), fig_filename, nrow=nb) if args.validate: break