def forward_and_reverse_output_shape(self, in_channel, data, levels=3, depth=4): glow = Glow(in_channel, levels, depth) z, logdet, eps = glow(data) height, width = data.shape[2], data.shape[3] """ cifar example: Level = 3 initial shape -> [4, 3, 32, 32] iter 1 -> z: [4, 12, 16, 16] because of squeeze from outside the loop iter 2 -> z: [4, 24, 8, 8] because of squeeze + split iter 3 -> z: [4, 48, 4, 4] because of squeeze + split """ assert list(z.shape) == [4, in_channel * 4 * 2**(levels - 1), 4, 4] assert list(logdet.shape) == [4] # because batch_size = 4 assert len( eps ) == levels - 1 # because L = 3 and split is executed whenever < L, i.e 2 times in total factor = 1 for e in eps: factor *= 2 # example: first eps -> from iter 1 take z shape and divide channel by 2: [4, 12/2, 16, 16] assert list(e.shape) == [ 4, in_channel * factor, height / factor, width / factor ] """ In total depth * levels = 4 * 3 = 12, so we got 12 instances of actnorm, inconv and affinecoupling Actnorm = 2 trainable parameters Invconv = 3 trainable parameter Affinecoupling = 6 trainable parameters (got 3 conv layers, each layer has weight + bias, so for all layers combined we get 6 in total) Zeroconv = 4 (2 conv layers, each with weight + bias) 12 * (2+3+6) + 4= 136 """ assert len(list( glow.parameters())) == (levels * depth) * (2 + 3 + 6) + 4 for param in glow.parameters(): assert param.requires_grad # reverse # For cifar we expect z with level=3 to be of shape [4,48,4,4] z = glow.reverse(z, eps) assert list(z.shape) == [4, 3, 32, 32]
def main(args): # we're probably only be using 1 GPU, so this should be fine device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"running on {device}") # set random seed for all random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) global best_loss if(args.generate_samples): print("generating samples") # load data # example for CIFAR-10 training: train_set, test_set = get_dataloader(args.dataset, args.batch_size) input_channels = channels_from_dataset(args.dataset) print(f"amount of input channels: {input_channels}") # instantiate model # # baby network to make sure training script works net = Glow(in_channels=input_channels, depth=args.amt_flow_steps, levels=args.amt_levels, use_normalization=args.norm_method) # code for rosalinty model # net = RosGlow(input_channels, args.amt_flow_steps, args.amt_levels) net = net.to(device) print(f"training for {args.num_epochs} epochs.") start_epoch = 0 # TODO: add functionality for loading checkpoints here if args.resume: print(f"resuming from checkpoint found in checkpoints/best_{args.dataset.lower()}.pth.tar.") # raise error if no checkpoint directory is found assert os.path.isdir("new_checkpoints") checkpoint = torch.load(f"new_checkpoints/best_{args.dataset.lower()}.pth.tar") net.load_state_dict(checkpoint["model"]) global best_loss best_loss = checkpoint["test_loss"] start_epoch = checkpoint["epoch"] loss_function = FlowNLL().to(device) optimizer = optim.Adam(net.parameters(), lr=float(args.lr)) # scheduler found in code, no mention in paper # scheduler = sched.LambdaLR( # optimizer, lambda s: min(1., s / args.warmup_iters)) # should we add a resume function here? for epoch in range(start_epoch, start_epoch + args.num_epochs): print(f"training epoch {epoch}") train(net, train_set, device, optimizer, loss_function, epoch) # how often do we want to test? if (epoch % 10 == 0): # revert this to 10 once we know that this works print(f"testing epoch {epoch}") test(net, test_set, device, loss_function, epoch, args.generate_samples, args.amt_levels, args.dataset, args.n_samples)
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size, epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, learn_top, y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers, cuda, n_init_batches, warmup_steps, output_dir, saved_optimizer, warmup, fresh, logittransform, gan, disc_lr, sn, flowgan, eval_every, ld_on_samples, weight_gan, weight_prior, weight_logdet, jac_reg_lambda, affine_eps, no_warm_up, optim_name, clamp, svd_every, eval_only, no_actnorm, affine_scale_eps, actnorm_max_scale, no_conv_actnorm, affine_max_scale, actnorm_eps, init_sample, no_split, disc_arch, weight_entropy_reg, db): check_manual_seed(seed) ds = check_dataset(dataset, dataroot, augment, download) image_shape, num_classes, train_dataset, test_dataset = ds # Note: unsupported for now multi_class = False train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True) test_loader = data.DataLoader(test_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=n_workers, drop_last=False) model = Glow(image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, num_classes, learn_top, y_condition, logittransform, sn, affine_eps, no_actnorm, affine_scale_eps, actnorm_max_scale, no_conv_actnorm, affine_max_scale, actnorm_eps, no_split) model = model.to(device) if disc_arch == 'mine': discriminator = mine.Discriminator(image_shape[-1]) elif disc_arch == 'biggan': discriminator = cgan_models.Discriminator( image_channels=image_shape[-1], conditional_D=False) elif disc_arch == 'dcgan': discriminator = DCGANDiscriminator(image_shape[0], 64, image_shape[-1]) elif disc_arch == 'inv': discriminator = InvDiscriminator( image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, num_classes, learn_top, y_condition, logittransform, sn, affine_eps, no_actnorm, affine_scale_eps, actnorm_max_scale, no_conv_actnorm, affine_max_scale, actnorm_eps, no_split) discriminator = discriminator.to(device) D_optimizer = optim.Adam(filter(lambda p: p.requires_grad, discriminator.parameters()), lr=disc_lr, betas=(.5, .99), weight_decay=0) if optim_name == 'adam': optimizer = optim.Adam(model.parameters(), lr=lr, betas=(.5, .99), weight_decay=0) elif optim_name == 'adamax': optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5) if not no_warm_up: lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) iteration_fieldnames = [ 'global_iteration', 'fid', 'sample_pad', 'train_bpd', 'eval_bpd', 'pad', 'batch_real_acc', 'batch_fake_acc', 'batch_acc' ] iteration_logger = CSVLogger(fieldnames=iteration_fieldnames, filename=os.path.join(output_dir, 'iteration_log.csv')) iteration_fieldnames = [ 'global_iteration', 'condition_num', 'max_sv', 'min_sv', 'inverse_condition_num', 'inverse_max_sv', 'inverse_min_sv' ] svd_logger = CSVLogger(fieldnames=iteration_fieldnames, filename=os.path.join(output_dir, 'svd_log.csv')) # test_iter = test_loader.__iter__() N_inception = 1000 x_real_inception = torch.cat([ test_iter.__next__()[0].to(device) for _ in range(N_inception // args.batch_size + 1) ], 0)[:N_inception] x_real_inception = x_real_inception + .5 x_for_recon = test_iter.__next__()[0].to(device) def gan_step(engine, batch): assert not y_condition if 'iter_ind' in dir(engine): engine.iter_ind += 1 else: engine.iter_ind = -1 losses = {} model.train() discriminator.train() x, y = batch x = x.to(device) def run_noised_disc(discriminator, x): x = uniform_binning_correction(x)[0] return discriminator(x) real_acc = fake_acc = acc = 0 if weight_gan > 0: fake = generate_from_noise(model, x.size(0), clamp=clamp) D_real_scores = run_noised_disc(discriminator, x.detach()) D_fake_scores = run_noised_disc(discriminator, fake.detach()) ones_target = torch.ones((x.size(0), 1), device=x.device) zeros_target = torch.zeros((x.size(0), 1), device=x.device) D_real_accuracy = torch.sum( torch.round(F.sigmoid(D_real_scores)) == ones_target).float() / ones_target.size(0) D_fake_accuracy = torch.sum( torch.round(F.sigmoid(D_fake_scores)) == zeros_target).float() / zeros_target.size(0) D_real_loss = F.binary_cross_entropy_with_logits( D_real_scores, ones_target) D_fake_loss = F.binary_cross_entropy_with_logits( D_fake_scores, zeros_target) D_loss = (D_real_loss + D_fake_loss) / 2 gp = gradient_penalty( x.detach(), fake.detach(), lambda _x: run_noised_disc(discriminator, _x)) D_loss_plus_gp = D_loss + 10 * gp D_optimizer.zero_grad() D_loss_plus_gp.backward() D_optimizer.step() # Train generator fake = generate_from_noise(model, x.size(0), clamp=clamp, guard_nans=False) G_loss = F.binary_cross_entropy_with_logits( run_noised_disc(discriminator, fake), torch.ones((x.size(0), 1), device=x.device)) # Trace real_acc = D_real_accuracy.item() fake_acc = D_fake_accuracy.item() acc = .5 * (D_fake_accuracy.item() + D_real_accuracy.item()) z, nll, y_logits, (prior, logdet) = model.forward(x, None, return_details=True) train_bpd = nll.mean().item() loss = 0 if weight_gan > 0: loss = loss + weight_gan * G_loss if weight_prior > 0: loss = loss + weight_prior * -prior.mean() if weight_logdet > 0: loss = loss + weight_logdet * -logdet.mean() if weight_entropy_reg > 0: _, _, _, (sample_prior, sample_logdet) = model.forward(fake, None, return_details=True) # notice this is actually "decreasing" sample likelihood. loss = loss + weight_entropy_reg * (sample_prior.mean() + sample_logdet.mean()) # Jac Reg if jac_reg_lambda > 0: # Sample x_samples = generate_from_noise(model, args.batch_size, clamp=clamp).detach() x_samples.requires_grad_() z = model.forward(x_samples, None, return_details=True)[0] other_zs = torch.cat([ split._last_z2.view(x.size(0), -1) for split in model.flow.splits ], -1) all_z = torch.cat([other_zs, z.view(x.size(0), -1)], -1) sample_foward_jac = compute_jacobian_regularizer(x_samples, all_z, n_proj=1) _, c2, h, w = model.prior_h.shape c = c2 // 2 zshape = (batch_size, c, h, w) randz = torch.randn(zshape).to(device) randz = torch.autograd.Variable(randz, requires_grad=True) images = model(z=randz, y_onehot=None, temperature=1, reverse=True, batch_size=0) other_zs = [split._last_z2 for split in model.flow.splits] all_z = [randz] + other_zs sample_inverse_jac = compute_jacobian_regularizer_manyinputs( all_z, images, n_proj=1) # Data x.requires_grad_() z = model.forward(x, None, return_details=True)[0] other_zs = torch.cat([ split._last_z2.view(x.size(0), -1) for split in model.flow.splits ], -1) all_z = torch.cat([other_zs, z.view(x.size(0), -1)], -1) data_foward_jac = compute_jacobian_regularizer(x, all_z, n_proj=1) _, c2, h, w = model.prior_h.shape c = c2 // 2 zshape = (batch_size, c, h, w) z.requires_grad_() images = model(z=z, y_onehot=None, temperature=1, reverse=True, batch_size=0) other_zs = [split._last_z2 for split in model.flow.splits] all_z = [z] + other_zs data_inverse_jac = compute_jacobian_regularizer_manyinputs( all_z, images, n_proj=1) # loss = loss + jac_reg_lambda * (sample_foward_jac + sample_inverse_jac ) loss = loss + jac_reg_lambda * (sample_foward_jac + sample_inverse_jac + data_foward_jac + data_inverse_jac) if not eval_only: optimizer.zero_grad() loss.backward() if not db: assert max_grad_clip == max_grad_norm == 0 if max_grad_clip > 0: torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # Replace NaN gradient with 0 for p in model.parameters(): if p.requires_grad and p.grad is not None: g = p.grad.data g[g != g] = 0 optimizer.step() if engine.iter_ind % 100 == 0: with torch.no_grad(): fake = generate_from_noise(model, x.size(0), clamp=clamp) z = model.forward(fake, None, return_details=True)[0] print("Z max min") print(z.max().item(), z.min().item()) if (fake != fake).float().sum() > 0: title = 'NaNs' else: title = "Good" grid = make_grid((postprocess(fake.detach().cpu(), dataset)[:30]), nrow=6).permute(1, 2, 0) plt.figure(figsize=(10, 10)) plt.imshow(grid) plt.axis('off') plt.title(title) plt.savefig( os.path.join(output_dir, f'sample_{engine.iter_ind}.png')) if engine.iter_ind % eval_every == 0: def check_all_zero_except_leading(x): return x % 10**np.floor(np.log10(x)) == 0 if engine.iter_ind == 0 or check_all_zero_except_leading( engine.iter_ind): torch.save( model.state_dict(), os.path.join(output_dir, f'ckpt_sd_{engine.iter_ind}.pt')) model.eval() with torch.no_grad(): # Plot recon fpath = os.path.join(output_dir, '_recon', f'recon_{engine.iter_ind}.png') sample_pad = run_recon_evolution( model, generate_from_noise(model, args.batch_size, clamp=clamp).detach(), fpath) print( f"Iter: {engine.iter_ind}, Recon Sample PAD: {sample_pad}") pad = run_recon_evolution(model, x_for_recon, fpath) print(f"Iter: {engine.iter_ind}, Recon PAD: {pad}") pad = pad.item() sample_pad = sample_pad.item() # Inception score sample = torch.cat([ generate_from_noise(model, args.batch_size, clamp=clamp) for _ in range(N_inception // args.batch_size + 1) ], 0)[:N_inception] sample = sample + .5 if (sample != sample).float().sum() > 0: print("Sample NaNs") raise else: fid = run_fid(x_real_inception.clamp_(0, 1), sample.clamp_(0, 1)) print(f'fid: {fid}, global_iter: {engine.iter_ind}') # Eval BPD eval_bpd = np.mean([ model.forward(x.to(device), None, return_details=True)[1].mean().item() for x, _ in test_loader ]) stats_dict = { 'global_iteration': engine.iter_ind, 'fid': fid, 'train_bpd': train_bpd, 'pad': pad, 'eval_bpd': eval_bpd, 'sample_pad': sample_pad, 'batch_real_acc': real_acc, 'batch_fake_acc': fake_acc, 'batch_acc': acc } iteration_logger.writerow(stats_dict) plot_csv(iteration_logger.filename) model.train() if engine.iter_ind + 2 % svd_every == 0: model.eval() svd_dict = {} ret = utils.computeSVDjacobian(x_for_recon, model) D_for, D_inv = ret['D_for'], ret['D_inv'] cn = float(D_for.max() / D_for.min()) cn_inv = float(D_inv.max() / D_inv.min()) svd_dict['global_iteration'] = engine.iter_ind svd_dict['condition_num'] = cn svd_dict['max_sv'] = float(D_for.max()) svd_dict['min_sv'] = float(D_for.min()) svd_dict['inverse_condition_num'] = cn_inv svd_dict['inverse_max_sv'] = float(D_inv.max()) svd_dict['inverse_min_sv'] = float(D_inv.min()) svd_logger.writerow(svd_dict) # plot_utils.plot_stability_stats(output_dir) # plot_utils.plot_individual_figures(output_dir, 'svd_log.csv') model.train() if eval_only: sys.exit() # Dummy losses['total_loss'] = torch.mean(nll).item() return losses def eval_step(engine, batch): model.eval() x, y = batch x = x.to(device) with torch.no_grad(): if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class, reduction='none') else: z, nll, y_logits = model(x, None) losses = compute_loss(nll, reduction='none') return losses trainer = Engine(gan_step) # else: # trainer = Engine(step) checkpoint_handler = ModelCheckpoint(output_dir, 'glow', save_interval=5, n_saved=1, require_empty=False) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, { 'model': model, 'optimizer': optimizer }) monitoring_metrics = ['total_loss'] RunningAverage(output_transform=lambda x: x['total_loss']).attach( trainer, 'total_loss') evaluator = Engine(eval_step) # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach( evaluator, 'total_loss') if y_condition: monitoring_metrics.extend(['nll']) RunningAverage(output_transform=lambda x: x['nll']).attach( trainer, 'nll') # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['nll'], torch.empty(x['nll'].shape[0]))).attach( evaluator, 'nll') pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) # load pre-trained model if given if saved_model: print("Loading...") print(saved_model) loaded = torch.load(saved_model) # if 'Glow' in str(type(loaded)): # model = loaded # else: # raise # # if 'Glow' in str(type(loaded)): # # loaded = loaded.state_dict() model.load_state_dict(loaded) model.set_actnorm_init() if saved_optimizer: optimizer.load_state_dict(torch.load(saved_optimizer)) file_name, ext = os.path.splitext(saved_model) resume_epoch = int(file_name.split('_')[-1]) @trainer.on(Events.STARTED) def resume_training(engine): engine.state.epoch = resume_epoch engine.state.iteration = resume_epoch * len( engine.state.dataloader) @trainer.on(Events.STARTED) def init(engine): if saved_model: return model.train() print("Initializing Actnorm...") init_batches = [] init_targets = [] if n_init_batches == 0: model.set_actnorm_init() return with torch.no_grad(): if init_sample: generate_from_noise(model, args.batch_size * args.n_init_batches) else: for batch, target in islice(train_loader, None, n_init_batches): init_batches.append(batch) init_targets.append(target) init_batches = torch.cat(init_batches).to(device) assert init_batches.shape[0] == n_init_batches * batch_size if y_condition: init_targets = torch.cat(init_targets).to(device) else: init_targets = None model(init_batches, init_targets) @trainer.on(Events.EPOCH_COMPLETED) def evaluate(engine): evaluator.run(test_loader) if not no_warm_up: scheduler.step() metrics = evaluator.state.metrics losses = ', '.join( [f"{key}: {value:.2f}" for key, value in metrics.items()]) print(f'Validation Results - Epoch: {engine.state.epoch} {losses}') timer = Timer(average=True) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message( f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]' ) timer.reset() trainer.run(train_loader, epochs)
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size, epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, learn_top, y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers, cuda, n_init_batches, warmup_steps, output_dir, saved_optimizer, fresh): device = 'cpu' if (not torch.cuda.is_available() or not cuda) else 'cuda:0' check_manual_seed(seed) ds = check_dataset(dataset, dataroot, augment, download) image_shape, num_classes, train_dataset, test_dataset = ds # Note: unsupported for now multi_class = False train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True) test_loader = data.DataLoader(test_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=n_workers, drop_last=False) model = Glow(image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, num_classes, learn_top, y_condition) model = model.to(device) optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5) def step(engine, batch): model.train() optimizer.zero_grad() x, y = batch x = x.to(device) if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class) else: z, nll, y_logits = model(x, None) losses = compute_loss(nll) losses['total_loss'].backward() if max_grad_clip > 0: torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) optimizer.step() return losses def eval_step(engine, batch): model.eval() x, y = batch x = x.to(device) with torch.no_grad(): if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class, reduction='none') else: z, nll, y_logits = model(x, None) losses = compute_loss(nll, reduction='none') return losses trainer = Engine(step) checkpoint_handler = ModelCheckpoint(output_dir, 'glow', save_interval=1, n_saved=2, require_empty=False) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, { 'model': model, 'optimizer': optimizer }) monitoring_metrics = ['total_loss'] RunningAverage(output_transform=lambda x: x['total_loss']).attach( trainer, 'total_loss') evaluator = Engine(eval_step) # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach( evaluator, 'total_loss') if y_condition: monitoring_metrics.extend(['nll']) RunningAverage(output_transform=lambda x: x['nll']).attach( trainer, 'nll') # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['nll'], torch.empty(x['nll'].shape[0]))).attach( evaluator, 'nll') pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) # load pre-trained model if given if saved_model: model.load_state_dict(torch.load(saved_model)) model.set_actnorm_init() if saved_optimizer: optimizer.load_state_dict(torch.load(saved_optimizer)) file_name, ext = os.path.splitext(saved_model) resume_epoch = int(file_name.split('_')[-1]) @trainer.on(Events.STARTED) def resume_training(engine): engine.state.epoch = resume_epoch engine.state.iteration = resume_epoch * len( engine.state.dataloader) @trainer.on(Events.STARTED) def init(engine): model.train() init_batches = [] init_targets = [] with torch.no_grad(): for batch, target in islice(train_loader, None, n_init_batches): init_batches.append(batch) init_targets.append(target) init_batches = torch.cat(init_batches).to(device) assert init_batches.shape[0] == n_init_batches * batch_size if y_condition: init_targets = torch.cat(init_targets).to(device) else: init_targets = None model(init_batches, init_targets) @trainer.on(Events.EPOCH_COMPLETED) def evaluate(engine): evaluator.run(test_loader) metrics = evaluator.state.metrics losses = ', '.join( [f"{key}: {value:.2f}" for key, value in metrics.items()]) print(f'Validation Results - Epoch: {engine.state.epoch} {losses}') timer = Timer(average=True) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message( f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]' ) timer.reset() trainer.run(train_loader, epochs)
ldtv, ) in zip(log_p_val, logdet_val, log_p_train_val, logdet_train_val): print( args.delta, lpv.item(), ldv.item(), lptv.item(), ldtv.item(), file=f_ll, ) f_ll.close() f_train_loss.close() f_test_loss.close() if __name__ == "__main__": args = parser.parse_args() print(string_args(args)) device = args.device model = Glow( args.n_channels, args.n_flow, args.n_block, affine=args.affine, conv_lu=not args.no_lu, ) model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr) train(args, model, optimizer)
def main( dataset, dataset2, dataroot, download, augment, batch_size, eval_batch_size, nlls_batch_size, epochs, nb_step, saved_model, seed, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, learn_top, y_condition, y_weight, max_grad_clip, max_grad_norm, lr, lr_test, n_workers, cuda, n_init_batches, output_dir, saved_optimizer, warmup, every_epoch, ): device = "cpu" if (not torch.cuda.is_available() or not cuda) else "cuda:0" check_manual_seed(seed) ds = check_dataset(dataset, dataroot, augment, download) ds2 = check_dataset(dataset2, dataroot, augment, download) image_shape, num_classes, train_dataset, test_dataset = ds image_shape2, num_classes2, train_dataset_2, test_dataset_2 = ds2 assert(image_shape == image_shape2) data1 = [] data2 = [] for k in range(nlls_batch_size): dataaux, targetaux = test_dataset[k] data1.append(dataaux) dataaux, targetaux = test_dataset_2[k] data2.append(dataaux) # Note: unsupported for now multi_class = False train_loader = data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True, ) test_loader = data.DataLoader( test_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=n_workers, drop_last=False, ) model = Glow( image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, num_classes, learn_top, y_condition, ) model = model.to(device) optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5) lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup) # noqa scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) def step(engine, batch): model.train() optimizer.zero_grad() x, y = batch x = x.to(device) if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class) else: z, nll, y_logits = model(x, None) losses = compute_loss(nll) losses["total_loss"].backward() if max_grad_clip > 0: torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) optimizer.step() return losses def eval_step(engine, batch): model.eval() x, y = batch x = x.to(device) with torch.no_grad(): if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y( nll, y_logits, y_weight, y, multi_class, reduction="none" ) else: z, nll, y_logits = model(x, None) losses = compute_loss(nll, reduction="none") return losses trainer = Engine(step) checkpoint_handler = ModelCheckpoint( output_dir, "glow", n_saved=2, require_empty=False ) trainer.add_event_handler( Events.EPOCH_COMPLETED, checkpoint_handler, {"model": model, "optimizer": optimizer}, ) monitoring_metrics = ["total_loss"] RunningAverage(output_transform=lambda x: x["total_loss"]).attach( trainer, "total_loss" ) evaluator = Engine(eval_step) # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss( lambda x, y: torch.mean(x), output_transform=lambda x: ( x["total_loss"], torch.empty(x["total_loss"].shape[0]), ), ).attach(evaluator, "total_loss") if y_condition: monitoring_metrics.extend(["nll"]) RunningAverage(output_transform=lambda x: x["nll"]).attach(trainer, "nll") # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss( lambda x, y: torch.mean(x), output_transform=lambda x: (x["nll"], torch.empty(x["nll"].shape[0])), ).attach(evaluator, "nll") pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) # load pre-trained model if given if saved_model: model.load_state_dict(torch.load(saved_model)['model']) model.set_actnorm_init() if saved_optimizer: optimizer.load_state_dict(torch.load(saved_optimizer)['opt']) file_name, ext = os.path.splitext(saved_model) resume_epoch = int(file_name.split("_")[-1])/1e3 @trainer.on(Events.STARTED) def resume_training(engine): engine.state.epoch = resume_epoch engine.state.iteration = resume_epoch * len(engine.state.dataloader) @trainer.on(Events.STARTED) def init(engine): model.train() init_batches = [] init_targets = [] with torch.no_grad(): print(train_loader) for batch, target in islice(train_loader, None, n_init_batches): init_batches.append(batch) init_targets.append(target) init_batches = torch.cat(init_batches).to(device) assert init_batches.shape[0] == n_init_batches * batch_size if y_condition: init_targets = torch.cat(init_targets).to(device) else: init_targets = None model(init_batches, init_targets) @trainer.on(Events.EPOCH_COMPLETED) def evaluate(engine): evaluator.run(test_loader) scheduler.step() metrics = evaluator.state.metrics losses = ", ".join([f"{key}: {value:.2f}" for key, value in metrics.items()]) print(f"Validation Results - Epoch: {engine.state.epoch} {losses}") timer = Timer(average=True) timer.attach( trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED, ) @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message( f"Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]" ) timer.reset() # @trainer.on(Events.EPOCH_COMPLETED) # def eval_likelihood(engine): # global_nlls(output_dir, engine.state.epoch, data1, data2, model, dataset1_name = dataset, dataset2_name = dataset2, nb_step = nb_step, every_epoch = every_epoch, optim_default = partial(optim.SGD, lr=1e-5, momentum = 0.)) trainer.run(train_loader, epochs)
def main( dataset, dataroot, download, augment, batch_size, eval_batch_size, epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, learn_top, y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers, cuda, n_init_batches, output_dir, saved_optimizer, warmup, classifier_weight ): device = "cpu" if (not torch.cuda.is_available() or not cuda) else "cuda:0" wandb.init(project=args.dataset) check_manual_seed(seed) image_shape = (64,64,3) # if args.dataset == "task1": num_classes = 24 # else : num_classes = 40 num_classes = 40 # Note: unsupported for now multi_class = True #It's True but this variable doesn't be used now # if args.dataset == "task1": # dataset_train = CLEVRDataset(root_folder=args.dataroot,img_folder=args.dataroot+'images/') # train_loader = DataLoader(dataset_train,batch_size=args.batch_size,shuffle=True,drop_last=True) # else : # dataset_train = CelebALoader(root_folder=args.dataroot) #'/home/arg/courses/machine_learning/homework/deep_learning_and_practice/Lab7/dataset/task_2/' # train_loader = DataLoader(dataset_train,batch_size=args.batch_size,shuffle=True,drop_last=True) dataset_train = CelebALoader(root_folder=args.dataroot) #'/home/arg/courses/machine_learning/homework/deep_learning_and_practice/Lab7/dataset/task_2/' train_loader = DataLoader(dataset_train,batch_size=args.batch_size,shuffle=True,drop_last=True) model = Glow( image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, num_classes, learn_top, y_condition, ) model = model.to(device) optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5) lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup) # noqa scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) wandb.watch(model) def step(engine, batch): model.train() optimizer.zero_grad() x, y = batch x = x.to(device) if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) ### x: torch.Size([batchsize, 3, 64, 64]); y: torch.Size([batchsize, 24]); z: torch.Size([batchsize, 48, 8, 8]) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class) else: z, nll, y_logits = model(x, None) losses = compute_loss(nll) losses["total_loss"].backward() if max_grad_clip > 0: torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) optimizer.step() return losses trainer = Engine(step) checkpoint_handler = ModelCheckpoint( output_dir, "glow", n_saved=None, require_empty=False ) ### n_saved (Optional[int]) – Number of objects that should be kept on disk. Older files will be removed. If set to None, all objects are kept. trainer.add_event_handler( Events.EPOCH_COMPLETED, checkpoint_handler, {"model": model, "optimizer": optimizer}, ) monitoring_metrics = ["total_loss"] RunningAverage(output_transform=lambda x: x["total_loss"]).attach( trainer, "total_loss" ) pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) if saved_model: model.load_state_dict(torch.load(saved_model, map_location="cpu")['model']) model.set_actnorm_init() @trainer.on(Events.STARTED) def init(engine): model.train() init_batches = [] init_targets = [] with torch.no_grad(): for batch, target in islice(train_loader, None, n_init_batches): init_batches.append(batch) init_targets.append(target) init_batches = torch.cat(init_batches).to(device) assert init_batches.shape[0] == n_init_batches * batch_size if y_condition: init_targets = torch.cat(init_targets).to(device) else: init_targets = None model(init_batches, init_targets) # evaluator = evaluation_model(args.classifier_weight) # @trainer.on(Events.EPOCH_COMPLETED) # def evaluate(engine): # if args.dataset == "task1": # model.eval() # with torch.no_grad(): # test_conditions = get_test_conditions(args.dataroot).cuda() # predict_x = postprocess(model(y_onehot=test_conditions, temperature=1, reverse=True)).float() # score = evaluator.eval(predict_x, test_conditions) # save_image(predict_x.float(), args.output_dir+f"/Epoch{engine.state.epoch}_score{score:.3f}.png", normalize=True) # test_conditions = get_new_test_conditions(args.dataroot).cuda() # predict_x = postprocess(model(y_onehot=test_conditions, temperature=1, reverse=True)).float() # newscore = evaluator.eval(predict_x.float(), test_conditions) # save_image(predict_x.float(), args.output_dir+f"/Epoch{engine.state.epoch}_newscore{newscore:.3f}.png", normalize=True) # print(f"Iter: {engine.state.iteration} score:{score:.3f} newscore:{newscore:.3f} ") # wandb.log({"score": score, "new_score": newscore}) trainer.run(train_loader, epochs)
batch_size=args.batch_size, shuffle=True) image_size = train_dataset[0][0].size() print('size of train data: %d' % len(train_dataset)) print('size of test data: %d' % len(test_dataset)) print('image size: %s' % str(image_size)) # Model print('==> Model') model = Glow(image_size, args.channels_h, args.K, args.L, save_memory=args.save_memory).to(device) #print(model) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) def train(epoch): # warmup lr = min(args.lr * epoch / 10, args.lr) for param_group in optimizer.param_groups: param_group['lr'] = lr model.train() sum_loss = 0 count = 0 for iteration, batch in enumerate(train_loader, 1): batch = batch[0].to(device)
n_sample = 4 temp = 0.7 n_bits = 5 n_bins = 2.**n_bits img_channels = 1 model = Glow(img_channels, n_flow, n_block, affine=affine) z_sample = [] z_shapes = calc_z_shapes(img_channels, img_size, n_flow, n_block) for z in z_shapes: z_new = torch.randn(n_sample, *z) * temp z_sample.append(z_new.to('cuda')) model.to('cuda') optimizer = Adam(model.parameters(), lr=1e-4) plot = False i = 0 total_loss = [] for i in range(100): for image, _ in tqdm(train_loader): optimizer.zero_grad() image = image.to('cuda') log_p, logdet, out = model(image + torch.rand_like(image) / n_bins) loss, log_p, log_det = calc_loss(log_p, logdet, img_size, img_channels, n_bins) loss.backward() optimizer.step() writer.add_scalar('loss', loss.cpu().item(), i)
def main( dataset, augment, batch_size, eval_batch_size, epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, learn_top, y_condition, extra_condition, sp_condition, d_condition, yd_condition, y_weight, d_weight, max_grad_clip, max_grad_norm, lr, n_workers, cuda, n_init_batches, output_dir, missing, saved_optimizer, warmup, ): print(output_dir) device = "cpu" if (not torch.cuda.is_available() or not cuda) else "cuda:0" print(device) check_manual_seed(seed) print("augmenting?", augment) train_dataset, test_dataset = check_dataset(dataset, augment, missing) image_shape = (32, 32, 3) multi_class = False if yd_condition: num_classes = 2 num_domains = 10 #num_classes = 10+2 #multi_class=True elif d_condition: num_classes = 10 num_domains = 0 else: num_classes = 2 num_domains = 0 #print("num classes", num_classes) train_loader = data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True, ) test_loader = data.DataLoader( test_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=n_workers, drop_last=False, ) model = Glow(image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, num_classes, num_domains, learn_top, y_condition, extra_condition, sp_condition, d_condition, yd_condition) model = model.to(device) optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5) lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup) # noqa scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) def step(engine, batch): model.train() optimizer.zero_grad() x, y, d, yd = batch x = x.to(device) if y_condition: y = y.to(device) z, nll, y_logits, spare = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class) elif d_condition: d = d.to(device) z, nll, d_logits, spare = model(x, d) losses = compute_loss_y(nll, d_logits, d_weight, d, multi_class) elif yd_condition: y, d, yd = y.to(device), d.to(device), yd.to(device) z, nll, y_logits, d_logits = model(x, yd) losses = compute_loss_yd(nll, y_logits, y_weight, y, d_logits, d_weight, d) else: print("none") z, nll, y_logits, spare = model(x, None) losses = compute_loss(nll) losses["total_loss"].backward() if max_grad_clip > 0: torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) optimizer.step() return losses def eval_step(engine, batch): model.eval() x, y, d, yd = batch x = x.to(device) with torch.no_grad(): if y_condition: y = y.to(device) z, nll, y_logits, none_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class, reduction="none") elif d_condition: d = d.to(device) z, nll, d_logits, non_logits = model(x, d) losses = compute_loss_y(nll, d_logits, d_weight, d, multi_class, reduction="none") elif yd_condition: y, d, yd = y.to(device), d.to(device), yd.to(device) z, nll, y_logits, d_logits = model(x, yd) losses = compute_loss_yd(nll, y_logits, y_weight, y, d_logits, d_weight, d, reduction="none") else: z, nll, y_logits, d_logits = model(x, None) losses = compute_loss(nll, reduction="none") #print(losses, "losssssess") return losses trainer = Engine(step) checkpoint_handler = ModelCheckpoint(output_dir, "glow", save_interval=1, n_saved=2, require_empty=False) trainer.add_event_handler( Events.EPOCH_COMPLETED, checkpoint_handler, { "model": model, "optimizer": optimizer }, ) monitoring_metrics = ["total_loss"] RunningAverage(output_transform=lambda x: x["total_loss"]).attach( trainer, "total_loss") evaluator = Engine(eval_step) # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss( lambda x, y: torch.mean(x), output_transform=lambda x: ( x["total_loss"], torch.empty(x["total_loss"].shape[0]), ), ).attach(evaluator, "total_loss") if y_condition or d_condition or yd_condition: monitoring_metrics.extend(["nll"]) RunningAverage(output_transform=lambda x: x["nll"]).attach( trainer, "nll") # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss( lambda x, y: torch.mean(x), output_transform=lambda x: (x["nll"], torch.empty(x["nll"].shape[0])), ).attach(evaluator, "nll") pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) # load pre-trained model if given if saved_model: model.load_state_dict(torch.load(saved_model)) model.set_actnorm_init() if saved_optimizer: optimizer.load_state_dict(torch.load(saved_optimizer)) file_name, ext = os.path.splitext(saved_model) resume_epoch = int(file_name.split("_")[-1]) @trainer.on(Events.STARTED) def resume_training(engine): engine.state.epoch = resume_epoch engine.state.iteration = resume_epoch * len( engine.state.dataloader) @trainer.on(Events.STARTED) def init(engine): model.train() init_batches = [] init_targets = [] init_domains = [] init_yds = [] with torch.no_grad(): for batch, target, domain, yd in islice(train_loader, None, n_init_batches): init_batches.append(batch) init_targets.append(target) init_domains.append(domain) init_yds.append(yd) init_batches = torch.cat(init_batches).to(device) assert init_batches.shape[0] == n_init_batches * batch_size if y_condition: init_targets = torch.cat(init_targets).to(device) model(init_batches, init_targets) elif d_condition: init_domains = torch.cat(init_domains).to(device) model(init_batches, init_domains) elif yd_condition: init_yds = torch.cat(init_yds).to(device) model(init_batches, init_yds) else: init_targets = None model(init_batches, init_targets) @trainer.on(Events.EPOCH_COMPLETED) def evaluate(engine): evaluator.run(test_loader) #print("done") scheduler.step() metrics = evaluator.state.metrics losses = ", ".join( [f"{key}: {value:.8f}" for key, value in metrics.items()]) print(f"Validation Results - Epoch: {engine.state.epoch} {losses}") def score_function(engine): val_loss = engine.state.metrics['total_loss'] return -val_loss name = "best_" val_handler = ModelCheckpoint(output_dir, name, score_function=score_function, score_name="val_loss", n_saved=1, require_empty=False) evaluator.add_event_handler( Events.EPOCH_COMPLETED, val_handler, {"model": model}, ) timer = Timer(average=True) timer.attach( trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED, ) @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message( f"Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]" ) timer.reset() trainer.run(train_loader, epochs)
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size, epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, learn_top, y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers, cuda, n_init_batches, warmup_steps, output_dir, saved_optimizer, warmup, fresh, logittransform, gan, disc_lr): device = 'cpu' if (not torch.cuda.is_available() or not cuda) else 'cuda:0' check_manual_seed(seed) ds = check_dataset(dataset, dataroot, augment, download) image_shape, num_classes, train_dataset, test_dataset = ds # Note: unsupported for now multi_class = False train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True) test_loader = data.DataLoader(test_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=n_workers, drop_last=False) model = Glow(image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, num_classes, learn_top, y_condition, logittransform) model = model.to(device) if gan: # Debug model = mine.Generator(32, 1).to(device) optimizer = optim.Adam(model.parameters(), lr=lr, betas=(.5, .99), weight_decay=0) discriminator = mine.Discriminator(image_shape[-1]) discriminator = discriminator.to(device) D_optimizer = optim.Adam(filter(lambda p: p.requires_grad, discriminator.parameters()), lr=disc_lr, betas=(.5, .99), weight_decay=0) else: optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5) # lr_lambda = lambda epoch: lr * min(1., epoch+1 / warmup) # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) i = 0 def step(engine, batch): model.train() optimizer.zero_grad() x, y = batch x = x.to(device) if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class) else: z, nll, y_logits = model(x, None) losses = compute_loss(nll) losses['total_loss'].backward() if max_grad_clip > 0: torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) optimizer.step() return losses def gan_step(engine, batch): assert not y_condition if 'iter_ind' in dir(engine): engine.iter_ind += 1 else: engine.iter_ind = -1 losses = {} model.train() discriminator.train() x, y = batch x = x.to(device) # def generate_from_noise(batch_size): # _, c2, h, w = model.prior_h.shape # c = c2 // 2 # zshape = (batch_size, c, h, w) # randz = torch.autograd.Variable(torch.randn(zshape), requires_grad=True).to(device) # images = model(z= randz, y_onehot=None, temperature=1, reverse=True,batch_size=batch_size) # return images def generate_from_noise(batch_size): zshape = (batch_size, 32, 1, 1) randz = torch.randn(zshape).to(device) images = model(randz) return images / 2 def run_noised_disc(discriminator, x): x = uniform_binning_correction(x)[0] return discriminator(x) # Train Disc fake = generate_from_noise(x.size(0)) D_real_scores = run_noised_disc(discriminator, x.detach()) D_fake_scores = run_noised_disc(discriminator, fake.detach()) ones_target = torch.ones((x.size(0), 1), device=x.device) zeros_target = torch.zeros((x.size(0), 1), device=x.device) # D_real_accuracy = torch.sum(torch.round(F.sigmoid(D_real_scores)) == ones_target).float() / ones_target.size(0) # D_fake_accuracy = torch.sum(torch.round(F.sigmoid(D_fake_scores)) == zeros_target).float() / zeros_target.size(0) D_real_loss = F.binary_cross_entropy_with_logits( D_real_scores, ones_target) D_fake_loss = F.binary_cross_entropy_with_logits( D_fake_scores, zeros_target) D_loss = (D_real_loss + D_fake_loss) / 2 gp = gradient_penalty(x.detach(), fake.detach(), lambda _x: run_noised_disc(discriminator, _x)) D_loss_plus_gp = D_loss + 10 * gp D_optimizer.zero_grad() D_loss_plus_gp.backward() D_optimizer.step() # Train generator fake = generate_from_noise(x.size(0)) G_loss = F.binary_cross_entropy_with_logits( run_noised_disc(discriminator, fake), torch.ones((x.size(0), 1), device=x.device)) losses['total_loss'] = G_loss # G-step optimizer.zero_grad() losses['total_loss'].backward() params = list(model.parameters()) gnorm = [p.grad.norm() for p in params] optimizer.step() # if max_grad_clip > 0: # torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) # if max_grad_norm > 0: # torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) if engine.iter_ind % 50 == 0: grid = make_grid((postprocess(fake.detach().cpu())[:30]), nrow=6).permute(1, 2, 0) plt.figure(figsize=(10, 10)) plt.imshow(grid) plt.axis('off') plt.savefig( os.path.join(output_dir, f'sample_{engine.iter_ind}.png')) grid = make_grid( (postprocess(uniform_binning_correction(x)[0].cpu())[:30]), nrow=6).permute(1, 2, 0) plt.figure(figsize=(10, 10)) plt.imshow(grid) plt.axis('off') plt.savefig(os.path.join(output_dir, f'data_{engine.iter_ind}.png')) return losses def eval_step(engine, batch): model.eval() x, y = batch x = x.to(device) with torch.no_grad(): if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class, reduction='none') else: z, nll, y_logits = model(x, None) losses = compute_loss(nll, reduction='none') return losses if gan: trainer = Engine(gan_step) else: trainer = Engine(step) checkpoint_handler = ModelCheckpoint(output_dir, 'glow', save_interval=1, n_saved=2, require_empty=False) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, { 'model': model, 'optimizer': optimizer }) monitoring_metrics = ['total_loss'] RunningAverage(output_transform=lambda x: x['total_loss']).attach( trainer, 'total_loss') evaluator = Engine(eval_step) # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach( evaluator, 'total_loss') if y_condition: monitoring_metrics.extend(['nll']) RunningAverage(output_transform=lambda x: x['nll']).attach( trainer, 'nll') # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['nll'], torch.empty(x['nll'].shape[0]))).attach( evaluator, 'nll') pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) # load pre-trained model if given if saved_model: model.load_state_dict(torch.load(saved_model)) model.set_actnorm_init() if saved_optimizer: optimizer.load_state_dict(torch.load(saved_optimizer)) file_name, ext = os.path.splitext(saved_model) resume_epoch = int(file_name.split('_')[-1]) @trainer.on(Events.STARTED) def resume_training(engine): engine.state.epoch = resume_epoch engine.state.iteration = resume_epoch * len( engine.state.dataloader) # @trainer.on(Events.STARTED) # def init(engine): # model.train() # init_batches = [] # init_targets = [] # with torch.no_grad(): # for batch, target in islice(train_loader, None, # n_init_batches): # init_batches.append(batch) # init_targets.append(target) # init_batches = torch.cat(init_batches).to(device) # assert init_batches.shape[0] == n_init_batches * batch_size # if y_condition: # init_targets = torch.cat(init_targets).to(device) # else: # init_targets = None # model(init_batches, init_targets) # @trainer.on(Events.EPOCH_COMPLETED) # def evaluate(engine): # evaluator.run(test_loader) # # scheduler.step() # metrics = evaluator.state.metrics # losses = ', '.join([f"{key}: {value:.2f}" for key, value in metrics.items()]) # myprint(f'Validation Results - Epoch: {engine.state.epoch} {losses}') timer = Timer(average=True) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message( f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]' ) timer.reset() trainer.run(train_loader, epochs)