def train_pbigan(args): torch.manual_seed(args.seed) if args.mask == 'indep': data = IndepMaskedCelebA(obs_prob=args.obs_prob) mask_str = f'{args.mask}_{args.obs_prob}' elif args.mask == 'block': data = BlockMaskedCelebA(block_len=args.block_len) mask_str = f'{args.mask}_{args.block_len}' data_loader = DataLoader(data, batch_size=args.batch_size, shuffle=True, drop_last=True) mask_loader = DataLoader(data, batch_size=args.batch_size, shuffle=True, drop_last=True) test_loader = DataLoader(data, batch_size=args.batch_size, drop_last=True) decoder = ConvDecoder(args.latent) encoder = ConvEncoder(args.latent, args.flow, logprob=False) pbigan = PBiGAN(encoder, decoder, args.aeloss).to(device) critic = ConvCritic(args.latent).to(device) optimizer = optim.Adam(pbigan.parameters(), lr=args.lr, betas=(.5, .9)) critic_optimizer = optim.Adam(critic.parameters(), lr=args.lr, betas=(.5, .9)) grad_penalty = GradientPenalty(critic, args.batch_size) scheduler = make_scheduler(optimizer, args.lr, args.min_lr, args.epoch) path = '{}_{}_{}'.format(args.prefix, datetime.now().strftime('%m%d.%H%M%S'), mask_str) output_dir = Path('results') / 'celeba-pbigan' / path mkdir(output_dir) print(output_dir) if args.save_interval > 0: model_dir = mkdir(output_dir / 'model') with (output_dir / 'args.txt').open('w') as f: print(pprint.pformat(vars(args)), file=f) vis = Visualizer(output_dir, loss_xlim=(0, args.epoch)) test_x, test_mask, index = iter(test_loader).next() test_x = test_x.to(device) test_mask = test_mask.to(device).float() bbox = None if data.mask_loc is not None: bbox = [data.mask_loc[idx] for idx in index] n_critic = 5 critic_updates = 0 ae_weight = 0 for epoch in range(args.epoch): loss_breakdown = defaultdict(float) if epoch >= args.ae_start: ae_weight = args.ae for (x, mask, _), (_, mask_gen, _) in zip(data_loader, mask_loader): x = x.to(device) mask = mask.to(device).float() mask_gen = mask_gen.to(device).float() if critic_updates < n_critic: z_enc, z_gen, x_rec, x_gen, _ = pbigan(x, mask, ae=False) real_score = critic((x * mask, z_enc)).mean() fake_score = critic((x_gen * mask_gen, z_gen)).mean() w_dist = real_score - fake_score D_loss = -w_dist + grad_penalty((x * mask, z_enc), (x_gen * mask_gen, z_gen)) critic_optimizer.zero_grad() D_loss.backward() critic_optimizer.step() loss_breakdown['D'] += D_loss.item() critic_updates += 1 else: critic_updates = 0 # Update generators' parameters for p in critic.parameters(): p.requires_grad_(False) z_enc, z_gen, x_rec, x_gen, ae_loss = pbigan(x, mask, ae=(args.ae > 0)) real_score = critic((x * mask, z_enc)).mean() fake_score = critic((x_gen * mask_gen, z_gen)).mean() G_loss = real_score - fake_score ae_loss = ae_loss * ae_weight loss = G_loss + ae_loss mmd_loss = 0 if args.mmd > 0: mmd_loss = mmd(z_enc, z_gen) loss += mmd_loss * args.mmd optimizer.zero_grad() loss.backward() optimizer.step() loss_breakdown['G'] += G_loss.item() if torch.is_tensor(ae_loss): loss_breakdown['AE'] += ae_loss.item() if torch.is_tensor(mmd_loss): loss_breakdown['MMD'] += mmd_loss.item() loss_breakdown['total'] += loss.item() for p in critic.parameters(): p.requires_grad_(True) if scheduler: scheduler.step() vis.plot_loss(epoch, loss_breakdown) if epoch % args.plot_interval == 0: with torch.no_grad(): pbigan.eval() z, z_gen, x_rec, x_gen, ae_loss = pbigan(test_x, test_mask) pbigan.train() vis.plot(epoch, test_x, test_mask, bbox, x_rec, x_gen) model_dict = { 'pbigan': pbigan.state_dict(), 'critic': critic.state_dict(), 'history': vis.history, 'epoch': epoch, 'args': args, } torch.save(model_dict, str(output_dir / 'model.pth')) if args.save_interval > 0 and (epoch + 1) % args.save_interval == 0: torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth')) print(output_dir)
def train_pbigan(args): torch.manual_seed(args.seed) if args.mask == 'indep': data = IndepMaskedMNIST(obs_prob=args.obs_prob, obs_prob_max=args.obs_prob_max) mask_str = f'{args.mask}_{args.obs_prob}_{args.obs_prob_max}' elif args.mask == 'block': data = BlockMaskedMNIST(block_len=args.block_len, block_len_max=args.block_len_max) mask_str = f'{args.mask}_{args.block_len}_{args.block_len_max}' data_loader = DataLoader(data, batch_size=args.batch_size, shuffle=True, drop_last=True) mask_loader = DataLoader(data, batch_size=args.batch_size, shuffle=True, drop_last=True) # Evaluate the training progress using 2000 examples from the training data test_loader = DataLoader(data, batch_size=args.batch_size, drop_last=True) decoder = ConvDecoder(args.latent) encoder = ConvEncoder(args.latent, args.flow, logprob=False) pbigan = PBiGAN(encoder, decoder, args.aeloss).to(device) critic = ConvCritic(args.latent).to(device) lrate = 1e-4 optimizer = optim.Adam(pbigan.parameters(), lr=lrate, betas=(.5, .9)) critic_optimizer = optim.Adam(critic.parameters(), lr=lrate, betas=(.5, .9)) grad_penalty = GradientPenalty(critic, args.batch_size) scheduler = make_scheduler(optimizer, args.lr, args.min_lr, args.epoch) path = '{}_{}_{}'.format(args.prefix, datetime.now().strftime('%m%d.%H%M%S'), mask_str) output_dir = Path('results') / 'mnist-pbigan' / path mkdir(output_dir) print(output_dir) if args.save_interval > 0: model_dir = mkdir(output_dir / 'model') logging.basicConfig( level=logging.INFO, format='%(asctime)s %(message)s', datefmt='%Y-%m-%d %H:%M:%S', handlers=[ logging.FileHandler(output_dir / 'log.txt'), logging.StreamHandler(sys.stdout), ], ) with (output_dir / 'args.txt').open('w') as f: print(pprint.pformat(vars(args)), file=f) vis = Visualizer(output_dir) test_x, test_mask, index = iter(test_loader).next() test_x = test_x.to(device) test_mask = test_mask.to(device).float() bbox = None if data.mask_loc is not None: bbox = [data.mask_loc[idx] for idx in index] n_critic = 5 critic_updates = 0 ae_weight = 0 ae_flat = 100 for epoch in range(args.epoch): loss_breakdown = defaultdict(float) if epoch > ae_flat: ae_weight = args.ae * (epoch - ae_flat) / (args.epoch - ae_flat) for (x, mask, _), (_, mask_gen, _) in zip(data_loader, mask_loader): x = x.to(device) mask = mask.to(device).float() mask_gen = mask_gen.to(device).float() z_enc, z_gen, x_rec, x_gen, _ = pbigan(x, mask, ae=False) real_score = critic((x * mask, z_enc)).mean() fake_score = critic((x_gen * mask_gen, z_gen)).mean() w_dist = real_score - fake_score D_loss = -w_dist + grad_penalty((x * mask, z_enc), (x_gen * mask_gen, z_gen)) critic_optimizer.zero_grad() D_loss.backward() critic_optimizer.step() loss_breakdown['D'] += D_loss.item() critic_updates += 1 if critic_updates == n_critic: critic_updates = 0 # Update generators' parameters for p in critic.parameters(): p.requires_grad_(False) z_enc, z_gen, x_rec, x_gen, ae_loss = pbigan(x, mask) real_score = critic((x * mask, z_enc)).mean() fake_score = critic((x_gen * mask_gen, z_gen)).mean() G_loss = real_score - fake_score ae_loss = ae_loss * ae_weight loss = G_loss + ae_loss optimizer.zero_grad() loss.backward() optimizer.step() loss_breakdown['G'] += G_loss.item() loss_breakdown['AE'] += ae_loss.item() loss_breakdown['total'] += loss.item() for p in critic.parameters(): p.requires_grad_(True) if scheduler: scheduler.step() vis.plot_loss(epoch, loss_breakdown) if epoch % args.plot_interval == 0: with torch.no_grad(): pbigan.eval() z, z_gen, x_rec, x_gen, ae_loss = pbigan(test_x, test_mask) pbigan.train() vis.plot(epoch, test_x, test_mask, bbox, x_rec, x_gen) model_dict = { 'pbigan': pbigan.state_dict(), 'critic': critic.state_dict(), 'history': vis.history, 'epoch': epoch, 'args': args, } torch.save(model_dict, str(output_dir / 'model.pth')) if args.save_interval > 0 and (epoch + 1) % args.save_interval == 0: torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth')) print(output_dir)
def train_pvae(args): torch.manual_seed(args.seed) if args.mask == 'indep': data = IndepMaskedMNIST(obs_prob=args.obs_prob, obs_prob_max=args.obs_prob_max) mask_str = f'{args.mask}_{args.obs_prob}_{args.obs_prob_max}' elif args.mask == 'block': data = BlockMaskedMNIST(block_len=args.block_len, block_len_max=args.block_len_max) mask_str = f'{args.mask}_{args.block_len}_{args.block_len_max}' data_loader = DataLoader(data, batch_size=args.batch_size, shuffle=True, drop_last=True) # Evaluate the training progress using 2000 examples from the training data test_loader = DataLoader(data, batch_size=args.batch_size, drop_last=True) decoder = ConvDecoder(args.latent) encoder = ConvEncoder(args.latent, args.flow, logprob=True) pvae = PVAE(encoder, decoder).to(device) optimizer = optim.Adam(pvae.parameters(), lr=args.lr) scheduler = make_scheduler(optimizer, args.lr, args.min_lr, args.epoch) rand_z = torch.empty(args.batch_size, args.latent, device=device) path = '{}_{}_{}'.format(args.prefix, datetime.now().strftime('%m%d.%H%M%S'), mask_str) output_dir = Path('results') / 'mnist-pvae' / path mkdir(output_dir) print(output_dir) if args.save_interval > 0: model_dir = mkdir(output_dir / 'model') logging.basicConfig( level=logging.INFO, format='%(asctime)s %(message)s', datefmt='%Y-%m-%d %H:%M:%S', handlers=[ logging.FileHandler(output_dir / 'log.txt'), logging.StreamHandler(sys.stdout), ], ) with (output_dir / 'args.txt').open('w') as f: print(pprint.pformat(vars(args)), file=f) vis = Visualizer(output_dir) test_x, test_mask, index = iter(test_loader).next() test_x = test_x.to(device) test_mask = test_mask.to(device).float() bbox = None if data.mask_loc is not None: bbox = [data.mask_loc[idx] for idx in index] kl_center = (args.kl_on + args.kl_off) / 2 kl_scale = 12 / min(args.kl_on - args.kl_off, 1) for epoch in range(args.epoch): if epoch >= args.kl_on: kl_weight = 1 elif epoch < args.kl_off: kl_weight = 0 else: kl_weight = 1 / (1 + math.exp(-(epoch - kl_center) * kl_scale)) loss_breakdown = defaultdict(float) for x, mask, _ in data_loader: x = x.to(device) mask = mask.to(device).float() optimizer.zero_grad() loss, _, _, _, loss_info = pvae(x, mask, args.k, kl_weight=kl_weight) loss.backward() optimizer.step() for name, val in loss_info.items(): loss_breakdown[name] += val if scheduler: scheduler.step() vis.plot_loss(epoch, loss_breakdown) if epoch % args.plot_interval == 0: x_recon = pvae.impute(test_x, test_mask, args.k) with torch.no_grad(): pvae.eval() rand_z.normal_() _, x_gen = decoder(rand_z) pvae.train() vis.plot(epoch, test_x, test_mask, bbox, x_recon, x_gen) model_dict = { 'pvae': pvae.state_dict(), 'history': vis.history, 'epoch': epoch, 'args': args, } torch.save(model_dict, str(output_dir / 'model.pth')) if args.save_interval > 0 and (epoch + 1) % args.save_interval == 0: torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth')) print(output_dir)