def test(args, model): if args.dataset == "mnist": dataset_f = memory_mnist elif args.dataset == "fashion_mnist": dataset_f = memory_fashion elif args.dataset == "celeba": dataset_f = celeba else: raise ValueError("Unknown dataset:", args.dataset) args.delta = float(args.model_path.split(";")[-1].split("_")[0].split("#")[1]) repr_args = string_args(args) f = open(f"./test/ll_per_point_{repr_args}_.txt", "w") train_loader, val_loader, train_val_loader, train_labels, val_labels = dataset_f( 1, args.img_size, args.n_channels, return_y=True ) with torch.no_grad(): for ind, image in enumerate(train_loader): # TODO Rozkminić żeby było bez tego repeat image = image.repeat(100, 1, 1, 1) image = image.to(device) log_p, logdet, _ = model(image) for i in range(log_p.shape[0]): print( ind, args.delta, log_p[i].item(), logdet[i].item(), train_labels[ind].item(), file=f, ) if ind >= 9999: break f.close()
def train(args, model, optimizer): if args.dataset == "mnist": dataset_f = memory_mnist elif args.dataset == "fashion_mnist": dataset_f = memory_fashion repr_args = string_args(args) n_bins = 2.0**args.n_bits z_sample = [] z_shapes = calc_z_shapes(args.n_channels, args.img_size, args.n_flow, args.n_block) for z in z_shapes: z_new = torch.randn(args.n_sample, *z) * args.temp z_sample.append(z_new.to(device)) deltas = create_deltas_sequence(0.1, 0.005) args.delta = deltas[0] epoch_losses = [] f_train_loss = open(f"losses/seq_losses_train_{repr_args}_.txt", "w", buffering=1) f_test_loss = open(f"losses/seq_losses_test_{repr_args}_.txt", "w", buffering=1) with tqdm(range(200)) as pbar: for i in pbar: args.delta = deltas[i] repr_args = string_args(args) train_loader, val_loader, train_val_loader = dataset_f( args.batch, args.img_size, args.n_channels) train_losses = [] for image in train_loader: optimizer.zero_grad() image = image.to(device) if args.tr_dq: noisy_image += torch.rand_like(image) / n_bins noisy_image += torch.randn_like(image) * args.delta log_p, logdet, _ = model(noisy_image) logdet = logdet.mean() loss, log_p, log_det = calc_loss(log_p, logdet, args.img_size, n_bins, args.n_channels) loss.backward() optimizer.step() train_losses.append(loss.item()) current_train_loss = np.mean(train_losses) print(f"{current_train_loss},{args.delta},{i + 1}", file=f_train_loss) with torch.no_grad(): utils.save_image( model.reverse(z_sample).cpu().data, f"sample/seq_sample_{repr_args}_{str(i + 1).zfill(6)}.png", normalize=True, nrow=10, range=(-0.5, 0.5), ) losses = [] logdets = [] logps = [] for image in val_loader: image = image.to(device) noisy_image = image if args.te_dq: noisy_image += torch.rand_like(image) / n_bins if args.te_noise: noisy_image += torch.randn_like(image) * args.delta log_p, logdet, _ = model(noisy_image) logdet = logdet.mean() loss, log_p, log_det = calc_loss(log_p, logdet, args.img_size, n_bins, args.n_channels) losses.append(loss.item()) logdets.append(log_det.item()) logps.append(log_p.item()) pbar.set_description( f"Loss: {np.mean(losses):.5f}; logP: {np.mean(logps):.5f}; logdet: {np.mean(logdets):.5f}; delta: {args.delta:.5f}" ) current_loss = np.mean(losses) print(f"{current_loss},{args.delta},{i + 1}", file=f_test_loss) epoch_losses.append(current_loss) if (i + 1) % 10 == 0: torch.save( model.state_dict(), f"checkpoint/seq_model_{repr_args}_{i + 1}_.pt", ) f_ll = open(f"ll/seq_ll_{repr_args}_{i + 1}.txt", "w") train_loader, val_loader, train_val_loader = dataset_f( args.batch, args.img_size, args.n_channels) train_val_loader = iter(train_val_loader) for image_val in val_loader: image = image_val image = image.to(device) if args.te_dq: noisy_image += torch.rand_like(image) / n_bins if args.te_noise: noisy_image += torch.randn_like(image) * args.delta log_p_val, logdet_val, _ = model(noisy_image) image = next(train_val_loader) image = image.to(device) if args.te_dq: noisy_image += torch.rand_like(image) / n_bins if args.te_noise: noisy_image += torch.randn_like(image) * args.delta log_p_train_val, logdet_train_val, _ = model(noisy_image) for ( lpv, ldv, lptv, ldtv, ) in zip(log_p_val, logdet_val, log_p_train_val, logdet_train_val): print( args.delta, lpv.item(), ldv.item(), lptv.item(), ldtv.item(), file=f_ll, ) f_ll.close() f_train_loss.close() f_test_loss.close()
print( ind, args.delta, log_p[i].item(), logdet[i].item(), train_labels[ind].item(), file=f, ) if ind >= 9999: break f.close() if __name__ == "__main__": args = parser.parse_args() print(string_args(args)) device = args.device model_single = Glow( args.n_channels, args.n_flow, args.n_block, affine=args.affine, conv_lu=not args.no_lu, ) model = model_single model.load_state_dict(torch.load(args.model_path)) model = model.to(device) test(args, model)
def train(args, model, optimizer): if args.dataset == "mnist": dataset_f = memory_mnist elif args.dataset == "fashion_mnist": dataset_f = memory_fashion elif args.dataset == "celeba": dataset_f = celeba elif args.dataset == "ffhq_gan_32": dataset_f = ffhq_gan_32 elif args.dataset == "cifar_horses_40": dataset_f = cifar_horses_40 elif args.dataset == "ffhq_50": dataset_f = ffhq_50 elif args.dataset == "cifar_horses_20": dataset_f = cifar_horses_20 elif args.dataset == "cifar_horses_80": dataset_f = cifar_horses_80 elif args.dataset == "mnist_30": dataset_f = mnist_30 elif args.dataset == "mnist_gan_all": dataset_f = mnist_gan_all elif args.dataset == "mnist_pad": dataset_f = mnist_pad elif args.dataset == "cifar_horses_20_top": dataset_f = cifar_horses_20_top elif args.dataset == "cifar_horses_40_top": dataset_f = cifar_horses_40_top elif args.dataset == "cifar_horses_20_top_small_lr": dataset_f = cifar_horses_20_top_small_lr elif args.dataset == "cifar_horses_40_top_small_lr": dataset_f = cifar_horses_40_top_small_lr elif args.dataset == "arrows_small": dataset_f = arrows_small elif args.dataset == "arrows_big": dataset_f = arrows_big elif args.dataset == "cifar_20_picked_inds_2": dataset_f = cifar_20_picked_inds_2 elif args.dataset == "cifar_40_picked_inds_2": dataset_f = cifar_40_picked_inds_2 elif args.dataset == "cifar_40_picked_inds_3": dataset_f = cifar_40_picked_inds_3 elif args.dataset == "cifar_20_picked_inds_3": dataset_f = cifar_20_picked_inds_3 else: raise ValueError("Unknown dataset:", args.dataset) repr_args = string_args(args) n_bins = 2.0**args.n_bits z_sample = [] z_shapes = calc_z_shapes(args.n_channels, args.img_size, args.n_flow, args.n_block) for z in z_shapes: z_new = torch.randn(args.n_sample, *z) * args.temp z_sample.append(z_new.to(device)) epoch_losses = [] f_train_loss = open(f"losses/losses_train_{repr_args}_.txt", "a", buffering=1) f_test_loss = open(f"losses/losses_test_{repr_args}_.txt", "a", buffering=1) last_model_path = f"checkpoint/model_{repr_args}_last_.pt" try: model.load_state_dict(torch.load(last_model_path)) model.eval() f_epoch = open(f"checkpoint/last_epoch_{repr_args}.txt", "r", buffering=1) epoch_n = int(f_epoch.readline().strip()) f_epoch.close() except FileNotFoundError: print("Training the model from scratch.") epoch_n = 0 with tqdm(range(epoch_n, args.epochs + epoch_n)) as pbar: for i in pbar: repr_args = string_args(args) train_loader, val_loader, train_val_loader = dataset_f( args.batch, args.img_size, args.n_channels) train_losses = [] for image in train_loader: if isinstance(image, list): image = image[0] optimizer.zero_grad() image = image.to(device) noisy_image = image if args.tr_dq: noisy_image += torch.rand_like(image) / n_bins noisy_image += torch.randn_like(image) * args.delta log_p, logdet, _ = model(noisy_image) logdet = logdet.mean() loss, log_p, log_det = calc_loss(log_p, logdet, args.img_size, n_bins, args.n_channels) loss.backward() optimizer.step() train_losses.append(loss.item()) current_train_loss = np.mean(train_losses) print(f"{current_train_loss},{args.delta},{i + 1}", file=f_train_loss) with torch.no_grad(): utils.save_image( model.reverse(z_sample).cpu().data, f"sample/sample_{repr_args}_{str(i + 1).zfill(6)}.png", normalize=True, nrow=10, range=(-0.5, 0.5), ) losses = [] logdets = [] logps = [] for image in val_loader: if isinstance(image, list): image = image[0] image = image.to(device) log_p, logdet, _ = model(image) logdet = logdet.mean() loss, log_p, log_det = calc_loss(log_p, logdet, args.img_size, n_bins, args.n_channels) losses.append(loss.item()) logdets.append(log_det.item()) logps.append(log_p.item()) pbar.set_description( f"Loss: {np.mean(losses):.5f}; logP: {np.mean(logps):.5f}; logdet: {np.mean(logdets):.5f}; delta: {args.delta:.5f}" ) current_loss = np.mean(losses) print(f"{current_loss},{args.delta},{i + 1}", file=f_test_loss) epoch_losses.append(current_loss) # early stopping if len(epoch_losses) >= 20 and epoch_losses[-20] < min( epoch_losses[-19:]): break ''' too much space if (i + 1) % 5 == 0: torch.save( model.state_dict(), f"checkpoint/model_{repr_args}_{i + 1}_.pt" ) ''' torch.save(model.state_dict(), last_model_path) f_epoch = open(f"checkpoint/last_epoch_{repr_args}.txt", "w", buffering=1) f_epoch.write(str(i + 1)) f_epoch.close() f_ll = open(f"ll/ll_{repr_args}_{i + 1}.txt", "w") train_loader, val_loader, train_val_loader = dataset_f( args.batch, args.img_size, args.n_channels) train_val_loader = iter(train_val_loader) for image_val in val_loader: image = image_val if isinstance(image, list): image = image[0] image = image.to(device) log_p_val, logdet_val, _ = model(image) image = next(train_val_loader) if isinstance(image, list): image = image[0] image = image.to(device) log_p_train_val, logdet_train_val, _ = model(image) for ( lpv, ldv, lptv, ldtv, ) in zip(log_p_val, logdet_val, log_p_train_val, logdet_train_val): print( args.delta, lpv.item(), ldv.item(), lptv.item(), ldtv.item(), file=f_ll, ) f_ll.close() f_train_loss.close() f_test_loss.close()