def main(args): # we're probably only be using 1 GPU, so this should be fine device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"running on {device}") # set random seed for all random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) global best_loss if(args.generate_samples): print("generating samples") # load data # example for CIFAR-10 training: train_set, test_set = get_dataloader(args.dataset, args.batch_size) input_channels = channels_from_dataset(args.dataset) print(f"amount of input channels: {input_channels}") # instantiate model # # baby network to make sure training script works net = Glow(in_channels=input_channels, depth=args.amt_flow_steps, levels=args.amt_levels, use_normalization=args.norm_method) # code for rosalinty model # net = RosGlow(input_channels, args.amt_flow_steps, args.amt_levels) net = net.to(device) print(f"training for {args.num_epochs} epochs.") start_epoch = 0 # TODO: add functionality for loading checkpoints here if args.resume: print(f"resuming from checkpoint found in checkpoints/best_{args.dataset.lower()}.pth.tar.") # raise error if no checkpoint directory is found assert os.path.isdir("new_checkpoints") checkpoint = torch.load(f"new_checkpoints/best_{args.dataset.lower()}.pth.tar") net.load_state_dict(checkpoint["model"]) global best_loss best_loss = checkpoint["test_loss"] start_epoch = checkpoint["epoch"] loss_function = FlowNLL().to(device) optimizer = optim.Adam(net.parameters(), lr=float(args.lr)) # scheduler found in code, no mention in paper # scheduler = sched.LambdaLR( # optimizer, lambda s: min(1., s / args.warmup_iters)) # should we add a resume function here? for epoch in range(start_epoch, start_epoch + args.num_epochs): print(f"training epoch {epoch}") train(net, train_set, device, optimizer, loss_function, epoch) # how often do we want to test? if (epoch % 10 == 0): # revert this to 10 once we know that this works print(f"testing epoch {epoch}") test(net, test_set, device, loss_function, epoch, args.generate_samples, args.amt_levels, args.dataset, args.n_samples)
def main(args,kwargs): output_folder = args.output_folder model_name = args.model_name with open(os.path.join(output_folder,'hparams.json')) as json_file: hparams = json.load(json_file) image_shape, num_classes, _, test_mnist = get_MNIST(False, hparams['dataroot'], hparams['download']) test_loader = data.DataLoader(test_mnist, batch_size=32, shuffle=False, num_workers=6, drop_last=False) x, y = test_loader.__iter__().__next__() x = x.to(device) model = Glow(image_shape, hparams['hidden_channels'], hparams['K'], hparams['L'], hparams['actnorm_scale'], hparams['flow_permutation'], hparams['flow_coupling'], hparams['LU_decomposed'], num_classes, hparams['learn_top'], hparams['y_condition'], False if 'logittransform' not in hparams else hparams['logittransform'],False if 'sn' not in hparams else hparams['sn']) model.load_state_dict(torch.load(os.path.join(output_folder, model_name))) model.set_actnorm_init() model = model.to(device) model = model.eval() with torch.no_grad(): # ipdb.set_trace() images = model(y_onehot=None, temperature=1, batch_size=32, reverse=True).cpu() better_dup_images = model(y_onehot=None, temperature=1, z=model._last_z, reverse=True, use_last_split=True).cpu() dup_images = model(y_onehot=None, temperature=1, z=model._last_z, reverse=True).cpu() worse_dup_images = model(y_onehot=None, temperature=1, z=model._last_z, reverse=True).cpu() l2_err = torch.pow((images - dup_images).view(images.shape[0], -1), 2).sum(-1).mean() better_l2_err = torch.pow((images - better_dup_images).view(images.shape[0], -1), 2).sum(-1).mean() worse_l2_err = torch.pow((images - worse_dup_images).view(images.shape[0], -1), 2).sum(-1).mean() print(l2_err, better_l2_err, worse_l2_err) plot_imgs([images, dup_images, better_dup_images, worse_dup_images], '_recons') # with torch.no_grad(): # ipdb.set_trace() z, nll, y_logits = model(x, None) better_dup_images = model(y_onehot=None, temperature=1, z=z, reverse=True, use_last_split=True).cpu() plot_imgs([x, better_dup_images], '_data_recons2') fpath = os.path.join(output_folder, '_recon_evoluation.png') pad = run_recon_evolution(model, x, fpath)
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size, epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, learn_top, y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers, cuda, n_init_batches, warmup_steps, output_dir, saved_optimizer, fresh): device = 'cpu' if (not torch.cuda.is_available() or not cuda) else 'cuda:0' check_manual_seed(seed) ds = check_dataset(dataset, dataroot, augment, download) image_shape, num_classes, train_dataset, test_dataset = ds # Note: unsupported for now multi_class = False train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True) test_loader = data.DataLoader(test_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=n_workers, drop_last=False) model = Glow(image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, num_classes, learn_top, y_condition) model = model.to(device) optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5) def step(engine, batch): model.train() optimizer.zero_grad() x, y = batch x = x.to(device) if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class) else: z, nll, y_logits = model(x, None) losses = compute_loss(nll) losses['total_loss'].backward() if max_grad_clip > 0: torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) optimizer.step() return losses def eval_step(engine, batch): model.eval() x, y = batch x = x.to(device) with torch.no_grad(): if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class, reduction='none') else: z, nll, y_logits = model(x, None) losses = compute_loss(nll, reduction='none') return losses trainer = Engine(step) checkpoint_handler = ModelCheckpoint(output_dir, 'glow', save_interval=1, n_saved=2, require_empty=False) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, { 'model': model, 'optimizer': optimizer }) monitoring_metrics = ['total_loss'] RunningAverage(output_transform=lambda x: x['total_loss']).attach( trainer, 'total_loss') evaluator = Engine(eval_step) # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach( evaluator, 'total_loss') if y_condition: monitoring_metrics.extend(['nll']) RunningAverage(output_transform=lambda x: x['nll']).attach( trainer, 'nll') # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['nll'], torch.empty(x['nll'].shape[0]))).attach( evaluator, 'nll') pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) # load pre-trained model if given if saved_model: model.load_state_dict(torch.load(saved_model)) model.set_actnorm_init() if saved_optimizer: optimizer.load_state_dict(torch.load(saved_optimizer)) file_name, ext = os.path.splitext(saved_model) resume_epoch = int(file_name.split('_')[-1]) @trainer.on(Events.STARTED) def resume_training(engine): engine.state.epoch = resume_epoch engine.state.iteration = resume_epoch * len( engine.state.dataloader) @trainer.on(Events.STARTED) def init(engine): model.train() init_batches = [] init_targets = [] with torch.no_grad(): for batch, target in islice(train_loader, None, n_init_batches): init_batches.append(batch) init_targets.append(target) init_batches = torch.cat(init_batches).to(device) assert init_batches.shape[0] == n_init_batches * batch_size if y_condition: init_targets = torch.cat(init_targets).to(device) else: init_targets = None model(init_batches, init_targets) @trainer.on(Events.EPOCH_COMPLETED) def evaluate(engine): evaluator.run(test_loader) metrics = evaluator.state.metrics losses = ', '.join( [f"{key}: {value:.2f}" for key, value in metrics.items()]) print(f'Validation Results - Epoch: {engine.state.epoch} {losses}') timer = Timer(average=True) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message( f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]' ) timer.reset() trainer.run(train_loader, epochs)
except (KeyboardInterrupt, SystemExit): check_save(model_single, optimizer, args, z_sample, i, save=True) raise if __name__ == '__main__': args = parser.parse_args() if len(args.load_path) > 0: args.startiter = int(args.load_path[:-3].split('_')[-1]) print(args) model_single = Glow( 1, args.n_flow, args.n_block, affine=args.affine, conv_lu=not args.no_lu ).cpu() if len(args.load_path) > 0: model_single.load_state_dict(torch.load(args.load_path, map_location=lambda storage, loc: storage)) model_single.initialize() gc.collect() torch.cuda.empty_cache() model = model_single model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr) if len(args.load_path) > 0: optim_path = '/'.join(args.load_path.split('/')[:-1]) optimizer.load_state_dict(torch.load(os.path.join(optim_path, 'optimizer.pth'), map_location=lambda storage, loc: storage)) gc.collect() torch.cuda.empty_cache() train(args, model, optimizer)
hparams = json.load(json_file) ds = check_dataset(args.dataset, args.dataroot, True, args.download) ds2 = check_dataset(args.dataset2, args.dataroot, True, args.download) image_shape, num_classes, train_dataset, test_dataset = ds image_shape2, num_classes2, train_dataset_2, test_dataset_2 = ds2 model = Glow(image_shape, hparams['hidden_channels'], hparams['K'], hparams['L'], hparams['actnorm_scale'], hparams['flow_permutation'], hparams['flow_coupling'], hparams['LU_decomposed'], num_classes, hparams['learn_top'], hparams['y_condition']) dic = torch.load(checkpoint_path) if 'model' in dic.keys(): model.load_state_dict(dic["model"]) else: model.load_state_dict(dic) model.set_actnorm_init() model = model.to(device) model = model.eval() if args.optim_type == "ADAM": optim_default = partial(optim.Adam, lr=args.lr_test) elif args.optim_type == "SGD": optim_default = partial(optim.SGD, lr=args.lr_test, momentum=args.momentum) if args.limited_data is not None:
def main( dataset, dataset2, dataroot, download, augment, batch_size, eval_batch_size, nlls_batch_size, epochs, nb_step, saved_model, seed, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, learn_top, y_condition, y_weight, max_grad_clip, max_grad_norm, lr, lr_test, n_workers, cuda, n_init_batches, output_dir, saved_optimizer, warmup, every_epoch, ): device = "cpu" if (not torch.cuda.is_available() or not cuda) else "cuda:0" check_manual_seed(seed) ds = check_dataset(dataset, dataroot, augment, download) ds2 = check_dataset(dataset2, dataroot, augment, download) image_shape, num_classes, train_dataset, test_dataset = ds image_shape2, num_classes2, train_dataset_2, test_dataset_2 = ds2 assert(image_shape == image_shape2) data1 = [] data2 = [] for k in range(nlls_batch_size): dataaux, targetaux = test_dataset[k] data1.append(dataaux) dataaux, targetaux = test_dataset_2[k] data2.append(dataaux) # Note: unsupported for now multi_class = False train_loader = data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True, ) test_loader = data.DataLoader( test_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=n_workers, drop_last=False, ) model = Glow( image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, num_classes, learn_top, y_condition, ) model = model.to(device) optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5) lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup) # noqa scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) def step(engine, batch): model.train() optimizer.zero_grad() x, y = batch x = x.to(device) if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class) else: z, nll, y_logits = model(x, None) losses = compute_loss(nll) losses["total_loss"].backward() if max_grad_clip > 0: torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) optimizer.step() return losses def eval_step(engine, batch): model.eval() x, y = batch x = x.to(device) with torch.no_grad(): if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y( nll, y_logits, y_weight, y, multi_class, reduction="none" ) else: z, nll, y_logits = model(x, None) losses = compute_loss(nll, reduction="none") return losses trainer = Engine(step) checkpoint_handler = ModelCheckpoint( output_dir, "glow", n_saved=2, require_empty=False ) trainer.add_event_handler( Events.EPOCH_COMPLETED, checkpoint_handler, {"model": model, "optimizer": optimizer}, ) monitoring_metrics = ["total_loss"] RunningAverage(output_transform=lambda x: x["total_loss"]).attach( trainer, "total_loss" ) evaluator = Engine(eval_step) # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss( lambda x, y: torch.mean(x), output_transform=lambda x: ( x["total_loss"], torch.empty(x["total_loss"].shape[0]), ), ).attach(evaluator, "total_loss") if y_condition: monitoring_metrics.extend(["nll"]) RunningAverage(output_transform=lambda x: x["nll"]).attach(trainer, "nll") # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss( lambda x, y: torch.mean(x), output_transform=lambda x: (x["nll"], torch.empty(x["nll"].shape[0])), ).attach(evaluator, "nll") pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) # load pre-trained model if given if saved_model: model.load_state_dict(torch.load(saved_model)['model']) model.set_actnorm_init() if saved_optimizer: optimizer.load_state_dict(torch.load(saved_optimizer)['opt']) file_name, ext = os.path.splitext(saved_model) resume_epoch = int(file_name.split("_")[-1])/1e3 @trainer.on(Events.STARTED) def resume_training(engine): engine.state.epoch = resume_epoch engine.state.iteration = resume_epoch * len(engine.state.dataloader) @trainer.on(Events.STARTED) def init(engine): model.train() init_batches = [] init_targets = [] with torch.no_grad(): print(train_loader) for batch, target in islice(train_loader, None, n_init_batches): init_batches.append(batch) init_targets.append(target) init_batches = torch.cat(init_batches).to(device) assert init_batches.shape[0] == n_init_batches * batch_size if y_condition: init_targets = torch.cat(init_targets).to(device) else: init_targets = None model(init_batches, init_targets) @trainer.on(Events.EPOCH_COMPLETED) def evaluate(engine): evaluator.run(test_loader) scheduler.step() metrics = evaluator.state.metrics losses = ", ".join([f"{key}: {value:.2f}" for key, value in metrics.items()]) print(f"Validation Results - Epoch: {engine.state.epoch} {losses}") timer = Timer(average=True) timer.attach( trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED, ) @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message( f"Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]" ) timer.reset() # @trainer.on(Events.EPOCH_COMPLETED) # def eval_likelihood(engine): # global_nlls(output_dir, engine.state.epoch, data1, data2, model, dataset1_name = dataset, dataset2_name = dataset2, nb_step = nb_step, every_epoch = every_epoch, optim_default = partial(optim.SGD, lr=1e-5, momentum = 0.)) trainer.run(train_loader, epochs)
def main(args): seed_list = [int(item) for item in args.seed.split(',')] for seed in seed_list: device = torch.device("cuda") experiment_folder = args.experiment_folder + '/' + str(seed) + '/' print(experiment_folder) #model_name = 'glow_checkpoint_'+ str(args.chk)+'.pth' for thing in os.listdir(experiment_folder): if 'best' in thing: model_name = thing print(model_name) random.seed(seed) torch.manual_seed(seed) np.random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True with open(experiment_folder + 'hparams.json') as json_file: hparams = json.load(json_file) image_shape = (32, 32, 3) if hparams['y_condition']: num_classes = 2 num_domains = 0 elif hparams['d_condition']: num_classes = 10 num_domains = 0 elif hparams['yd_condition']: num_classes = 2 num_domains = 10 else: num_classes = 2 num_domains = 0 model = Glow(image_shape, hparams['hidden_channels'], hparams['K'], hparams['L'], hparams['actnorm_scale'], hparams['flow_permutation'], hparams['flow_coupling'], hparams['LU_decomposed'], num_classes, num_domains, hparams['learn_top'], hparams['y_condition'], hparams['extra_condition'], hparams['sp_condition'], hparams['d_condition'], hparams['yd_condition']) print('loading model') model.load_state_dict(torch.load(experiment_folder + model_name)) model.set_actnorm_init() model = model.to(device) model = model.eval() if hparams['y_condition']: print('y_condition') def sample(model, temp=args.temperature): with torch.no_grad(): if hparams['y_condition']: print("extra", hparams['extra_condition']) y = torch.eye(num_classes) y = torch.cat(1000 * [y]) print(y.size()) y_0 = y[::2, :].to( device) # number hardcoded in model for now y_1 = y[1::2, :].to(device) print(y_0.size()) print(y_0) print(y_1) print(y_1.size()) images0 = model(z=None, y_onehot=y_0, temperature=temp, reverse=True, batch_size=1000) images1 = model(z=None, y_onehot=y_1, temperature=temp, reverse=True, batch_size=1000) return images0, images1 images0, images1 = sample(model) os.makedirs(experiment_folder + 'generations/Uninfected', exist_ok=True) os.makedirs(experiment_folder + 'generations/Parasitized', exist_ok=True) for i in range(images0.size(0)): torchvision.utils.save_image( images0[i, :, :, :], experiment_folder + 'generations/Uninfected/sample_{}.png'.format(i)) torchvision.utils.save_image( images1[i, :, :, :], experiment_folder + 'generations/Parasitized/sample_{}.png'.format(i)) images_concat0 = torchvision.utils.make_grid(images0[:64, :, :, :], nrow=int(64**0.5), padding=2, pad_value=255) torchvision.utils.save_image(images_concat0, experiment_folder + '/uninfected.png') images_concat1 = torchvision.utils.make_grid(images1[:64, :, :, :], nrow=int(64**0.5), padding=2, pad_value=255) torchvision.utils.save_image( images_concat1, experiment_folder + '/parasitized.png') elif hparams['d_condition']: print('d_cond') def sample_d(model, idx, batch_size=1000, temp=args.temperature): with torch.no_grad(): if hparams['d_condition']: y_0 = torch.zeros([batch_size, 10], device='cuda:0') y_0[:, idx] = torch.ones(batch_size) y_0.to(device) print(y_0) # y_1 = torch.zeros([batch_size, 201], device='cuda:0') # y_1[:, 157] = torch.ones(batch_size) # y_1.to(device) # y = torch.eye(num_classes) # y = torch.cat(1000 * [y]) # print(y.size()) # y_0 = y[::2, :].to(device) # number hardcoded in model for now # y_1 = y[1::2, :].to(device) # print(y_0.size()) # print(y_0) # print(y_1) # print(y_1.size()) images0 = model(z=None, y_onehot=y_0, temperature=temp, reverse=True, batch_size=1000) # images1 = model(z=None, y_onehot=y_1, temperature=1.0, reverse=True, batch_size=1000) return images0 for idx, dom in enumerate(["C116P77ThinF", "C132P93ThinF", "C137P98ThinF", "C180P141NThinF", "C182P143NThinF", \ "C184P145ThinF", "C39P4thinF", 'C59P20thinF', "C68P29N", "C99P60ThinF"]): images0 = sample_d(model, idx) os.makedirs(experiment_folder + 'generations/' + dom + '/Uninfected/', exist_ok=True) os.makedirs(experiment_folder + 'generations/' + dom + '/Parasitized/', exist_ok=True) # os.makedirs(experiment_folder + 'generations/C59P20thinF/Uninfected/', exist_ok=True) # os.makedirs(experiment_folder + 'generations/C59P20thinF/Parasitized/', exist_ok=True) for i in range(images0.size(0)): torchvision.utils.save_image( images0[i, :, :, :], experiment_folder + 'generations/' + dom + '/Uninfected/sample_{}.png'.format(i)) #torchvision.utils.save_image(images1[i, :, :, :], experiment_folder + 'generations/C59P20thinF/Parasitized/sample_{}.png'.format(i)) images_concat0 = torchvision.utils.make_grid( images0[:25, :, :, :], nrow=int(25**0.5), padding=2, pad_value=255) torchvision.utils.save_image(images_concat0, experiment_folder + dom + '.png') # images_concat1 = torchvision.utils.make_grid(images1[:64,:,:,:], nrow=int(64 ** 0.5), padding=2, pad_value=255) # torchvision.utils.save_image(images_concat1, experiment_folder + 'C59P20thinF.png') elif hparams['yd_condition']: def sample_YD(model, idx, batch_size=1000, temp=args.temperature): with torch.no_grad(): if hparams['yd_condition']: y_0 = torch.zeros([batch_size, 12], device='cuda:0') y_0[:, 0] = torch.ones(batch_size) y_0[:, idx + 2] = torch.ones(batch_size) y_0.to(device) print(y_0) y_1 = torch.zeros([batch_size, 12], device='cuda:0') y_1[:, 1] = torch.ones(batch_size) y_1[:, idx + 2] = torch.ones(batch_size) y_1.to(device) print(y_1) images0 = model(z=None, y_onehot=y_0, temperature=temp, reverse=True, batch_size=1000) images1 = model(z=None, y_onehot=y_1, temperature=temp, reverse=True, batch_size=1000) return images0, images1 def sample_DD(model, idx, batch_size=1000, temp=args.temperature): with torch.no_grad(): if hparams['yd_condition']: y_1 = torch.zeros([batch_size, 20], device='cuda:0') y_1[:, idx] = torch.ones(batch_size) y_1.to(device) print(y_1) y_0 = torch.zeros([batch_size, 20], device='cuda:0') y_0[:, idx + 10] = torch.ones(batch_size) y_0.to(device) print(y_0) images0 = model(z=None, y_onehot=y_0, temperature=temp, reverse=True, batch_size=1000) images1 = model(z=None, y_onehot=y_1, temperature=temp, reverse=True, batch_size=1000) return images0, images1 for idx, dom in enumerate( ["C116P77ThinF", "C132P93ThinF", "C137P98ThinF", "C180P141NThinF", "C182P143NThinF", \ "C184P145ThinF", "C39P4thinF", 'C59P20thinF', "C68P29N", "C99P60ThinF"]): images0, images1 = sample_YD(model, idx) os.makedirs(experiment_folder + 'generations/' + dom + '/Uninfected/', exist_ok=True) os.makedirs(experiment_folder + 'generations/' + dom + '/Parasitized/', exist_ok=True) for i in range(images0.size(0)): torchvision.utils.save_image( images0[i, :, :, :], experiment_folder + 'generations/' + dom + '/Uninfected/sample_{}.png'.format(i)) torchvision.utils.save_image( images1[i, :, :, :], experiment_folder + 'generations/' + dom + '/Parasitized/sample_{}.png'.format(i)) images_concat0 = torchvision.utils.make_grid( images0[:64, :, :, :], nrow=int(64**0.5), padding=2, pad_value=255) torchvision.utils.save_image( images_concat0, experiment_folder + dom + str(args.temperature) + '_uninfected.png') images_concat1 = torchvision.utils.make_grid( images1[:64, :, :, :], nrow=int(64**0.5), padding=2, pad_value=255) torchvision.utils.save_image( images_concat1, experiment_folder + dom + str(args.temperature) + '_parasitized.png') else: def sample(model, temp=args.temperature): with torch.no_grad(): images = model(z=None, y_onehot=None, temperature=temp, reverse=True, batch_size=1000) return images images = sample(model) os.makedirs('unconditioned/' + str(seed) + '/generations/' + experiment_folder[:-3], exist_ok=True) for i in range(images.size(0)): torchvision.utils.save_image( images[i, :, :, :], 'unconditioned/' + str(seed) + '/generations/' + experiment_folder[:-2] + 'sample_{}.png'.format(i)) images_concat = torchvision.utils.make_grid(images[:64, :, :, :], nrow=int(64**0.5), padding=2, pad_value=255) torchvision.utils.save_image( images_concat, 'unconditioned/' + str(seed) + '/' + experiment_folder[:-3] + '.png')
model_name = 'glow_model_1.pth' with open(os.path.join(output_folder, 'hparams.json')) as json_file: hparams = json.load(json_file) image_shape, num_classes, _, test_mnist = get_MNIST(False, hparams['dataroot'], hparams['download']) model = Glow( image_shape, hparams['hidden_channels'], hparams['K'], hparams['L'], hparams['actnorm_scale'], hparams['flow_permutation'], hparams['flow_coupling'], hparams['LU_decomposed'], num_classes, hparams['learn_top'], hparams['y_condition'], False if 'logittransform' not in hparams else hparams['logittransform']) model.load_state_dict(torch.load(os.path.join(output_folder, model_name))) model.set_actnorm_init() model = model.to(device) model = model.eval() def sample(model): with torch.no_grad(): assert not hparams['y_condition'] y = None images = model(y_onehot=y, temperature=1, reverse=True, batch_size=32) # images = postprocess(model(y_onehot=y, temperature=1, reverse=True)) return images.cpu()
hparams['dataroot'] = '../mutual-information' image_shape, num_classes, _, test_cifar = get_CIFAR10(hparams['augment'], hparams['dataroot'], hparams['download']) image_shape, num_classes, _, test_svhn = get_SVHN(hparams['augment'], hparams['dataroot'], hparams['download']) model = Glow(image_shape, hparams['hidden_channels'], hparams['K'], hparams['L'], hparams['actnorm_scale'], hparams['flow_permutation'], hparams['flow_coupling'], hparams['LU_decomposed'], num_classes, hparams['learn_top'], hparams['y_condition']) model.load_state_dict(torch.load(output_folder + model_name)) model.set_actnorm_init() model = model.to(device) model = model.eval() def sample(model): with torch.no_grad(): if hparams['y_condition']: y = torch.eye(num_classes) y = y.repeat(batch_size // num_classes + 1) y = y[:32, :].to(device) # number hardcoded in model for now else: y = None
def main(args): # torch.manual_seed(args.seed) # Test loading and sampling output_folder = os.path.join('results', args.name) with open(os.path.join(output_folder, 'hparams.json')) as json_file: hparams = json.load(json_file) device = "cpu" if not torch.cuda.is_available() else "cuda:0" image_shape = (hparams['patch_size'], hparams['patch_size'], args.n_modalities) num_classes = 1 print('Loading model...') model = Glow(image_shape, hparams['hidden_channels'], hparams['K'], hparams['L'], hparams['actnorm_scale'], hparams['flow_permutation'], hparams['flow_coupling'], hparams['LU_decomposed'], num_classes, hparams['learn_top'], hparams['y_condition']) model_chkpt = torch.load( os.path.join(output_folder, 'checkpoints', args.model)) model.load_state_dict(model_chkpt['model']) model.set_actnorm_init() model = model.to(device) # Build images model.eval() temperature = args.temperature if args.steps is None: # automatically calculate step size if no step size fig_dir = os.path.join(output_folder, 'stepnum_results') if not os.path.exists(fig_dir): os.mkdir(fig_dir) print('No step size entered') # Create sample of images to estimate chord length with torch.no_grad(): mean, logs = model.prior(None, None) z = gaussian_sample(mean, logs, temperature) images_raw = model(z=z, temperature=temperature, reverse=True) images_raw[torch.isnan(images_raw)] = 0.5 images_raw[torch.isinf(images_raw)] = 0.5 images_raw = torch.clamp(images_raw, -0.5, 0.5) images_out = np.transpose( np.squeeze(images_raw[:, args.step_modality, :, :].cpu().numpy()), (1, 0, 2)) # Threshold images and compute covariances if args.binary_data: thresh = 0 else: thresh = threshold_otsu(images_out) images_bin = np.greater(images_out, thresh) x_cov = two_point_correlation(images_bin, 0) y_cov = two_point_correlation(images_bin, 1) # Compute chord length cov_avg = np.mean(np.mean(np.concatenate((x_cov, y_cov), axis=2), axis=0), axis=0) N = 5 S20, _ = curve_fit(straight_line_at_origin(cov_avg[0]), range(0, N), cov_avg[0:N]) l_pore = np.abs(cov_avg[0] / S20) steps = int(l_pore) print('Calculated step size: {}'.format(steps)) else: print('Using user-entered step size {}...'.format(args.steps)) steps = args.steps # Build desired number of volumes for iter_vol in range(args.iter): if args.iter == 1: stack_dir = os.path.join(output_folder, 'image_stacks', args.save_name) print('Sampling images, saving to {}...'.format(args.save_name)) else: stack_dir = os.path.join( output_folder, 'image_stacks', args.save_name + '_' + str(iter_vol).zfill(3)) print('Sampling images, saving to {}_'.format(args.save_name) + str(iter_vol).zfill(3) + '...') if not os.path.exists(stack_dir): os.makedirs(stack_dir) with torch.no_grad(): mean, logs = model.prior(None, None) alpha = 1 - torch.reshape(torch.linspace(0, 1, steps=steps), (-1, 1, 1, 1)) alpha = alpha.to(device) num_imgs = int(np.ceil(hparams['patch_size'] / steps) + 1) z = gaussian_sample(mean, logs, temperature)[:num_imgs, ...] z = torch.cat([ alpha * z[i, ...] + (1 - alpha) * z[i + 1, ...] for i in range(num_imgs - 1) ]) z = z[:hparams['patch_size'], ...] images_raw = model(z=z, temperature=temperature, reverse=True) images_raw[torch.isnan(images_raw)] = 0.5 images_raw[torch.isinf(images_raw)] = 0.5 images_raw = torch.clamp(images_raw, -0.5, 0.5) # apply median filter to output if args.med_filt is not None or args.binary_data: for m in range(args.n_modalities): if args.binary_data: SE = ball(1) else: SE = ball(args.med_filt) images_np = np.squeeze(images_raw[:, m, :, :].cpu().numpy()) images_filt = median_filter(images_np, footprint=SE) # Erode binary images if args.binary_data: images_filt = np.greater(images_filt, 0) SE = ball(1) images_filt = 1.0 * binary_erosion(images_filt, selem=SE) - 0.5 images_raw[:, m, :, :] = torch.tensor(images_filt, device=device) images1 = postprocess(images_raw).cpu() images2 = postprocess(torch.transpose(images_raw, 0, 2)).cpu() images3 = postprocess(torch.transpose(images_raw, 0, 3)).cpu() # apply Otsu thresholding to output if args.save_binary and not args.binary_data: thresh = threshold_otsu(images1.numpy()) images1[images1 < thresh] = 0 images1[images1 > thresh] = 255 images2[images2 < thresh] = 0 images2[images2 > thresh] = 255 images3[images3 < thresh] = 0 images3[images3 > thresh] = 255 # # erode binary images by 1 px to correct for training image transformation # if args.binary_data: # images1 = np.greater(images1.numpy(), 127) # images2 = np.greater(images2.numpy(), 127) # images3 = np.greater(images3.numpy(), 127) # images1 = 255*torch.tensor(1.0*np.expand_dims(binary_erosion(np.squeeze(images1), selem=np.ones((1,2,2))), 1)) # images2 = 255*torch.tensor(1.0*np.expand_dims(binary_erosion(np.squeeze(images2), selem=np.ones((2,1,2))), 1)) # images3 = 255*torch.tensor(1.0*np.expand_dims(binary_erosion(np.squeeze(images3), selem=np.ones((2,2,1))), 1)) # save video for each modality for m in range(args.n_modalities): if args.n_modalities > 1: save_dir = os.path.join(stack_dir, 'modality{}'.format(m)) else: save_dir = stack_dir if not os.path.exists(save_dir): os.makedirs(save_dir) write_video(images1[:, m, :, :], 'xy', hparams, save_dir) write_video(images2[:, m, :, :], 'xz', hparams, save_dir) write_video(images3[:, m, :, :], 'yz', hparams, save_dir) print('Finished!')
def main( dataset, dataroot, download, augment, batch_size, eval_batch_size, epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, learn_top, y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers, cuda, n_init_batches, output_dir, saved_optimizer, warmup, classifier_weight ): device = "cpu" if (not torch.cuda.is_available() or not cuda) else "cuda:0" wandb.init(project=args.dataset) check_manual_seed(seed) image_shape = (64,64,3) # if args.dataset == "task1": num_classes = 24 # else : num_classes = 40 num_classes = 40 # Note: unsupported for now multi_class = True #It's True but this variable doesn't be used now # if args.dataset == "task1": # dataset_train = CLEVRDataset(root_folder=args.dataroot,img_folder=args.dataroot+'images/') # train_loader = DataLoader(dataset_train,batch_size=args.batch_size,shuffle=True,drop_last=True) # else : # dataset_train = CelebALoader(root_folder=args.dataroot) #'/home/arg/courses/machine_learning/homework/deep_learning_and_practice/Lab7/dataset/task_2/' # train_loader = DataLoader(dataset_train,batch_size=args.batch_size,shuffle=True,drop_last=True) dataset_train = CelebALoader(root_folder=args.dataroot) #'/home/arg/courses/machine_learning/homework/deep_learning_and_practice/Lab7/dataset/task_2/' train_loader = DataLoader(dataset_train,batch_size=args.batch_size,shuffle=True,drop_last=True) model = Glow( image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, num_classes, learn_top, y_condition, ) model = model.to(device) optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5) lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup) # noqa scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) wandb.watch(model) def step(engine, batch): model.train() optimizer.zero_grad() x, y = batch x = x.to(device) if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) ### x: torch.Size([batchsize, 3, 64, 64]); y: torch.Size([batchsize, 24]); z: torch.Size([batchsize, 48, 8, 8]) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class) else: z, nll, y_logits = model(x, None) losses = compute_loss(nll) losses["total_loss"].backward() if max_grad_clip > 0: torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) optimizer.step() return losses trainer = Engine(step) checkpoint_handler = ModelCheckpoint( output_dir, "glow", n_saved=None, require_empty=False ) ### n_saved (Optional[int]) – Number of objects that should be kept on disk. Older files will be removed. If set to None, all objects are kept. trainer.add_event_handler( Events.EPOCH_COMPLETED, checkpoint_handler, {"model": model, "optimizer": optimizer}, ) monitoring_metrics = ["total_loss"] RunningAverage(output_transform=lambda x: x["total_loss"]).attach( trainer, "total_loss" ) pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) if saved_model: model.load_state_dict(torch.load(saved_model, map_location="cpu")['model']) model.set_actnorm_init() @trainer.on(Events.STARTED) def init(engine): model.train() init_batches = [] init_targets = [] with torch.no_grad(): for batch, target in islice(train_loader, None, n_init_batches): init_batches.append(batch) init_targets.append(target) init_batches = torch.cat(init_batches).to(device) assert init_batches.shape[0] == n_init_batches * batch_size if y_condition: init_targets = torch.cat(init_targets).to(device) else: init_targets = None model(init_batches, init_targets) # evaluator = evaluation_model(args.classifier_weight) # @trainer.on(Events.EPOCH_COMPLETED) # def evaluate(engine): # if args.dataset == "task1": # model.eval() # with torch.no_grad(): # test_conditions = get_test_conditions(args.dataroot).cuda() # predict_x = postprocess(model(y_onehot=test_conditions, temperature=1, reverse=True)).float() # score = evaluator.eval(predict_x, test_conditions) # save_image(predict_x.float(), args.output_dir+f"/Epoch{engine.state.epoch}_score{score:.3f}.png", normalize=True) # test_conditions = get_new_test_conditions(args.dataroot).cuda() # predict_x = postprocess(model(y_onehot=test_conditions, temperature=1, reverse=True)).float() # newscore = evaluator.eval(predict_x.float(), test_conditions) # save_image(predict_x.float(), args.output_dir+f"/Epoch{engine.state.epoch}_newscore{newscore:.3f}.png", normalize=True) # print(f"Iter: {engine.state.iteration} score:{score:.3f} newscore:{newscore:.3f} ") # wandb.log({"score": score, "new_score": newscore}) trainer.run(train_loader, epochs)
def main(dataset, dataroot, download, augment, n_workers, eval_batch_size, output_dir,db, glow_path,ckpt_name): (image_shape, num_classes, train_dataset, test_dataset) = check_dataset(dataset, dataroot, augment, download) test_loader = data.DataLoader(test_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=n_workers, drop_last=False) x = test_loader.__iter__().__next__()[0].to(device) # OOD data ood_distributions = ['gaussian'] # ood_distributions = ['gaussian', 'rademacher', 'texture3', 'svhn','tinyimagenet','lsun'] tr = transforms.Compose([]) tr.transforms.append(transforms.ToPILImage()) tr.transforms.append(transforms.Resize((32,32))) tr.transforms.append(transforms.ToTensor()) tr.transforms.append(one_to_three_channels) tr.transforms.append(preprocess) ood_tensors = [(out_name, torch.stack([tr(x) for x in load_ood_data({ 'name': out_name, 'ood_scale': 1, 'n_anom': eval_batch_size, })]).to(device) ) for out_name in ood_distributions] if 'sd' in glow_path: with open(os.path.join(os.path.dirname(glow_path), 'hparams.json'), 'r') as f: model_kwargs = json.load(f) model = Glow( (32, 32, 3), model_kwargs['hidden_channels'], model_kwargs['K'], model_kwargs['L'], model_kwargs['actnorm_scale'], model_kwargs['flow_permutation'], model_kwargs['flow_coupling'], model_kwargs['LU_decomposed'], 10, model_kwargs['learn_top'], model_kwargs['y_condition'], model_kwargs['logittransform'], model_kwargs['sn'], model_kwargs['affine_eps'], model_kwargs['no_actnorm'], model_kwargs['affine_scale_eps'], model_kwargs['actnorm_max_scale'], model_kwargs['no_conv_actnorm'], model_kwargs['affine_max_scale'], model_kwargs['actnorm_eps'], model_kwargs['no_split'] ) model.load_state_dict(torch.load(glow_path)) model.set_actnorm_init() else: model = torch.load(glow_path) model = model.to(device) model.eval() with torch.no_grad(): samples = generate_from_noise(model, eval_batch_size,clamp=False, guard_nans=False) stats = OrderedDict() for name, x in [('data',x), ('samples',samples)] + ood_tensors: p_pxs, p_ims, cn, dlogdet, bpd, pad = run_analysis(x, model, os.path.join(output_dir, f'recon_{ckpt_name}_{name}.jpeg')) stats[f"{name}-percent-pixels-nans"] = p_pxs stats[f"{name}-percent-imgs-nans"] = p_ims stats[f"{name}-cn"] = cn stats[f"{name}-dlogdet"] = dlogdet stats[f"{name}-bpd"] = bpd stats[f"{name}-recon-err"] = pad with open(os.path.join(output_dir, f'results_{ckpt_name}.json'), 'w') as fp: json.dump(stats, fp, indent=4)
def main( dataset, augment, batch_size, eval_batch_size, epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, learn_top, y_condition, extra_condition, sp_condition, d_condition, yd_condition, y_weight, d_weight, max_grad_clip, max_grad_norm, lr, n_workers, cuda, n_init_batches, output_dir, missing, saved_optimizer, warmup, ): print(output_dir) device = "cpu" if (not torch.cuda.is_available() or not cuda) else "cuda:0" print(device) check_manual_seed(seed) print("augmenting?", augment) train_dataset, test_dataset = check_dataset(dataset, augment, missing) image_shape = (32, 32, 3) multi_class = False if yd_condition: num_classes = 2 num_domains = 10 #num_classes = 10+2 #multi_class=True elif d_condition: num_classes = 10 num_domains = 0 else: num_classes = 2 num_domains = 0 #print("num classes", num_classes) train_loader = data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True, ) test_loader = data.DataLoader( test_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=n_workers, drop_last=False, ) model = Glow(image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, num_classes, num_domains, learn_top, y_condition, extra_condition, sp_condition, d_condition, yd_condition) model = model.to(device) optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5) lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup) # noqa scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) def step(engine, batch): model.train() optimizer.zero_grad() x, y, d, yd = batch x = x.to(device) if y_condition: y = y.to(device) z, nll, y_logits, spare = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class) elif d_condition: d = d.to(device) z, nll, d_logits, spare = model(x, d) losses = compute_loss_y(nll, d_logits, d_weight, d, multi_class) elif yd_condition: y, d, yd = y.to(device), d.to(device), yd.to(device) z, nll, y_logits, d_logits = model(x, yd) losses = compute_loss_yd(nll, y_logits, y_weight, y, d_logits, d_weight, d) else: print("none") z, nll, y_logits, spare = model(x, None) losses = compute_loss(nll) losses["total_loss"].backward() if max_grad_clip > 0: torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) optimizer.step() return losses def eval_step(engine, batch): model.eval() x, y, d, yd = batch x = x.to(device) with torch.no_grad(): if y_condition: y = y.to(device) z, nll, y_logits, none_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class, reduction="none") elif d_condition: d = d.to(device) z, nll, d_logits, non_logits = model(x, d) losses = compute_loss_y(nll, d_logits, d_weight, d, multi_class, reduction="none") elif yd_condition: y, d, yd = y.to(device), d.to(device), yd.to(device) z, nll, y_logits, d_logits = model(x, yd) losses = compute_loss_yd(nll, y_logits, y_weight, y, d_logits, d_weight, d, reduction="none") else: z, nll, y_logits, d_logits = model(x, None) losses = compute_loss(nll, reduction="none") #print(losses, "losssssess") return losses trainer = Engine(step) checkpoint_handler = ModelCheckpoint(output_dir, "glow", save_interval=1, n_saved=2, require_empty=False) trainer.add_event_handler( Events.EPOCH_COMPLETED, checkpoint_handler, { "model": model, "optimizer": optimizer }, ) monitoring_metrics = ["total_loss"] RunningAverage(output_transform=lambda x: x["total_loss"]).attach( trainer, "total_loss") evaluator = Engine(eval_step) # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss( lambda x, y: torch.mean(x), output_transform=lambda x: ( x["total_loss"], torch.empty(x["total_loss"].shape[0]), ), ).attach(evaluator, "total_loss") if y_condition or d_condition or yd_condition: monitoring_metrics.extend(["nll"]) RunningAverage(output_transform=lambda x: x["nll"]).attach( trainer, "nll") # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss( lambda x, y: torch.mean(x), output_transform=lambda x: (x["nll"], torch.empty(x["nll"].shape[0])), ).attach(evaluator, "nll") pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) # load pre-trained model if given if saved_model: model.load_state_dict(torch.load(saved_model)) model.set_actnorm_init() if saved_optimizer: optimizer.load_state_dict(torch.load(saved_optimizer)) file_name, ext = os.path.splitext(saved_model) resume_epoch = int(file_name.split("_")[-1]) @trainer.on(Events.STARTED) def resume_training(engine): engine.state.epoch = resume_epoch engine.state.iteration = resume_epoch * len( engine.state.dataloader) @trainer.on(Events.STARTED) def init(engine): model.train() init_batches = [] init_targets = [] init_domains = [] init_yds = [] with torch.no_grad(): for batch, target, domain, yd in islice(train_loader, None, n_init_batches): init_batches.append(batch) init_targets.append(target) init_domains.append(domain) init_yds.append(yd) init_batches = torch.cat(init_batches).to(device) assert init_batches.shape[0] == n_init_batches * batch_size if y_condition: init_targets = torch.cat(init_targets).to(device) model(init_batches, init_targets) elif d_condition: init_domains = torch.cat(init_domains).to(device) model(init_batches, init_domains) elif yd_condition: init_yds = torch.cat(init_yds).to(device) model(init_batches, init_yds) else: init_targets = None model(init_batches, init_targets) @trainer.on(Events.EPOCH_COMPLETED) def evaluate(engine): evaluator.run(test_loader) #print("done") scheduler.step() metrics = evaluator.state.metrics losses = ", ".join( [f"{key}: {value:.8f}" for key, value in metrics.items()]) print(f"Validation Results - Epoch: {engine.state.epoch} {losses}") def score_function(engine): val_loss = engine.state.metrics['total_loss'] return -val_loss name = "best_" val_handler = ModelCheckpoint(output_dir, name, score_function=score_function, score_name="val_loss", n_saved=1, require_empty=False) evaluator.add_event_handler( Events.EPOCH_COMPLETED, val_handler, {"model": model}, ) timer = Timer(average=True) timer.attach( trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED, ) @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message( f"Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]" ) timer.reset() trainer.run(train_loader, epochs)
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size, epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, learn_top, y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers, cuda, n_init_batches, warmup_steps, output_dir, saved_optimizer, warmup, fresh, logittransform, gan, disc_lr): device = 'cpu' if (not torch.cuda.is_available() or not cuda) else 'cuda:0' check_manual_seed(seed) ds = check_dataset(dataset, dataroot, augment, download) image_shape, num_classes, train_dataset, test_dataset = ds # Note: unsupported for now multi_class = False train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True) test_loader = data.DataLoader(test_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=n_workers, drop_last=False) model = Glow(image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, num_classes, learn_top, y_condition, logittransform) model = model.to(device) if gan: # Debug model = mine.Generator(32, 1).to(device) optimizer = optim.Adam(model.parameters(), lr=lr, betas=(.5, .99), weight_decay=0) discriminator = mine.Discriminator(image_shape[-1]) discriminator = discriminator.to(device) D_optimizer = optim.Adam(filter(lambda p: p.requires_grad, discriminator.parameters()), lr=disc_lr, betas=(.5, .99), weight_decay=0) else: optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5) # lr_lambda = lambda epoch: lr * min(1., epoch+1 / warmup) # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) i = 0 def step(engine, batch): model.train() optimizer.zero_grad() x, y = batch x = x.to(device) if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class) else: z, nll, y_logits = model(x, None) losses = compute_loss(nll) losses['total_loss'].backward() if max_grad_clip > 0: torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) optimizer.step() return losses def gan_step(engine, batch): assert not y_condition if 'iter_ind' in dir(engine): engine.iter_ind += 1 else: engine.iter_ind = -1 losses = {} model.train() discriminator.train() x, y = batch x = x.to(device) # def generate_from_noise(batch_size): # _, c2, h, w = model.prior_h.shape # c = c2 // 2 # zshape = (batch_size, c, h, w) # randz = torch.autograd.Variable(torch.randn(zshape), requires_grad=True).to(device) # images = model(z= randz, y_onehot=None, temperature=1, reverse=True,batch_size=batch_size) # return images def generate_from_noise(batch_size): zshape = (batch_size, 32, 1, 1) randz = torch.randn(zshape).to(device) images = model(randz) return images / 2 def run_noised_disc(discriminator, x): x = uniform_binning_correction(x)[0] return discriminator(x) # Train Disc fake = generate_from_noise(x.size(0)) D_real_scores = run_noised_disc(discriminator, x.detach()) D_fake_scores = run_noised_disc(discriminator, fake.detach()) ones_target = torch.ones((x.size(0), 1), device=x.device) zeros_target = torch.zeros((x.size(0), 1), device=x.device) # D_real_accuracy = torch.sum(torch.round(F.sigmoid(D_real_scores)) == ones_target).float() / ones_target.size(0) # D_fake_accuracy = torch.sum(torch.round(F.sigmoid(D_fake_scores)) == zeros_target).float() / zeros_target.size(0) D_real_loss = F.binary_cross_entropy_with_logits( D_real_scores, ones_target) D_fake_loss = F.binary_cross_entropy_with_logits( D_fake_scores, zeros_target) D_loss = (D_real_loss + D_fake_loss) / 2 gp = gradient_penalty(x.detach(), fake.detach(), lambda _x: run_noised_disc(discriminator, _x)) D_loss_plus_gp = D_loss + 10 * gp D_optimizer.zero_grad() D_loss_plus_gp.backward() D_optimizer.step() # Train generator fake = generate_from_noise(x.size(0)) G_loss = F.binary_cross_entropy_with_logits( run_noised_disc(discriminator, fake), torch.ones((x.size(0), 1), device=x.device)) losses['total_loss'] = G_loss # G-step optimizer.zero_grad() losses['total_loss'].backward() params = list(model.parameters()) gnorm = [p.grad.norm() for p in params] optimizer.step() # if max_grad_clip > 0: # torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) # if max_grad_norm > 0: # torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) if engine.iter_ind % 50 == 0: grid = make_grid((postprocess(fake.detach().cpu())[:30]), nrow=6).permute(1, 2, 0) plt.figure(figsize=(10, 10)) plt.imshow(grid) plt.axis('off') plt.savefig( os.path.join(output_dir, f'sample_{engine.iter_ind}.png')) grid = make_grid( (postprocess(uniform_binning_correction(x)[0].cpu())[:30]), nrow=6).permute(1, 2, 0) plt.figure(figsize=(10, 10)) plt.imshow(grid) plt.axis('off') plt.savefig(os.path.join(output_dir, f'data_{engine.iter_ind}.png')) return losses def eval_step(engine, batch): model.eval() x, y = batch x = x.to(device) with torch.no_grad(): if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class, reduction='none') else: z, nll, y_logits = model(x, None) losses = compute_loss(nll, reduction='none') return losses if gan: trainer = Engine(gan_step) else: trainer = Engine(step) checkpoint_handler = ModelCheckpoint(output_dir, 'glow', save_interval=1, n_saved=2, require_empty=False) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, { 'model': model, 'optimizer': optimizer }) monitoring_metrics = ['total_loss'] RunningAverage(output_transform=lambda x: x['total_loss']).attach( trainer, 'total_loss') evaluator = Engine(eval_step) # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach( evaluator, 'total_loss') if y_condition: monitoring_metrics.extend(['nll']) RunningAverage(output_transform=lambda x: x['nll']).attach( trainer, 'nll') # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['nll'], torch.empty(x['nll'].shape[0]))).attach( evaluator, 'nll') pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) # load pre-trained model if given if saved_model: model.load_state_dict(torch.load(saved_model)) model.set_actnorm_init() if saved_optimizer: optimizer.load_state_dict(torch.load(saved_optimizer)) file_name, ext = os.path.splitext(saved_model) resume_epoch = int(file_name.split('_')[-1]) @trainer.on(Events.STARTED) def resume_training(engine): engine.state.epoch = resume_epoch engine.state.iteration = resume_epoch * len( engine.state.dataloader) # @trainer.on(Events.STARTED) # def init(engine): # model.train() # init_batches = [] # init_targets = [] # with torch.no_grad(): # for batch, target in islice(train_loader, None, # n_init_batches): # init_batches.append(batch) # init_targets.append(target) # init_batches = torch.cat(init_batches).to(device) # assert init_batches.shape[0] == n_init_batches * batch_size # if y_condition: # init_targets = torch.cat(init_targets).to(device) # else: # init_targets = None # model(init_batches, init_targets) # @trainer.on(Events.EPOCH_COMPLETED) # def evaluate(engine): # evaluator.run(test_loader) # # scheduler.step() # metrics = evaluator.state.metrics # losses = ', '.join([f"{key}: {value:.2f}" for key, value in metrics.items()]) # myprint(f'Validation Results - Epoch: {engine.state.epoch} {losses}') timer = Timer(average=True) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message( f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]' ) timer.reset() trainer.run(train_loader, epochs)
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size, epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, learn_top, y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers, cuda, n_init_batches, warmup_steps, output_dir, saved_optimizer, warmup, fresh, logittransform, gan, disc_lr, sn, flowgan, eval_every, ld_on_samples, weight_gan, weight_prior, weight_logdet, jac_reg_lambda, affine_eps, no_warm_up, optim_name, clamp, svd_every, eval_only, no_actnorm, affine_scale_eps, actnorm_max_scale, no_conv_actnorm, affine_max_scale, actnorm_eps, init_sample, no_split, disc_arch, weight_entropy_reg, db): check_manual_seed(seed) ds = check_dataset(dataset, dataroot, augment, download) image_shape, num_classes, train_dataset, test_dataset = ds # Note: unsupported for now multi_class = False train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True) test_loader = data.DataLoader(test_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=n_workers, drop_last=False) model = Glow(image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, num_classes, learn_top, y_condition, logittransform, sn, affine_eps, no_actnorm, affine_scale_eps, actnorm_max_scale, no_conv_actnorm, affine_max_scale, actnorm_eps, no_split) model = model.to(device) if disc_arch == 'mine': discriminator = mine.Discriminator(image_shape[-1]) elif disc_arch == 'biggan': discriminator = cgan_models.Discriminator( image_channels=image_shape[-1], conditional_D=False) elif disc_arch == 'dcgan': discriminator = DCGANDiscriminator(image_shape[0], 64, image_shape[-1]) elif disc_arch == 'inv': discriminator = InvDiscriminator( image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, num_classes, learn_top, y_condition, logittransform, sn, affine_eps, no_actnorm, affine_scale_eps, actnorm_max_scale, no_conv_actnorm, affine_max_scale, actnorm_eps, no_split) discriminator = discriminator.to(device) D_optimizer = optim.Adam(filter(lambda p: p.requires_grad, discriminator.parameters()), lr=disc_lr, betas=(.5, .99), weight_decay=0) if optim_name == 'adam': optimizer = optim.Adam(model.parameters(), lr=lr, betas=(.5, .99), weight_decay=0) elif optim_name == 'adamax': optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5) if not no_warm_up: lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) iteration_fieldnames = [ 'global_iteration', 'fid', 'sample_pad', 'train_bpd', 'eval_bpd', 'pad', 'batch_real_acc', 'batch_fake_acc', 'batch_acc' ] iteration_logger = CSVLogger(fieldnames=iteration_fieldnames, filename=os.path.join(output_dir, 'iteration_log.csv')) iteration_fieldnames = [ 'global_iteration', 'condition_num', 'max_sv', 'min_sv', 'inverse_condition_num', 'inverse_max_sv', 'inverse_min_sv' ] svd_logger = CSVLogger(fieldnames=iteration_fieldnames, filename=os.path.join(output_dir, 'svd_log.csv')) # test_iter = test_loader.__iter__() N_inception = 1000 x_real_inception = torch.cat([ test_iter.__next__()[0].to(device) for _ in range(N_inception // args.batch_size + 1) ], 0)[:N_inception] x_real_inception = x_real_inception + .5 x_for_recon = test_iter.__next__()[0].to(device) def gan_step(engine, batch): assert not y_condition if 'iter_ind' in dir(engine): engine.iter_ind += 1 else: engine.iter_ind = -1 losses = {} model.train() discriminator.train() x, y = batch x = x.to(device) def run_noised_disc(discriminator, x): x = uniform_binning_correction(x)[0] return discriminator(x) real_acc = fake_acc = acc = 0 if weight_gan > 0: fake = generate_from_noise(model, x.size(0), clamp=clamp) D_real_scores = run_noised_disc(discriminator, x.detach()) D_fake_scores = run_noised_disc(discriminator, fake.detach()) ones_target = torch.ones((x.size(0), 1), device=x.device) zeros_target = torch.zeros((x.size(0), 1), device=x.device) D_real_accuracy = torch.sum( torch.round(F.sigmoid(D_real_scores)) == ones_target).float() / ones_target.size(0) D_fake_accuracy = torch.sum( torch.round(F.sigmoid(D_fake_scores)) == zeros_target).float() / zeros_target.size(0) D_real_loss = F.binary_cross_entropy_with_logits( D_real_scores, ones_target) D_fake_loss = F.binary_cross_entropy_with_logits( D_fake_scores, zeros_target) D_loss = (D_real_loss + D_fake_loss) / 2 gp = gradient_penalty( x.detach(), fake.detach(), lambda _x: run_noised_disc(discriminator, _x)) D_loss_plus_gp = D_loss + 10 * gp D_optimizer.zero_grad() D_loss_plus_gp.backward() D_optimizer.step() # Train generator fake = generate_from_noise(model, x.size(0), clamp=clamp, guard_nans=False) G_loss = F.binary_cross_entropy_with_logits( run_noised_disc(discriminator, fake), torch.ones((x.size(0), 1), device=x.device)) # Trace real_acc = D_real_accuracy.item() fake_acc = D_fake_accuracy.item() acc = .5 * (D_fake_accuracy.item() + D_real_accuracy.item()) z, nll, y_logits, (prior, logdet) = model.forward(x, None, return_details=True) train_bpd = nll.mean().item() loss = 0 if weight_gan > 0: loss = loss + weight_gan * G_loss if weight_prior > 0: loss = loss + weight_prior * -prior.mean() if weight_logdet > 0: loss = loss + weight_logdet * -logdet.mean() if weight_entropy_reg > 0: _, _, _, (sample_prior, sample_logdet) = model.forward(fake, None, return_details=True) # notice this is actually "decreasing" sample likelihood. loss = loss + weight_entropy_reg * (sample_prior.mean() + sample_logdet.mean()) # Jac Reg if jac_reg_lambda > 0: # Sample x_samples = generate_from_noise(model, args.batch_size, clamp=clamp).detach() x_samples.requires_grad_() z = model.forward(x_samples, None, return_details=True)[0] other_zs = torch.cat([ split._last_z2.view(x.size(0), -1) for split in model.flow.splits ], -1) all_z = torch.cat([other_zs, z.view(x.size(0), -1)], -1) sample_foward_jac = compute_jacobian_regularizer(x_samples, all_z, n_proj=1) _, c2, h, w = model.prior_h.shape c = c2 // 2 zshape = (batch_size, c, h, w) randz = torch.randn(zshape).to(device) randz = torch.autograd.Variable(randz, requires_grad=True) images = model(z=randz, y_onehot=None, temperature=1, reverse=True, batch_size=0) other_zs = [split._last_z2 for split in model.flow.splits] all_z = [randz] + other_zs sample_inverse_jac = compute_jacobian_regularizer_manyinputs( all_z, images, n_proj=1) # Data x.requires_grad_() z = model.forward(x, None, return_details=True)[0] other_zs = torch.cat([ split._last_z2.view(x.size(0), -1) for split in model.flow.splits ], -1) all_z = torch.cat([other_zs, z.view(x.size(0), -1)], -1) data_foward_jac = compute_jacobian_regularizer(x, all_z, n_proj=1) _, c2, h, w = model.prior_h.shape c = c2 // 2 zshape = (batch_size, c, h, w) z.requires_grad_() images = model(z=z, y_onehot=None, temperature=1, reverse=True, batch_size=0) other_zs = [split._last_z2 for split in model.flow.splits] all_z = [z] + other_zs data_inverse_jac = compute_jacobian_regularizer_manyinputs( all_z, images, n_proj=1) # loss = loss + jac_reg_lambda * (sample_foward_jac + sample_inverse_jac ) loss = loss + jac_reg_lambda * (sample_foward_jac + sample_inverse_jac + data_foward_jac + data_inverse_jac) if not eval_only: optimizer.zero_grad() loss.backward() if not db: assert max_grad_clip == max_grad_norm == 0 if max_grad_clip > 0: torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # Replace NaN gradient with 0 for p in model.parameters(): if p.requires_grad and p.grad is not None: g = p.grad.data g[g != g] = 0 optimizer.step() if engine.iter_ind % 100 == 0: with torch.no_grad(): fake = generate_from_noise(model, x.size(0), clamp=clamp) z = model.forward(fake, None, return_details=True)[0] print("Z max min") print(z.max().item(), z.min().item()) if (fake != fake).float().sum() > 0: title = 'NaNs' else: title = "Good" grid = make_grid((postprocess(fake.detach().cpu(), dataset)[:30]), nrow=6).permute(1, 2, 0) plt.figure(figsize=(10, 10)) plt.imshow(grid) plt.axis('off') plt.title(title) plt.savefig( os.path.join(output_dir, f'sample_{engine.iter_ind}.png')) if engine.iter_ind % eval_every == 0: def check_all_zero_except_leading(x): return x % 10**np.floor(np.log10(x)) == 0 if engine.iter_ind == 0 or check_all_zero_except_leading( engine.iter_ind): torch.save( model.state_dict(), os.path.join(output_dir, f'ckpt_sd_{engine.iter_ind}.pt')) model.eval() with torch.no_grad(): # Plot recon fpath = os.path.join(output_dir, '_recon', f'recon_{engine.iter_ind}.png') sample_pad = run_recon_evolution( model, generate_from_noise(model, args.batch_size, clamp=clamp).detach(), fpath) print( f"Iter: {engine.iter_ind}, Recon Sample PAD: {sample_pad}") pad = run_recon_evolution(model, x_for_recon, fpath) print(f"Iter: {engine.iter_ind}, Recon PAD: {pad}") pad = pad.item() sample_pad = sample_pad.item() # Inception score sample = torch.cat([ generate_from_noise(model, args.batch_size, clamp=clamp) for _ in range(N_inception // args.batch_size + 1) ], 0)[:N_inception] sample = sample + .5 if (sample != sample).float().sum() > 0: print("Sample NaNs") raise else: fid = run_fid(x_real_inception.clamp_(0, 1), sample.clamp_(0, 1)) print(f'fid: {fid}, global_iter: {engine.iter_ind}') # Eval BPD eval_bpd = np.mean([ model.forward(x.to(device), None, return_details=True)[1].mean().item() for x, _ in test_loader ]) stats_dict = { 'global_iteration': engine.iter_ind, 'fid': fid, 'train_bpd': train_bpd, 'pad': pad, 'eval_bpd': eval_bpd, 'sample_pad': sample_pad, 'batch_real_acc': real_acc, 'batch_fake_acc': fake_acc, 'batch_acc': acc } iteration_logger.writerow(stats_dict) plot_csv(iteration_logger.filename) model.train() if engine.iter_ind + 2 % svd_every == 0: model.eval() svd_dict = {} ret = utils.computeSVDjacobian(x_for_recon, model) D_for, D_inv = ret['D_for'], ret['D_inv'] cn = float(D_for.max() / D_for.min()) cn_inv = float(D_inv.max() / D_inv.min()) svd_dict['global_iteration'] = engine.iter_ind svd_dict['condition_num'] = cn svd_dict['max_sv'] = float(D_for.max()) svd_dict['min_sv'] = float(D_for.min()) svd_dict['inverse_condition_num'] = cn_inv svd_dict['inverse_max_sv'] = float(D_inv.max()) svd_dict['inverse_min_sv'] = float(D_inv.min()) svd_logger.writerow(svd_dict) # plot_utils.plot_stability_stats(output_dir) # plot_utils.plot_individual_figures(output_dir, 'svd_log.csv') model.train() if eval_only: sys.exit() # Dummy losses['total_loss'] = torch.mean(nll).item() return losses def eval_step(engine, batch): model.eval() x, y = batch x = x.to(device) with torch.no_grad(): if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class, reduction='none') else: z, nll, y_logits = model(x, None) losses = compute_loss(nll, reduction='none') return losses trainer = Engine(gan_step) # else: # trainer = Engine(step) checkpoint_handler = ModelCheckpoint(output_dir, 'glow', save_interval=5, n_saved=1, require_empty=False) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, { 'model': model, 'optimizer': optimizer }) monitoring_metrics = ['total_loss'] RunningAverage(output_transform=lambda x: x['total_loss']).attach( trainer, 'total_loss') evaluator = Engine(eval_step) # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach( evaluator, 'total_loss') if y_condition: monitoring_metrics.extend(['nll']) RunningAverage(output_transform=lambda x: x['nll']).attach( trainer, 'nll') # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['nll'], torch.empty(x['nll'].shape[0]))).attach( evaluator, 'nll') pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) # load pre-trained model if given if saved_model: print("Loading...") print(saved_model) loaded = torch.load(saved_model) # if 'Glow' in str(type(loaded)): # model = loaded # else: # raise # # if 'Glow' in str(type(loaded)): # # loaded = loaded.state_dict() model.load_state_dict(loaded) model.set_actnorm_init() if saved_optimizer: optimizer.load_state_dict(torch.load(saved_optimizer)) file_name, ext = os.path.splitext(saved_model) resume_epoch = int(file_name.split('_')[-1]) @trainer.on(Events.STARTED) def resume_training(engine): engine.state.epoch = resume_epoch engine.state.iteration = resume_epoch * len( engine.state.dataloader) @trainer.on(Events.STARTED) def init(engine): if saved_model: return model.train() print("Initializing Actnorm...") init_batches = [] init_targets = [] if n_init_batches == 0: model.set_actnorm_init() return with torch.no_grad(): if init_sample: generate_from_noise(model, args.batch_size * args.n_init_batches) else: for batch, target in islice(train_loader, None, n_init_batches): init_batches.append(batch) init_targets.append(target) init_batches = torch.cat(init_batches).to(device) assert init_batches.shape[0] == n_init_batches * batch_size if y_condition: init_targets = torch.cat(init_targets).to(device) else: init_targets = None model(init_batches, init_targets) @trainer.on(Events.EPOCH_COMPLETED) def evaluate(engine): evaluator.run(test_loader) if not no_warm_up: scheduler.step() metrics = evaluator.state.metrics losses = ', '.join( [f"{key}: {value:.2f}" for key, value in metrics.items()]) print(f'Validation Results - Epoch: {engine.state.epoch} {losses}') timer = Timer(average=True) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message( f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]' ) timer.reset() trainer.run(train_loader, epochs)
images = model(z= randz, y_onehot=None, temperature=1, reverse=True,batch_size=batch_size) return images # ipdb.set_trace() iteration_fieldnames = ['global_iteration', 'fid'] iteration_logger = CSVLogger(fieldnames=iteration_fieldnames, filename=os.path.join(output_dir, 'eval_log.csv')) # for idx in tqdm(np.arange(10,100,10)): for _ in range(1): idx = 100 model(xs[:10].cuda(), None) # this is to initialize the u,v buffer in SpectralNorm blah...otherwise state_dicts don't match... saved_model = os.path.join(exp_dir, f"glow_model_{idx}.pth") model.load_state_dict(torch.load(saved_model)) model.set_actnorm_init() # Sample with torch.no_grad(): fake = torch.cat([generate_from_noise(100) for _ in range(20)],0 ) x_is = 2*fake x_is = x_is.repeat(1,3,1,1).detach() # I have no clue why samples can contain nan....but it does... def _replace_nan_with_k_inplace(x, k): mask = x != x x[mask] = k _replace_nan_with_k_inplace(x_is, -1) with torch.no_grad(): issf, _, _, acts_fake = inception_score(x_is, cuda=True, batch_size=32, resize=True, splits=10, return_preds=True) idxs_ = np.argsort(np.abs(acts_fake).sum(-1))[:1800] # filter the ones with super large values acts_fake = acts_fake[idxs_]
num_classes = 40 Batch_Size = 4 dataset_test = CelebALoader( root_folder=hparams['dataroot'] ) #'/home/yellow/deep-learning-and-practice/hw7/dataset/task_2/' test_loader = DataLoader(dataset_test, batch_size=Batch_Size, shuffle=False, drop_last=True) model = Glow(image_shape, hparams['hidden_channels'], hparams['K'], hparams['L'], hparams['actnorm_scale'], hparams['flow_permutation'], hparams['flow_coupling'], hparams['LU_decomposed'], num_classes, hparams['learn_top'], hparams['y_condition']) model.load_state_dict( torch.load(output_folder + model_name, map_location="cpu")['model']) model.set_actnorm_init() model = model.to(device) model = model.eval() # attribute_list = [8] # Black_Hair attribute_list = [20, 31, 33] # Male, Smiling, Wavy_Hair, 24z No_Beard # attribute_list = [11, 26, 31, 8, 6, 7] # Brown_Hair, Pale_Skin, Smiling, Black_Hair, Big_Lips, Big_Nose # attribute_list = [i for i in range(40)] N = 8 z_pos_list = [torch.Tensor([]).cuda() for i in range(len(attribute_list))] z_neg_list = [torch.Tensor([]).cuda() for i in range(len(attribute_list))] z_input_img = None with torch.no_grad(): for i, (x, y) in enumerate(test_loader):
with open(output_folder + 'hparams.json') as json_file: hparams = json.load(json_file) test_mnist = train.MyMNIST(train=False, download=False) image_shape = (32, 32, 1) num_classes = 10 batch_size = 512 model = Glow(image_shape, hparams['hidden_channels'], hparams['K'], hparams['L'], hparams['actnorm_scale'], hparams['flow_permutation'], hparams['flow_coupling'], hparams['LU_decomposed'], num_classes, hparams['learn_top'], hparams['y_condition']) model.load_state_dict(torch.load(latest_model_path)) model.set_actnorm_init() model = model.to(device) model = model.eval() def sample(model): with torch.no_grad(): if hparams['y_condition']: y = torch.eye(num_classes) y = y.repeat(batch_size // num_classes + 1) y = y[:32, :].to(device) # number hardcoded in model for now else: y = None