def train_pbigan(args):
    torch.manual_seed(args.seed)

    if args.mask == 'indep':
        data = IndepMaskedCelebA(obs_prob=args.obs_prob)
        mask_str = f'{args.mask}_{args.obs_prob}'
    elif args.mask == 'block':
        data = BlockMaskedCelebA(block_len=args.block_len)
        mask_str = f'{args.mask}_{args.block_len}'

    data_loader = DataLoader(data,
                             batch_size=args.batch_size,
                             shuffle=True,
                             drop_last=True)
    mask_loader = DataLoader(data,
                             batch_size=args.batch_size,
                             shuffle=True,
                             drop_last=True)

    test_loader = DataLoader(data, batch_size=args.batch_size, drop_last=True)

    decoder = ConvDecoder(args.latent)
    encoder = ConvEncoder(args.latent, args.flow, logprob=False)
    pbigan = PBiGAN(encoder, decoder, args.aeloss).to(device)

    critic = ConvCritic(args.latent).to(device)

    optimizer = optim.Adam(pbigan.parameters(), lr=args.lr, betas=(.5, .9))

    critic_optimizer = optim.Adam(critic.parameters(),
                                  lr=args.lr,
                                  betas=(.5, .9))

    grad_penalty = GradientPenalty(critic, args.batch_size)

    scheduler = make_scheduler(optimizer, args.lr, args.min_lr, args.epoch)

    path = '{}_{}_{}'.format(args.prefix,
                             datetime.now().strftime('%m%d.%H%M%S'), mask_str)
    output_dir = Path('results') / 'celeba-pbigan' / path
    mkdir(output_dir)
    print(output_dir)

    if args.save_interval > 0:
        model_dir = mkdir(output_dir / 'model')

    with (output_dir / 'args.txt').open('w') as f:
        print(pprint.pformat(vars(args)), file=f)

    vis = Visualizer(output_dir, loss_xlim=(0, args.epoch))

    test_x, test_mask, index = iter(test_loader).next()
    test_x = test_x.to(device)
    test_mask = test_mask.to(device).float()
    bbox = None
    if data.mask_loc is not None:
        bbox = [data.mask_loc[idx] for idx in index]

    n_critic = 5
    critic_updates = 0
    ae_weight = 0

    for epoch in range(args.epoch):
        loss_breakdown = defaultdict(float)

        if epoch >= args.ae_start:
            ae_weight = args.ae

        for (x, mask, _), (_, mask_gen, _) in zip(data_loader, mask_loader):
            x = x.to(device)
            mask = mask.to(device).float()
            mask_gen = mask_gen.to(device).float()

            if critic_updates < n_critic:
                z_enc, z_gen, x_rec, x_gen, _ = pbigan(x, mask, ae=False)

                real_score = critic((x * mask, z_enc)).mean()
                fake_score = critic((x_gen * mask_gen, z_gen)).mean()

                w_dist = real_score - fake_score
                D_loss = -w_dist + grad_penalty((x * mask, z_enc),
                                                (x_gen * mask_gen, z_gen))

                critic_optimizer.zero_grad()
                D_loss.backward()
                critic_optimizer.step()

                loss_breakdown['D'] += D_loss.item()

                critic_updates += 1
            else:
                critic_updates = 0

                # Update generators' parameters
                for p in critic.parameters():
                    p.requires_grad_(False)

                z_enc, z_gen, x_rec, x_gen, ae_loss = pbigan(x,
                                                             mask,
                                                             ae=(args.ae > 0))

                real_score = critic((x * mask, z_enc)).mean()
                fake_score = critic((x_gen * mask_gen, z_gen)).mean()

                G_loss = real_score - fake_score

                ae_loss = ae_loss * ae_weight
                loss = G_loss + ae_loss

                mmd_loss = 0
                if args.mmd > 0:
                    mmd_loss = mmd(z_enc, z_gen)
                    loss += mmd_loss * args.mmd

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                loss_breakdown['G'] += G_loss.item()
                if torch.is_tensor(ae_loss):
                    loss_breakdown['AE'] += ae_loss.item()
                if torch.is_tensor(mmd_loss):
                    loss_breakdown['MMD'] += mmd_loss.item()
                loss_breakdown['total'] += loss.item()

                for p in critic.parameters():
                    p.requires_grad_(True)

        if scheduler:
            scheduler.step()

        vis.plot_loss(epoch, loss_breakdown)

        if epoch % args.plot_interval == 0:
            with torch.no_grad():
                pbigan.eval()
                z, z_gen, x_rec, x_gen, ae_loss = pbigan(test_x, test_mask)
                pbigan.train()
            vis.plot(epoch, test_x, test_mask, bbox, x_rec, x_gen)

        model_dict = {
            'pbigan': pbigan.state_dict(),
            'critic': critic.state_dict(),
            'history': vis.history,
            'epoch': epoch,
            'args': args,
        }
        torch.save(model_dict, str(output_dir / 'model.pth'))
        if args.save_interval > 0 and (epoch + 1) % args.save_interval == 0:
            torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))

    print(output_dir)
def train_pbigan(args):
    torch.manual_seed(args.seed)

    if args.mask == 'indep':
        data = IndepMaskedMNIST(obs_prob=args.obs_prob,
                                obs_prob_max=args.obs_prob_max)
        mask_str = f'{args.mask}_{args.obs_prob}_{args.obs_prob_max}'
    elif args.mask == 'block':
        data = BlockMaskedMNIST(block_len=args.block_len,
                                block_len_max=args.block_len_max)
        mask_str = f'{args.mask}_{args.block_len}_{args.block_len_max}'

    data_loader = DataLoader(data,
                             batch_size=args.batch_size,
                             shuffle=True,
                             drop_last=True)
    mask_loader = DataLoader(data,
                             batch_size=args.batch_size,
                             shuffle=True,
                             drop_last=True)

    # Evaluate the training progress using 2000 examples from the training data
    test_loader = DataLoader(data, batch_size=args.batch_size, drop_last=True)

    decoder = ConvDecoder(args.latent)
    encoder = ConvEncoder(args.latent, args.flow, logprob=False)
    pbigan = PBiGAN(encoder, decoder, args.aeloss).to(device)

    critic = ConvCritic(args.latent).to(device)

    lrate = 1e-4
    optimizer = optim.Adam(pbigan.parameters(), lr=lrate, betas=(.5, .9))

    critic_optimizer = optim.Adam(critic.parameters(),
                                  lr=lrate,
                                  betas=(.5, .9))

    grad_penalty = GradientPenalty(critic, args.batch_size)

    scheduler = make_scheduler(optimizer, args.lr, args.min_lr, args.epoch)

    path = '{}_{}_{}'.format(args.prefix,
                             datetime.now().strftime('%m%d.%H%M%S'), mask_str)
    output_dir = Path('results') / 'mnist-pbigan' / path
    mkdir(output_dir)
    print(output_dir)

    if args.save_interval > 0:
        model_dir = mkdir(output_dir / 'model')

    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        handlers=[
            logging.FileHandler(output_dir / 'log.txt'),
            logging.StreamHandler(sys.stdout),
        ],
    )

    with (output_dir / 'args.txt').open('w') as f:
        print(pprint.pformat(vars(args)), file=f)

    vis = Visualizer(output_dir)

    test_x, test_mask, index = iter(test_loader).next()
    test_x = test_x.to(device)
    test_mask = test_mask.to(device).float()
    bbox = None
    if data.mask_loc is not None:
        bbox = [data.mask_loc[idx] for idx in index]

    n_critic = 5
    critic_updates = 0
    ae_weight = 0
    ae_flat = 100

    for epoch in range(args.epoch):
        loss_breakdown = defaultdict(float)

        if epoch > ae_flat:
            ae_weight = args.ae * (epoch - ae_flat) / (args.epoch - ae_flat)

        for (x, mask, _), (_, mask_gen, _) in zip(data_loader, mask_loader):
            x = x.to(device)
            mask = mask.to(device).float()
            mask_gen = mask_gen.to(device).float()

            z_enc, z_gen, x_rec, x_gen, _ = pbigan(x, mask, ae=False)

            real_score = critic((x * mask, z_enc)).mean()
            fake_score = critic((x_gen * mask_gen, z_gen)).mean()

            w_dist = real_score - fake_score
            D_loss = -w_dist + grad_penalty((x * mask, z_enc),
                                            (x_gen * mask_gen, z_gen))

            critic_optimizer.zero_grad()
            D_loss.backward()
            critic_optimizer.step()

            loss_breakdown['D'] += D_loss.item()

            critic_updates += 1

            if critic_updates == n_critic:
                critic_updates = 0

                # Update generators' parameters
                for p in critic.parameters():
                    p.requires_grad_(False)

                z_enc, z_gen, x_rec, x_gen, ae_loss = pbigan(x, mask)

                real_score = critic((x * mask, z_enc)).mean()
                fake_score = critic((x_gen * mask_gen, z_gen)).mean()

                G_loss = real_score - fake_score

                ae_loss = ae_loss * ae_weight
                loss = G_loss + ae_loss

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                loss_breakdown['G'] += G_loss.item()
                loss_breakdown['AE'] += ae_loss.item()
                loss_breakdown['total'] += loss.item()

                for p in critic.parameters():
                    p.requires_grad_(True)

        if scheduler:
            scheduler.step()

        vis.plot_loss(epoch, loss_breakdown)

        if epoch % args.plot_interval == 0:
            with torch.no_grad():
                pbigan.eval()
                z, z_gen, x_rec, x_gen, ae_loss = pbigan(test_x, test_mask)
                pbigan.train()
            vis.plot(epoch, test_x, test_mask, bbox, x_rec, x_gen)

        model_dict = {
            'pbigan': pbigan.state_dict(),
            'critic': critic.state_dict(),
            'history': vis.history,
            'epoch': epoch,
            'args': args,
        }
        torch.save(model_dict, str(output_dir / 'model.pth'))
        if args.save_interval > 0 and (epoch + 1) % args.save_interval == 0:
            torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))

    print(output_dir)
def train_pvae(args):
    torch.manual_seed(args.seed)

    if args.mask == 'indep':
        data = IndepMaskedMNIST(obs_prob=args.obs_prob,
                                obs_prob_max=args.obs_prob_max)
        mask_str = f'{args.mask}_{args.obs_prob}_{args.obs_prob_max}'
    elif args.mask == 'block':
        data = BlockMaskedMNIST(block_len=args.block_len,
                                block_len_max=args.block_len_max)
        mask_str = f'{args.mask}_{args.block_len}_{args.block_len_max}'

    data_loader = DataLoader(data,
                             batch_size=args.batch_size,
                             shuffle=True,
                             drop_last=True)

    # Evaluate the training progress using 2000 examples from the training data
    test_loader = DataLoader(data, batch_size=args.batch_size, drop_last=True)

    decoder = ConvDecoder(args.latent)
    encoder = ConvEncoder(args.latent, args.flow, logprob=True)
    pvae = PVAE(encoder, decoder).to(device)

    optimizer = optim.Adam(pvae.parameters(), lr=args.lr)
    scheduler = make_scheduler(optimizer, args.lr, args.min_lr, args.epoch)

    rand_z = torch.empty(args.batch_size, args.latent, device=device)

    path = '{}_{}_{}'.format(args.prefix,
                             datetime.now().strftime('%m%d.%H%M%S'), mask_str)
    output_dir = Path('results') / 'mnist-pvae' / path
    mkdir(output_dir)
    print(output_dir)

    if args.save_interval > 0:
        model_dir = mkdir(output_dir / 'model')

    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        handlers=[
            logging.FileHandler(output_dir / 'log.txt'),
            logging.StreamHandler(sys.stdout),
        ],
    )

    with (output_dir / 'args.txt').open('w') as f:
        print(pprint.pformat(vars(args)), file=f)

    vis = Visualizer(output_dir)

    test_x, test_mask, index = iter(test_loader).next()
    test_x = test_x.to(device)
    test_mask = test_mask.to(device).float()
    bbox = None
    if data.mask_loc is not None:
        bbox = [data.mask_loc[idx] for idx in index]

    kl_center = (args.kl_on + args.kl_off) / 2
    kl_scale = 12 / min(args.kl_on - args.kl_off, 1)

    for epoch in range(args.epoch):
        if epoch >= args.kl_on:
            kl_weight = 1
        elif epoch < args.kl_off:
            kl_weight = 0
        else:
            kl_weight = 1 / (1 + math.exp(-(epoch - kl_center) * kl_scale))
        loss_breakdown = defaultdict(float)
        for x, mask, _ in data_loader:
            x = x.to(device)
            mask = mask.to(device).float()

            optimizer.zero_grad()
            loss, _, _, _, loss_info = pvae(x,
                                            mask,
                                            args.k,
                                            kl_weight=kl_weight)
            loss.backward()
            optimizer.step()
            for name, val in loss_info.items():
                loss_breakdown[name] += val

        if scheduler:
            scheduler.step()

        vis.plot_loss(epoch, loss_breakdown)

        if epoch % args.plot_interval == 0:
            x_recon = pvae.impute(test_x, test_mask, args.k)
            with torch.no_grad():
                pvae.eval()
                rand_z.normal_()
                _, x_gen = decoder(rand_z)
                pvae.train()
            vis.plot(epoch, test_x, test_mask, bbox, x_recon, x_gen)

        model_dict = {
            'pvae': pvae.state_dict(),
            'history': vis.history,
            'epoch': epoch,
            'args': args,
        }
        torch.save(model_dict, str(output_dir / 'model.pth'))
        if args.save_interval > 0 and (epoch + 1) % args.save_interval == 0:
            torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))

    print(output_dir)