Example #1
0
File: loss.py Project: deJQK/CAT
 def __init__(self):
     super(VGGLoss, self).__init__()
     self.vgg = VGG19()
     self.vgg.eval()
     util.set_requires_grad(self.vgg, False)
     self.criterion = nn.L1Loss()
     self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
 def optimize_parameters(self, steps):
     self.optimizer_D.zero_grad()
     self.optimizer_G.zero_grad()
     self.forward()
     util.set_requires_grad(self.netD, True)
     self.backward_D()
     util.set_requires_grad(self.netD, False)
     self.backward_G()
     self.optimizer_D.step()
     self.optimizer_G.step()
 def optimize_parameters(self):
     self.optimizer_D.zero_grad()
     self.optimizer_G.zero_grad()
     config = self.configs.sample()
     self.forward(config=config)
     util.set_requires_grad(self.netD, True)
     self.backward_D()
     util.set_requires_grad(self.netD, False)
     self.backward_G()
     self.optimizer_D.step()
     self.optimizer_G.step()
Example #4
0
def optimize(opt):
    dataset_name = 'cat'
    generator_name = 'stylegan2'

    transform = data.get_transform(dataset_name, 'im2tensor')
    dset = data.get_dataset(dataset_name,
                            opt.partition,
                            load_w=False,
                            transform=transform)

    total = len(dset)
    if opt.indices is None:
        start_idx = 0
        end_idx = total
    else:
        start_idx = opt.indices[0]
        end_idx = opt.indices[1]

    print("Optimizing dataset partition %s items %d to %d" %
          (opt.partition, start_idx, end_idx))

    generator = domain_generator.define_generator(generator_name,
                                                  dataset_name,
                                                  load_encoder=True)
    util.set_requires_grad(False, generator.generator)
    util.set_requires_grad(False, generator.encoder)

    for i in range(start_idx, end_idx):
        (im, label, path) = dset[i]
        img_filename = os.path.splitext(os.path.basename(path))[0]

        print("Running %d / %d images: %s" % (i, end_idx, img_filename))

        output_filename = os.path.join(opt.w_path, img_filename)
        if os.path.isfile(output_filename + '.pth'):
            print(output_filename + '.pth found... skipping')
            continue

        # cat face dataset is already centered
        centered_im = im[None].cuda()
        # find zero values to estimate the mask
        mask = torch.ones_like(centered_im)
        mask[torch.where(
            torch.sum(torch.abs(centered_im), axis=0, keepdims=True) < 0.02
        )] = 0
        mask = mask[:, :1, :, :].cuda()
        ckpt, loss = generator.optimize(centered_im, mask=mask)

        w_optimized = ckpt['current_z']
        loss = np.array(loss).squeeze()
        im_optimized = renormalize.as_image(ckpt['current_x'][0])
        torch.save({'w': w_optimized.detach().cpu()}, output_filename + '.pth')
        np.savez(output_filename + '_loss.npz', loss=loss)
        im_optimized.save(output_filename + '_optimized_im.png')
Example #5
0
 def optimize_parameters(self, steps):
     need_style_encoder = False if self.opt.student_no_style_encoder \
         else steps % self.opt.style_encoder_step != 0
     self.optimizer_D.zero_grad()
     self.optimizer_G.zero_grad()
     config = self.configs.sample()
     self.forward(config=config, need_style_encoder=need_style_encoder)
     util.set_requires_grad(self.netD, True)
     self.backward_D()
     util.set_requires_grad(self.netD, False)
     self.backward_G(need_style_encoder=need_style_encoder)
     self.optimizer_D.step()
     self.optimizer_G.step()
Example #6
0
def optimize(opt):
    dataset_name = 'celebahq'
    generator_name = 'stylegan2'
    transform = data.get_transform(dataset_name, 'imval')

    # we don't need the labels, so attribute doesn't really matter here
    dset = data.get_dataset(dataset_name,
                            opt.partition,
                            'Smiling',
                            load_w=False,
                            return_path=True,
                            transform=transform)
    total = len(dset)
    if opt.indices is None:
        start_idx = 0
        end_idx = total
    else:
        start_idx = opt.indices[0]
        end_idx = opt.indices[1]

    print("Optimizing dataset partition %s items %d to %d" %
          (opt.partition, start_idx, end_idx))

    generator = domain_generator.define_generator(generator_name,
                                                  dataset_name,
                                                  load_encoder=True)
    util.set_requires_grad(False, generator.generator)
    util.set_requires_grad(False, generator.encoder)

    for i in range(start_idx, end_idx):
        (image, label, path) = dset[i]
        image = image[None].cuda()
        img_filename = os.path.splitext(os.path.basename(path))[0]

        print("Running %d / %d images: %s" % (i, end_idx, img_filename))

        output_filename = os.path.join(opt.w_path, img_filename)
        if os.path.isfile(output_filename + '.pth'):
            print(output_filename + '.pth found... skipping')
            continue

        ckpt, loss = generator.optimize(image, mask=None)
        w_optimized = ckpt['current_z']
        loss = np.array(loss).squeeze()
        im_optimized = renormalize.as_image(ckpt['current_x'][0])
        torch.save({'w': w_optimized.detach().cpu()}, output_filename + '.pth')
        np.savez(output_filename + '_loss.npz', loss=loss)
        im_optimized.save(output_filename + '_optimized_im.png')
Example #7
0
def optimize(opt):
    dataset_name = 'cifar10'
    generator_name = 'stylegan2-cc'  # class conditional stylegan
    transform = data.get_transform(dataset_name, 'imval')

    dset = data.get_dataset(dataset_name,
                            opt.partition,
                            load_w=False,
                            transform=transform)
    total = len(dset)
    if opt.indices is None:
        start_idx = 0
        end_idx = total
    else:
        start_idx = opt.indices[0]
        end_idx = opt.indices[1]

    generator = domain_generator.define_generator(generator_name,
                                                  dataset_name,
                                                  load_encoder=False)
    util.set_requires_grad(False, generator.generator)

    resnet = domain_classifier.define_classifier(dataset_name,
                                                 'imageclassifier')

    ### iterate ###
    for i in range(start_idx, end_idx):
        img, label = dset[i]

        print("Running img %d/%d" % (i, len(dset)))
        filename = os.path.join(opt.w_path, '%s_%06d.npy' % (opt.partition, i))
        if os.path.isfile(filename):
            print(filename + ' found... skipping')
            continue

        img = img[None].cuda()
        with torch.no_grad():
            pred_logit = resnet(img)
            _, pred_label = pred_logit.max(1)
            pred_label = pred_label.item()
        print("True label %d prd label %d" % (label, pred_label))
        ckpt, loss = generator.optimize(img, pred_label)
        current_z = ckpt['current_z'].detach().cpu().numpy()
        np.save(filename, current_z)
Example #8
0
def optimize(opt):
    dataset_name = 'car'
    generator_name = 'stylegan2'

    transform = data.get_transform(dataset_name, 'im2tensor')

    # loads the PIL image
    dset = data.get_dataset(dataset_name,
                            opt.partition,
                            load_w=False,
                            transform=None)
    total = len(dset)

    if opt.indices is None:
        start_idx = 0
        end_idx = total
    else:
        start_idx = opt.indices[0]
        end_idx = opt.indices[1]

    print("Optimizing dataset partition %s items %d to %d" %
          (opt.partition, start_idx, end_idx))

    generator = domain_generator.define_generator(generator_name,
                                                  dataset_name,
                                                  load_encoder=True)
    util.set_requires_grad(False, generator.generator)
    util.set_requires_grad(False, generator.encoder)

    for i in range(start_idx, end_idx):
        (im, label, bbox, path) = dset[i]
        img_filename = os.path.splitext(os.path.basename(path))[0]

        print("Running %d / %d images: %s" % (i, end_idx, img_filename))

        output_filename = os.path.join(opt.w_path, img_filename)
        if os.path.isfile(output_filename + '.pth'):
            print(output_filename + '.pth found... skipping')
            continue

        # scale image to 512 width
        width, height = im.size
        ratio = 512 / width
        new_width = 512
        new_height = int(ratio * height)
        new_im = im.resize((new_width, new_height), Image.ANTIALIAS)
        print(im.size)
        print(new_im.size)
        bbox = [int(x * ratio) for x in bbox]

        # shift to center the bbox annotation
        cx = (bbox[2] + bbox[0]) // 2
        cy = (bbox[3] + bbox[1]) // 2
        print("%d --> %d" % (cx, new_width // 2))
        print("%d --> %d" % (cy, new_height // 2))
        offset_x = new_width // 2 - cx
        offset_y = new_height // 2 - cy

        im_tensor = transform(new_im)
        im_tensor, mask = data.transforms.shift_tensor(im_tensor, offset_y,
                                                       offset_x)
        im_tensor = data.transforms.centercrop_tensor(im_tensor, 384, 512)
        mask = data.transforms.centercrop_tensor(mask, 384, 512)
        # now image size is at most 512 x 384 (could be smaller)

        # center car on 512x512 tensor
        disp_y = (512 - im_tensor.shape[1]) // 2
        disp_x = (512 - im_tensor.shape[2]) // 2
        centered_im = torch.ones((3, 512, 512)) * 0
        centered_im[:, disp_y:disp_y + im_tensor.shape[1],
                    disp_x:disp_x + im_tensor.shape[2]] = im_tensor
        centered_mask = torch.zeros_like(centered_im)
        centered_mask[:, disp_y:disp_y + im_tensor.shape[1],
                      disp_x:disp_x + im_tensor.shape[2]] = mask

        ckpt, loss = generator.optimize(centered_im[None].cuda(),
                                        centered_mask[:1][None].cuda())

        w_optimized = ckpt['current_z']
        loss = np.array(loss).squeeze()
        im_optimized = renormalize.as_image(ckpt['current_x'][0])
        torch.save({'w': w_optimized.detach().cpu()}, output_filename + '.pth')
        np.savez(output_filename + '_loss.npz', loss=loss)
        im_optimized.save(output_filename + '_optimized_im.png')
def train(train_loader, model_dict, criterion_dict, optimizer_dict, fake_pool,
          recon_pool, WR_pool, visualizer, epoch, args, val_loader, fixed):
    iter_data_time = time.time()

    for i, (img, label, landmarks, img_path) in enumerate(train_loader):
        if img.size(0) != args.batch_size:
            continue

        img_cuda = img.cuda(non_blocking=True)

        if i % args.print_loss_freq == 0:
            iter_start_time = time.time()
            t_data = iter_start_time - iter_data_time

        visualizer.reset()

        # -------------------- forward & get aligned --------------------
        theta = alignment(landmarks)
        grid = torch.nn.functional.affine_grid(
            theta, torch.Size((args.batch_size, 3, 112, 96)))

        # -------------------- generate password --------------------
        z, dis_target, rand_z, rand_dis_target, \
        inv_z, inv_dis_target, rand_inv_z, rand_inv_dis_target, \
        rand_inv_2nd_z, rand_inv_2nd_dis_target = generate_code(args.passwd_length,
                                                                args.batch_size,
                                                                args.device,
                                                                inv=True,
                                                                use_minus_one=args.use_minus_one,
                                                                gen_random_WR=True)
        real_aligned = grid_sample(img_cuda, grid)  # (B, 3, h, w)
        real_aligned = real_aligned[:, [2, 1, 0], ...]

        fake = model_dict['G'](img, z.cpu())
        fake_aligned = grid_sample(fake, grid)
        fake_aligned = fake_aligned[:, [2, 1, 0], ...]

        recon = model_dict['G'](fake, inv_z)
        recon_aligned = grid_sample(recon, grid)
        recon_aligned = recon_aligned[:, [2, 1, 0], ...]

        rand_fake = model_dict['G'](img, rand_z.cpu())
        rand_fake_aligned = grid_sample(rand_fake, grid)
        rand_fake_aligned = rand_fake_aligned[:, [
            2,
            1,
            0,
        ], ...]

        rand_recon = model_dict['G'](fake, rand_inv_z)
        rand_recon_aligned = grid_sample(rand_recon, grid)
        rand_recon_aligned = rand_recon_aligned[:, [2, 1, 0], ...]

        rand_recon_2nd = model_dict['G'](fake, rand_inv_2nd_z)
        rand_recon_2nd_aligned = grid_sample(rand_recon_2nd, grid)
        rand_recon_2nd_aligned = rand_recon_2nd_aligned[:, [2, 1, 0], ...]

        # init loss dict for plot & print
        current_losses = {}

        # -------------------- D PART --------------------
        # init
        set_requires_grad(model_dict['G_nets'], False)
        set_requires_grad(model_dict['D_nets'], True)
        optimizer_dict['D'].zero_grad()
        loss_D = 0.

        # ========== Face Recognition (FR) losses (L_{adv}, L_{rec\_cls}) ==========
        # FAKE FRs
        # M
        id_fake = model_dict['FR'](fake_aligned.detach())[0]
        loss_D_FR_fake = criterion_dict['FR'](id_fake, label.to(args.device))

        # R & WR
        id_recon = model_dict['FR'](recon_aligned.detach())[0]
        loss_D_FR_recon = -criterion_dict['FR'](id_recon, label.to(
            args.device))

        id_rand_recon = model_dict['FR'](rand_recon_aligned.detach())[0]
        loss_D_FR_rand_recon = criterion_dict['FR'](id_rand_recon,
                                                    label.to(args.device))

        loss_D_FR_fake_total = args.lambda_FR_M * loss_D_FR_fake + loss_D_FR_recon \
                               + args.lambda_FR_WR * loss_D_FR_rand_recon
        loss_D_FR_fake_avg = loss_D_FR_fake_total / float(1. +
                                                          args.lambda_FR_M +
                                                          args.lambda_FR_WR)
        current_losses.update({
            'D_FR_M': loss_D_FR_fake.item(),
            'D_FR_R': loss_D_FR_recon.item(),
            'D_FR_WR': loss_D_FR_rand_recon.item(),
        })

        # REAL FR
        id_real = model_dict['FR'](real_aligned)[0]
        loss_D_FR_real = criterion_dict['FR'](id_real, label.to(args.device))

        loss_D += args.lambda_FR * (loss_D_FR_real + loss_D_FR_fake_avg) * 0.5
        current_losses.update({
            'D_FR_real': loss_D_FR_real.item(),
            'D_FR_fake': loss_D_FR_fake_avg.item()
        })

        # ========== GAN loss (L_{GAN}) ==========
        # fake
        all_M = torch.cat((
            fake.detach().cpu(),
            rand_fake.detach().cpu(),
        ), 0)
        pred_D_M = model_dict['D'](fake_pool.query(all_M,
                                                   batch_size=args.batch_size),
                                   'M')
        loss_D_M = criterion_dict['GAN'](pred_D_M, False)

        # R
        pred_D_recon = model_dict['D'](recon_pool.query(
            recon.detach().cpu(), batch_size=args.batch_size), 'R')
        loss_D_recon = criterion_dict['GAN'](pred_D_recon, False)

        # WR
        all_WR = torch.cat(
            (rand_recon.detach().cpu(), rand_recon_2nd.detach().cpu()), 0)
        pred_D_WR = model_dict['D'](WR_pool.query(all_WR,
                                                  batch_size=args.batch_size),
                                    'WR')
        loss_D_WR = criterion_dict['GAN'](pred_D_WR, False)

        loss_D_fake_total = args.lambda_GAN_M * loss_D_M + \
                            args.lambda_GAN_recon * loss_D_recon + \
                            args.lambda_GAN_WR * loss_D_WR
        loss_D_fake_total_weights = args.lambda_GAN_M + \
                                    args.lambda_GAN_recon + \
                                    args.lambda_GAN_WR
        loss_D_GAN_fake = loss_D_fake_total / loss_D_fake_total_weights
        current_losses.update({
            'D_GAN_M': loss_D_M.item(),
            'D_GAN_R': loss_D_recon.item(),
            'D_GAN_WR': loss_D_WR.item()
        })

        # real
        pred_D_real_M = model_dict['D'](img, 'M')
        pred_D_real_R = model_dict['D'](img, 'R')
        pred_D_real_WR = model_dict['D'](img, 'WR')

        loss_D_real_M = criterion_dict['GAN'](pred_D_real_M, True)
        loss_D_real_R = criterion_dict['GAN'](pred_D_real_R, True)
        loss_D_real_WR = criterion_dict['GAN'](pred_D_real_WR, True)

        loss_D_GAN_real = (args.lambda_GAN_M * loss_D_real_M +
                           args.lambda_GAN_recon * loss_D_real_R +
                           args.lambda_GAN_WR * loss_D_real_WR) / \
                          (args.lambda_GAN_M +
                           args.lambda_GAN_recon +
                           args.lambda_GAN_WR)

        loss_D += args.lambda_GAN * (loss_D_GAN_fake + loss_D_GAN_real) * 0.5
        current_losses.update({
            'D_GAN_real': loss_D_GAN_real.item(),
            'D_GAN_fake': loss_D_GAN_fake.item()
        })
        current_losses['D'] = loss_D.item()

        # D backward and optimizer steps
        loss_D.backward()
        optimizer_dict['D'].step()

        # -------------------- G PART --------------------
        # init
        set_requires_grad(model_dict['D_nets'], False)
        set_requires_grad(model_dict['G_nets'], True)
        optimizer_dict['G'].zero_grad()
        loss_G = 0

        # ========== GAN loss (L_{GAN}) ==========
        pred_G_fake = model_dict['D'](fake, 'M')
        loss_G_GAN_fake = criterion_dict['GAN'](pred_G_fake, True)

        pred_G_recon = model_dict['D'](recon, 'R')
        loss_G_GAN_recon = criterion_dict['GAN'](pred_G_recon, True)

        pred_G_WR = model_dict['D'](rand_recon, 'WR')
        loss_G_GAN_WR = criterion_dict['GAN'](pred_G_WR, True)

        loss_G_GAN_total = args.lambda_GAN_M * loss_G_GAN_fake + \
                           args.lambda_GAN_recon * loss_G_GAN_recon + \
                           args.lambda_GAN_WR * loss_G_GAN_WR
        loss_G_GAN_total_weights = args.lambda_GAN_M + args.lambda_GAN_recon + args.lambda_GAN_WR
        loss_G_GAN = loss_G_GAN_total / loss_G_GAN_total_weights
        loss_G += args.lambda_GAN * loss_G_GAN

        current_losses.update({
            'G_GAN_M': loss_G_GAN_fake.item(),
            'G_GAN_R': loss_G_GAN_recon.item(),
            'G_GAN_WR': loss_G_GAN_WR.item(),
            'G_GAN': loss_G_GAN.item()
        })

        # ========== infoGAN loss (L_{aux}) ==========
        if args.lambda_dis > 0:
            fake_dis_logits = model_dict['Q'](infoGAN_input(img_cuda, fake))
            infogan_fake_acc = 0
            loss_G_fake_dis = 0
            for dis_idx in range(args.passwd_length // 4):
                a = fake_dis_logits[dis_idx].max(dim=1)[1]
                b = dis_target[:, dis_idx]
                acc = torch.eq(a, b).type(torch.float).mean()
                infogan_fake_acc += acc.item()
                loss_G_fake_dis += criterion_dict['DIS'](
                    fake_dis_logits[dis_idx], dis_target[:, dis_idx])
            infogan_fake_acc = infogan_fake_acc / float(
                args.passwd_length // 4)

            recon_dis_logits = model_dict['Q'](infoGAN_input(fake, recon))
            infogan_recon_acc = 0
            loss_G_recon_dis = 0
            for dis_idx in range(args.passwd_length // 4):
                a = recon_dis_logits[dis_idx].max(dim=1)[1]
                b = inv_dis_target[:, dis_idx]
                acc = torch.eq(a, b).type(torch.float).mean()
                infogan_recon_acc += acc.item()
                loss_G_recon_dis += criterion_dict['DIS'](
                    recon_dis_logits[dis_idx], inv_dis_target[:, dis_idx])
            infogan_recon_acc = infogan_recon_acc / float(
                args.passwd_length // 4)

            rand_recon_dis_logits = model_dict['Q'](infoGAN_input(
                fake, rand_recon))
            infogan_rand_recon_acc = 0
            loss_G_recon_rand_dis = 0
            for dis_idx in range(args.passwd_length // 4):
                a = rand_recon_dis_logits[dis_idx].max(dim=1)[1]
                b = rand_inv_dis_target[:, dis_idx]
                acc = torch.eq(a, b).type(torch.float).mean()
                infogan_rand_recon_acc += acc.item()
                loss_G_recon_rand_dis += criterion_dict['DIS'](
                    rand_recon_dis_logits[dis_idx],
                    rand_inv_dis_target[:, dis_idx])
            infogan_rand_recon_acc = infogan_rand_recon_acc / float(
                args.passwd_length // 4)

            dis_loss_total = loss_G_fake_dis + loss_G_recon_dis + loss_G_recon_rand_dis
            dis_acc_total = infogan_fake_acc + infogan_recon_acc + infogan_rand_recon_acc
            dis_cnt = 3

            loss_G += args.lambda_dis * dis_loss_total
            current_losses.update({
                'dis': dis_loss_total.item(),
                'dis_acc': dis_acc_total / float(dis_cnt)
            })

        # ========== Face Recognition (FR) loss (L_{adv}, L{rec_cls}})==========
        # (netFR must not be fixed)
        id_fake_G, fake_feat = model_dict['FR'](fake_aligned)
        loss_G_FR_fake = -criterion_dict['FR'](id_fake_G, label.to(
            args.device))

        id_recon_G, recon_feat = model_dict['FR'](recon_aligned)
        loss_G_FR_recon = criterion_dict['FR'](id_recon_G,
                                               label.to(args.device))

        id_rand_recon_G, rand_recon_feat = model_dict['FR'](rand_recon_aligned)
        loss_G_FR_rand_recon = -criterion_dict['FR'](id_rand_recon_G,
                                                     label.to(args.device))

        loss_G_FR_avg = (args.lambda_FR_M * loss_G_FR_fake +
                         loss_G_FR_recon +
                         args.lambda_FR_WR * loss_G_FR_rand_recon) /\
                        (args.lambda_FR_M + 1. + args.lambda_FR_WR)
        loss_G += args.lambda_FR * loss_G_FR_avg

        current_losses.update({
            'G_FR_M': loss_G_FR_fake.item(),
            'G_FR_R': loss_G_FR_recon.item(),
            'G_FR_WR': loss_G_FR_rand_recon.item(),
            'G_FR': loss_G_FR_avg.item()
        })

        # ========== Feature losses (L_{feat} is the sum of the three L_{dis}'s) ==========
        if args.feature_loss == 'cos':  # make cos sim target
            FR_cos_sim_target = torch.empty(size=(args.batch_size, 1),
                                            dtype=torch.float32,
                                            device=args.device)
            FR_cos_sim_target.fill_(-1.)
        else:
            FR_cos_sim_target = None

        id_rand_fake_G, rand_fake_feat = model_dict['FR'](rand_fake_aligned)
        id_rand_recon_2nd_G, rand_recon_2nd_feat = model_dict['FR'](
            rand_recon_2nd_aligned)

        if args.lambda_Feat:
            loss_G_feat = get_feat_loss(fake_feat, rand_fake_feat,
                                        FR_cos_sim_target, args.feature_loss,
                                        criterion_dict)
            current_losses['G_feat'] = loss_G_feat.item()
        else:
            loss_G_feat = 0.

        if args.lambda_WR_Feat:
            loss_G_WR_feat = get_feat_loss(rand_recon_feat,
                                           rand_recon_2nd_feat,
                                           FR_cos_sim_target,
                                           args.feature_loss, criterion_dict)
            current_losses['G_WR_feat'] = loss_G_WR_feat.item()
        else:
            loss_G_WR_feat = 0.

        if args.lambda_false_recon_diff:
            loss_G_M_WR_feat = get_feat_loss(fake_feat, rand_recon_feat,
                                             FR_cos_sim_target,
                                             args.feature_loss, criterion_dict)
            current_losses['G_feat_M_WR'] = loss_G_M_WR_feat.item()
        else:
            loss_G_M_WR_feat = 0.

        loss_G += args.lambda_Feat * loss_G_feat + \
                  args.lambda_WR_Feat * loss_G_WR_feat + \
                  args.lambda_false_recon_diff * loss_G_M_WR_feat

        # ========== L1/Recon losses (L_1, L_{rec}) ==========
        loss_G_L1 = criterion_dict['L1'](fake, img_cuda)
        loss_G_rand_recon_L1 = criterion_dict['L1'](rand_recon, img_cuda)
        loss_G_recon = criterion_dict['L1'](recon, img_cuda)

        loss_G += args.lambda_L1 * loss_G_L1 + \
                  args.lambda_rand_recon_L1 * loss_G_rand_recon_L1 + \
                  args.lambda_G_recon * loss_G_recon

        current_losses.update({
            'L1_M': loss_G_L1.item(),
            'recon': loss_G_recon.item(),
            'L1_WR': loss_G_rand_recon_L1.item()
        })

        current_losses['G'] = loss_G.item()

        # G backward and optimizer steps
        loss_G.backward()
        optimizer_dict['G'].step()

        # -------------------- LOGGING PART --------------------
        if i % args.print_loss_freq == 0:
            t = (time.time() - iter_start_time) / args.batch_size
            visualizer.print_current_losses(epoch, i, current_losses, t,
                                            t_data)
            if args.display_id > 0 and i % args.plot_loss_freq == 0:
                visualizer.plot_current_losses(epoch,
                                               float(i) / len(train_loader),
                                               args, current_losses)

        if i % args.visdom_visual_freq == 0:
            save_result = i % args.update_html_freq == 0

            current_visuals = OrderedDict()
            current_visuals['real'] = img.detach()
            current_visuals['fake'] = fake.detach()
            current_visuals['rand_fake'] = rand_fake.detach()
            current_visuals['recon'] = recon.detach()
            current_visuals['rand_recon'] = rand_recon.detach()
            current_visuals['rand_recon_2nd'] = rand_recon_2nd.detach()

            try:
                with time_limit(60):
                    visualizer.display_current_results(current_visuals, epoch,
                                                       save_result, args)
            except TimeoutException:
                visualizer.logger.log(
                    'TIME OUT visualizer.display_current_results epoch:{} iter:{}. Change display_id to -1'
                    .format(epoch, i))
                # disable visdom display ever since
                args.display_id = -1

        # +1 so that we do not save/test for 0th iteration
        if (i + 1) % args.save_iter_freq == 0:
            save_model(epoch,
                       model_dict,
                       optimizer_dict,
                       args,
                       iter=i,
                       save_sep=False)
            if args.display_id > 0:
                visualizer.vis.save([args.name])

        if (i + 1) % args.html_iter_freq == 0:
            validate(val_loader, model_dict, visualizer, epoch, args, fixed, i)

        if (i + 1) % args.print_loss_freq == 0:
            iter_data_time = time.time()
Example #10
0
def train(train_loader, model_dict, criterion_dict, optimizer_dict, fake_pool, recon_pool, fake_pair_pool, WR_pool, visualizer, epoch, args, test_loader, fixed_z, fixed_rand_z):
    iter_data_time = time.time()

    for i, (img, label, landmarks, img_path) in enumerate(train_loader):
        iter_start_time = time.time()
        if i % args.print_loss_freq == 0:
            t_data = iter_start_time - iter_data_time

        visualizer.reset()
        batch_size = img.size(0)

        if args.lambda_dis > 0:
            # -------------------- generate password --------------------
            z, dis_target, rand_z, rand_dis_target, inv_z, inv_dis_target, another_rand_z, another_rand_dis_target = generate_code(args.passwd_length, batch_size, args.device, inv=True)

            # -------------------- forward --------------------
            # TODO: whether to detach
            fake = model_dict['G'](img, z.cpu())
            rand_fake = model_dict['G'](img, rand_z.cpu())
            if args.lambda_G_recon > 0:
                recon = model_dict['G'](fake, inv_z)
                rand_recon = model_dict['G'](fake, another_rand_z)
        else:
            fake = model_dict['G'](img)
            if args.lambda_G_recon > 0:
                recon = model_dict['G'](fake)

        # FR forward and FR losses
        theta = alignment(landmarks)
        grid = torch.nn.functional.affine_grid(theta, torch.Size((batch_size, 3, 112, 96)))
        real_aligned = torch.nn.functional.grid_sample(img.cuda(), grid)
        real_aligned = real_aligned[:, [2, 1, 0], ...]

        fake_aligned = torch.nn.functional.grid_sample(fake, grid)
        fake_aligned = fake_aligned[:, [2, 1, 0], ...]

        rand_fake_aligned = torch.nn.functional.grid_sample(rand_fake, grid)
        rand_fake_aligned = rand_fake_aligned[:, [2, 1, 0, ], ...]
        # (B, 3, h, w)

        if args.lambda_G_recon > 0:
            recon_aligned = torch.nn.functional.grid_sample(recon, grid)
            recon_aligned = recon_aligned[:, [2, 1, 0], ...]
            rand_recon_aligned = torch.nn.functional.grid_sample(rand_recon, grid)
            rand_recon_aligned = rand_recon_aligned[:, [2, 1, 0], ...]

        current_losses = {}
        # -------------------- D PART --------------------
        if optimizer_dict['D'] is not None:
            set_requires_grad(model_dict['G_nets'], False)
            set_requires_grad(model_dict['D_nets'], True)
            optimizer_dict['D'].zero_grad()

            id_real = model_dict['FR'](real_aligned)[0]
            loss_D_FR_real = criterion_dict['FR'](id_real, label.to(args.device))

            cnt_FR_fake = 0.
            loss_D_FR_fake_total = 0
            if args.train_M:
                id_fake = model_dict['FR'](fake_aligned.detach())[0]
                id_rand_fake = model_dict['FR'](rand_fake_aligned.detach())[0]

                loss_D_FR_fake = criterion_dict['FR'](id_fake, label.to(args.device))
                loss_D_FR_rand_fake = criterion_dict['FR'](id_rand_fake, label.to(args.device))

                loss_D_FR_fake_total += loss_D_FR_fake + loss_D_FR_rand_fake
                cnt_FR_fake += 2.
                current_losses.update({'D_FR_fake': loss_D_FR_fake.item(),
                                       'D_FR_rand': loss_D_FR_rand_fake.item(),
                                       # 'D_FR_rand_recon': loss_D_FR_rand_recon.item()
                                       })

            if args.recon_FR:
                # TODO: rand_fake_recon FR loss?
                id_recon = model_dict['FR'](recon_aligned.detach())[0]
                loss_D_FR_recon = -criterion_dict['FR'](id_recon, label.to(args.device))
                if args.lambda_FR_WR:
                    id_rand_recon = model_dict['FR'](rand_recon_aligned.detach())[0]
                    loss_D_FR_rand_recon = criterion_dict['FR'](id_rand_recon, label.to(args.device))
                    current_losses.update({'D_FR_rand_recon': loss_D_FR_rand_recon.item()
                                           })
                else:
                    loss_D_FR_rand_recon = 0.

                loss_D_FR_fake_total += loss_D_FR_recon + args.lambda_FR_WR * loss_D_FR_rand_recon
                cnt_FR_fake += 1. + args.lambda_FR_WR
                current_losses.update({'D_FR_recon': loss_D_FR_recon.item(),
                                       # 'D_FR_rand_recon': loss_D_FR_rand_recon.item()
                                       })


            loss_D_FR_fake_avg = loss_D_FR_fake_total / float(cnt_FR_fake)

            loss_D = args.lambda_FR * (loss_D_FR_real + loss_D_FR_fake_avg) * 0.5
            current_losses.update({'D_FR_real': loss_D_FR_real.item(),
                                   'D_FR_fake': loss_D_FR_fake_avg.item()
                              # 'D_FR_fake': loss_D_FR_fake.item(),
                              # 'D_FR_rand': loss_D_FR_rand_fake.item(),
                              # 'D_FR_rand_recon': loss_D_FR_rand_recon.item()
                              })

            # GAN loss
            if args.lambda_GAN > 0:
                # real
                if args.recon_pair_GAN:
                    assert args.single_GAN_recon_only
                    real_input = torch.cat((img.cuda(), recon.detach()), dim=1)
                else:
                    real_input = img

                pred_D_real = model_dict['D'](real_input)
                loss_D_real = criterion_dict['GAN'](pred_D_real, True)

                # fake
                loss_D_fake_total = 0.
                loss_D_fake_total_weights = 0.

                # recon
                if args.lambda_GAN_recon:
                    if args.recon_pair_GAN:
                        recon_input_to_pool = torch.cat((recon.detach().cpu(), img), dim=1)
                    else:
                        recon_input_to_pool = recon.detach().cpu()

                    pred_D_recon = model_dict['D'](recon_pool.query(recon_input_to_pool))
                    loss_D_recon = criterion_dict['GAN'](pred_D_recon, False)

                    loss_D_fake_total += args.lambda_GAN_recon * loss_D_recon
                    loss_D_fake_total_weights += args.lambda_GAN_recon
                    current_losses['D_recon'] = loss_D_recon.item()

                if not args.single_GAN_recon_only:
                    assert args.lambda_pair_GAN == 0
                    if args.train_M:
                        all_M = torch.cat((fake.detach().cpu(),
                                           rand_fake.detach().cpu(),
                                           ), 0)
                        pred_D_M = model_dict['D'](fake_pool.query(all_M))
                        loss_D_M = criterion_dict['GAN'](pred_D_M, False)

                        loss_D_fake_total += args.lambda_GAN_M * loss_D_M
                        loss_D_fake_total_weights += args.lambda_GAN_M
                        current_losses['D_M'] = loss_D_M.item()

                    if args.lambda_GAN_WR:
                        pred_D_WR = model_dict['D'](WR_pool.query(rand_recon.detach().cpu()))
                        loss_D_WR = criterion_dict['GAN'](pred_D_WR, False)

                        loss_D_fake_total += args.lambda_GAN_WR * loss_D_WR
                        loss_D_fake_total_weights += args.lambda_GAN_WR
                        current_losses['D_WR'] = loss_D_WR.item()


                loss_D_fake = loss_D_fake_total / loss_D_fake_total_weights
                loss_D += args.lambda_GAN * (loss_D_fake + loss_D_real) * 0.5

                current_losses.update({
                    'D_real': loss_D_real.item(),
                    'D_fake': loss_D_fake.item()
                })


            if args.lambda_pair_GAN > 0:
                loss_pair_fake_total = 0
                loss_pair_real_total = 0
                loss_pair_cnt = 0.
                if args.train_M:
                    pred_pair_real1 = model_dict['pair_D'](torch.cat((img.cuda(), fake.detach()), 1))
                    pred_pair_real2 = model_dict['pair_D'](torch.cat((img.cuda(), rand_fake.detach()), 1))

                    all_fake_pair = torch.cat((torch.cat((fake.detach().cpu(), img), 1),
                                               torch.cat((rand_fake.detach().cpu(), img), 1),
                                               ), 0)
                    pred_pair_fake = model_dict['pair_D'](fake_pair_pool.query(all_fake_pair))

                    loss_pair_M_real = (criterion_dict['GAN'](pred_pair_real1, True) + criterion_dict['GAN'](pred_pair_real2, True)) / 2.
                    loss_pair_M_fake = criterion_dict['GAN'](pred_pair_fake, False)

                    loss_pair_real_total += loss_pair_M_real
                    loss_pair_fake_total += loss_pair_M_fake
                    loss_pair_cnt += 1

                pred_pair_WR_real = model_dict['pair_D'](torch.cat((img.cuda(), rand_recon.detach()), 1))
                pred_pair_WR_fake = model_dict['pair_D'](WR_pool.query(torch.cat((rand_recon.detach().cpu(), img), 1)))

                loss_pair_WR_real = criterion_dict['GAN'](pred_pair_WR_real, True)
                loss_pair_WR_fake = criterion_dict['GAN'](pred_pair_WR_fake, False)

                loss_pair_real_total += args.multiple_pair_WR_GAN * loss_pair_WR_real
                loss_pair_fake_total += args.multiple_pair_WR_GAN * loss_pair_WR_fake
                loss_pair_cnt += args.multiple_pair_WR_GAN

                loss_pair_D_real = loss_pair_real_total / loss_pair_cnt  # (loss_pair_M_real + args.multiple_pair_WR_GAN * loss_pair_WR_real) / (1. + args.multiple_pair_WR_GAN)
                loss_pair_D_fake = loss_pair_fake_total / loss_pair_cnt #(loss_pair_M_fake + args.multiple_pair_WR_GAN * loss_pair_WR_fake) / (1. + args.multiple_pair_WR_GAN)

                current_losses.update({
                    'pair_D_fake': loss_pair_D_fake.item(),
                    'pair_D_real': loss_pair_D_real.item()
                })
                loss_D += args.lambda_pair_GAN * (loss_pair_D_fake + loss_pair_D_real) * 0.5

            current_losses['D'] = loss_D.item()
            # D backward and optimizer steps
            loss_D.backward()

            if args.gan_mode == 'wgangp':
                real_to_wgangp = torch.cat((img, img), 0).to(args.device)
                if np.random.rand() > 0.5:
                    fake_selected = fake.detach()
                else:
                    fake_selected = rand_fake.detach()
                fake_to_wgangp = torch.cat((fake_selected, rand_recon.detach()), 0)
                loss_gp, gradients = models.cal_gradient_penalty(model_dict['D'], real_to_wgangp, fake_to_wgangp, args.device)
                # print('gradeints abs/l2 mean:', gradients[0], gradients[1])
                loss_gp *= args.lambda_GAN
                # print('loss_gp', loss_gp.item())
                loss_gp.backward()

            optimizer_dict['D'].step()

        # -------------------- G PART --------------------
        # init
        set_requires_grad(model_dict['D_nets'], False)
        set_requires_grad(model_dict['G_nets'], True)
        optimizer_dict['G'].zero_grad()

        loss_G = 0
        # GAN loss
        if args.lambda_GAN > 0:
            loss_G_GAN_total = 0.
            loss_G_GAN_total_weights = 0.

            # recon
            if args.lambda_GAN_recon:
                if args.recon_pair_GAN:
                    recon_input_G = torch.cat((recon, img.cuda()), dim=1)
                else:
                    recon_input_G = recon
                pred_G_recon = model_dict['D'](recon_input_G)
                loss_G_recon = criterion_dict['GAN'](pred_G_recon, True)

                loss_G_GAN_total += args.lambda_GAN_recon * loss_G_recon
                loss_G_GAN_total_weights += args.lambda_GAN_recon
                current_losses['G_recon'] = loss_G_recon.item()

            if not args.single_GAN_recon_only:
                if args.train_M:
                    pred_G_fake = model_dict['D'](fake)
                    pred_G_rand_fake = model_dict['D'](rand_fake)

                    loss_G_fake = criterion_dict['GAN'](pred_G_fake, True)
                    loss_G_rand_fake = criterion_dict['GAN'](pred_G_rand_fake, True)

                    loss_G_GAN_total += args.lambda_GAN_M * 0.5 * (loss_G_fake + loss_G_rand_fake)
                    loss_G_GAN_total_weights += args.lambda_GAN_M

                    current_losses['G_M'] = 0.5 * (loss_G_fake.item() + loss_G_rand_fake.item())

                pred_G_WR = model_dict['D'](rand_recon)
                loss_G_WR = criterion_dict['GAN'](pred_G_WR, True)
                current_losses['G_WR'] = loss_G_WR.item()

                loss_G_GAN_total += args.lambda_GAN_WR * loss_G_WR
                loss_G_GAN_total_weights += args.lambda_GAN_WR

            loss_G_GAN = loss_G_GAN_total / loss_G_GAN_total_weights
            loss_G += args.lambda_GAN * loss_G_GAN

            current_losses.update({'G_GAN': loss_G_GAN.item(),
                                   })


        if args.lambda_pair_GAN > 0:
            loss_pair_G_total = 0
            cnt_pair_G = 0.

            if args.train_M:
                pred_pair_fake1_G = model_dict['pair_D'](torch.cat((fake, img.cuda()), 1))
                pred_pair_fake2_G = model_dict['pair_D'](torch.cat((rand_fake, img.cuda()), 1))

                loss_pair_M_G = (criterion_dict['GAN'](pred_pair_fake1_G, True)
                               + criterion_dict['GAN'](pred_pair_fake2_G, True)) / 2.

                loss_pair_G_total += loss_pair_M_G
                cnt_pair_G += 1.

            pred_pair_fake3_G = model_dict['pair_D'](torch.cat((rand_recon, img.cuda()), 1))
            loss_pair_WR_G = criterion_dict['GAN'](pred_pair_fake3_G, True)

            loss_pair_G_total += args.multiple_pair_WR_GAN * loss_pair_WR_G
            cnt_pair_G += args.multiple_pair_WR_GAN

            loss_pair_G_avg = loss_pair_G_total / cnt_pair_G

            loss_G += args.lambda_pair_GAN * loss_pair_G_avg
            current_losses['pair_G'] = loss_pair_G_avg.item()

        # infoGAN loss
        def infoGAN_input(img1, img2):
            if args.use_minus_Q:
                return img2 - img1
            else:
                return torch.cat((img1, img2), 1)

        if args.lambda_dis > 0:
            infogan_acc = 0
            infogan_inv_acc = 0
            infogan_rand_acc = 0
            infogan_recon_rand_acc = 0

            dis_logits = model_dict['Q'](infoGAN_input(img.cuda(), fake))
            loss_G_dis = 0
            for dis_idx in range(args.passwd_length // 4):
                a = dis_logits[dis_idx].max(dim=1)[1]
                b = dis_target[:, dis_idx]
                acc = torch.eq(a, b).type(torch.float).mean()
                infogan_acc += acc.item()
                loss_G_dis += criterion_dict['DIS'](dis_logits[dis_idx], dis_target[:, dis_idx])
            infogan_acc = infogan_acc / float(args.passwd_length // 4)

            inv_dis_logits = model_dict['Q'](infoGAN_input(fake, recon))
            loss_G_inv_dis = 0
            for dis_idx in range(args.passwd_length // 4):
                a = inv_dis_logits[dis_idx].max(dim=1)[1]
                b = inv_dis_target[:, dis_idx]
                acc = torch.eq(a, b).type(torch.float).mean()
                infogan_inv_acc += acc.item()
                loss_G_inv_dis += criterion_dict['DIS'](inv_dis_logits[dis_idx], inv_dis_target[:, dis_idx])
            infogan_inv_acc = infogan_inv_acc / float(args.passwd_length // 4)

            rand_dis_logits = model_dict['Q'](infoGAN_input(img.cuda(), rand_fake))
            loss_G_rand_dis = 0
            for dis_idx in range(args.passwd_length // 4):
                a = rand_dis_logits[dis_idx].max(dim=1)[1]
                b = rand_dis_target[:, dis_idx]
                acc = torch.eq(a, b).type(torch.float).mean()
                infogan_rand_acc += acc.item()
                loss_G_rand_dis += criterion_dict['DIS'](rand_dis_logits[dis_idx], rand_dis_target[:, dis_idx])
            infogan_rand_acc = infogan_rand_acc / float(args.passwd_length // 4)

            recon_rand_dis_logits = model_dict['Q'](infoGAN_input(fake, rand_recon))
            loss_G_recon_rand_dis = 0
            for dis_idx in range(args.passwd_length // 4):
                a = recon_rand_dis_logits[dis_idx].max(dim=1)[1]
                b = another_rand_dis_target[:, dis_idx]
                acc = torch.eq(a, b).type(torch.float).mean()
                infogan_recon_rand_acc += acc.item()
                loss_G_recon_rand_dis += criterion_dict['DIS'](recon_rand_dis_logits[dis_idx], another_rand_dis_target[:, dis_idx])
            infogan_recon_rand_acc = infogan_recon_rand_acc / float(args.passwd_length // 4)

            # current_losses.update({'G_dis': loss_G_dis.item(),
            #                        'G_inv_dis': loss_G_inv_dis.item(),
            #                        'G_dis_acc': infogan_acc,
            #                        'G_inv_dis_acc': infogan_inv_acc,
            #                        'G_rand_dis': loss_G_rand_dis.item(),
            #                        'G_recon_rand_dis': loss_G_recon_rand_dis.item(),
            #                        'G_rand_dis_acc': infogan_rand_acc,
            #                        'G_recon_rand_dis_acc': infogan_recon_rand_acc
            #                        })
            loss_dis = (loss_G_dis + loss_G_inv_dis + loss_G_rand_dis + loss_G_recon_rand_dis)
            dis_acc = (infogan_acc + infogan_inv_acc + infogan_rand_acc + infogan_recon_rand_acc) / 4.
            loss_G += args.lambda_dis * loss_dis
            current_losses.update({
                'dis': loss_dis.item(),
                'dis_acc': dis_acc
            })

        # FR loss, netFR must not be fixed
        loss_G_FR_total = 0
        cnt_G_FR = 0.

        if args.train_M:
            id_fake_G, fake_feat = model_dict['FR'](fake_aligned)
            loss_G_FR = -criterion_dict['FR'](id_fake_G, label.to(args.device))
            # current_losses['G_FR'] = loss_G_FR.item()

            id_rand_fake_G, rand_fake_feat = model_dict['FR'](rand_fake_aligned)
            loss_G_FR_rand = -criterion_dict['FR'](id_rand_fake_G, label.to(args.device))
            # current_losses['G_FR_rand'] = loss_G_FR_rand.item()

            loss_G_FR_total += loss_G_FR + loss_G_FR_rand
            cnt_G_FR += 2

        if args.feature_loss == 'cos':
            FR_cos_sim_target = torch.empty(size=(batch_size, 1), dtype=torch.float32, device=args.device)
            FR_cos_sim_target.fill_(-1.)

        if args.lambda_Feat:
            if args.feature_loss == 'cos':
                loss_G_feat = criterion_dict['Feat'](fake_feat, rand_fake_feat, target=FR_cos_sim_target)
            else:
                loss_G_feat = -criterion_dict['Feat'](fake_feat, rand_fake_feat)
            current_losses['G_feat'] = loss_G_feat.item()
            loss_G += args.lambda_Feat * loss_G_feat


        if args.lambda_G_recon:
            id_recon_G, recon_feat = model_dict['FR'](recon_aligned)
            if args.lambda_FR_WR:
                id_rand_recon_G, rand_recon_feat = model_dict['FR'](rand_recon_aligned)

            if args.lambda_recon_Feat:
                if args.feature_loss == 'cos':
                    loss_G_recon_feat = criterion_dict['Feat'](recon_feat, rand_recon_feat, target=FR_cos_sim_target)
                else:
                    loss_G_recon_feat = -criterion_dict['Feat'](recon_feat, rand_recon_feat)
                current_losses['G_recon_feat'] = loss_G_recon_feat.item()
                loss_G += args.lambda_recon_Feat * loss_G_recon_feat

            if args.lambda_false_recon_diff:
                if args.feature_loss == 'cos':
                    loss_G_false_recon_feat =criterion_dict['Feat'](fake_feat, rand_recon_feat, target=FR_cos_sim_target)
                else:
                    loss_G_false_recon_feat =-criterion_dict['Feat'](fake_feat, rand_recon_feat)
                current_losses['G_false_recon_feat'] = loss_G_false_recon_feat.item()
                loss_G += args.lambda_false_recon_diff * loss_G_false_recon_feat

            if args.recon_FR:
                loss_G_FR_recon = criterion_dict['FR'](id_recon_G, label.to(args.device))
                # current_losses['G_FR_recon'] = loss_G_FR_recon.item()
                if args.lambda_FR_WR:
                    loss_G_FR_rand_recon = -criterion_dict['FR'](id_rand_recon_G, label.to(args.device))
                else:
                    loss_G_FR_rand_recon = 0.
                # current_losses['G_FR_rand_recon'] = loss_G_FR_rand_recon.item()
                loss_G_FR_total += loss_G_FR_recon + args.lambda_FR_WR * loss_G_FR_rand_recon
                cnt_G_FR += 1. + args.lambda_FR_WR

        loss_G_FR_avg = loss_G_FR_total / cnt_G_FR

        loss_G += args.lambda_FR * loss_G_FR_avg
        current_losses['G_FR'] = loss_G_FR_avg.item()


        # loss_L1 = 0
        # cnt_loss_L1 = 0
        if args.lambda_L1 > 0:
            loss_G_L1 = criterion_dict['L1'](fake, img.cuda())
            current_losses['L1'] = loss_G_L1.item()
            # loss_L1 += loss_G_L1.item()
            # cnt_loss_L1 += 1
            loss_G += args.lambda_L1 * loss_G_L1

        if args.lambda_rand_L1 > 0:
            loss_G_rand_L1 = criterion_dict['L1'](rand_fake, img.cuda())
            current_losses['rand_L1'] = loss_G_rand_L1.item()
            # loss_L1 += loss_G_rand_L1.item()
            # cnt_loss_L1 += 1
            loss_G += args.lambda_rand_L1 * loss_G_rand_L1

        if args.lambda_rand_recon_L1 > 0:
            loss_G_rand_recon_L1 = criterion_dict['L1'](rand_recon, img.cuda())
            current_losses['wrong_recon_L1'] = loss_G_rand_recon_L1.item()
            # loss_L1 += loss_G_rand_recon_L1.item()
            # cnt_loss_L1 += 1
            loss_G += args.lambda_rand_recon_L1 * loss_G_rand_recon_L1

        # current_losses['L1'] = loss_L1 / float(cnt_loss_L1)

        if args.lambda_G_recon > 0:
            loss_G_recon = criterion_dict['L1'](recon, img.cuda())
            loss_G += args.lambda_G_recon * loss_G_recon
            current_losses['recon'] = loss_G_recon.item()

        if args.lambda_G_rand_recon > 0:
            if args.use_minus_one:
                inv_rand_z = rand_z * -1
            else:
                inv_rand_z = 1.0 - rand_z
            rand_fake_recon = model_dict['G'](rand_fake, inv_rand_z)
            loss_G_rand_recon = criterion_dict['L1'](rand_fake_recon, img.cuda())
            loss_G += args.lambda_G_rand_recon * loss_G_rand_recon
            current_losses['another_recon'] = loss_G_rand_recon.item()

        current_losses['G'] = loss_G.item()

        # G backward and optimizer steps
        loss_G.backward()
        optimizer_dict['G'].step()

        # -------------------- LOGGING PART --------------------
        if i % args.print_loss_freq == 0:
            t = (time.time() - iter_start_time) / batch_size
            visualizer.print_current_losses(epoch, i, current_losses, t, t_data)
            if args.display_id > 0 and i % args.plot_loss_freq == 0:
                visualizer.plot_current_losses(epoch, float(i) / len(train_loader), args, current_losses)
            if args.print_gradient:
                for net_name, net in model_dict.items():
                    # if net_name != 'Q':
                    #     continue
                    if isinstance(net, list):
                        continue
                    print(('================ NET %s ================' % net_name))
                    for name, param in net.named_parameters():
                        print_param_info(name, param, print_std=True)

        if i % args.visdom_visual_freq == 0:
            save_result = i % args.update_html_freq == 0

            current_visuals = OrderedDict()
            current_visuals['real'] = img.detach()
            current_visuals['fake'] = fake.detach()
            current_visuals['rand_fake'] = rand_fake.detach()
            if args.lambda_G_recon:
                current_visuals['recon'] = recon.detach()
                current_visuals['rand_recon'] = rand_recon.detach()
            if args.lambda_G_rand_recon > 0:
                current_visuals['rand_fake_recon'] = rand_fake_recon.detach()
            current_visuals['real_aligned'] = real_aligned.detach()
            current_visuals['fake_aligned'] = fake_aligned.detach()
            current_visuals['rand_fake_aligned'] = rand_fake_aligned.detach()
            if args.lambda_G_recon:
                current_visuals['recon_aligned'] = recon_aligned.detach()
                current_visuals['rand_recon_aligned'] = rand_recon_aligned.detach()

            try:
                with time_limit(60):
                    visualizer.display_current_results(current_visuals, epoch, save_result, args)
            except TimeoutException:
                visualizer.logger.log('TIME OUT visualizer.display_current_results epoch:{} iter:{}. Change display_id to -1'.format(epoch, i))
                args.display_id = -1

        if (i + 1) % args.save_iter_freq == 0:
            save_model(epoch, model_dict, optimizer_dict, args, iter=i)
            if args.display_id > 0:
                visualizer.vis.save([args.name])
                visualizer.overview_vis.save(['overview'])

        if (i + 1) % args.html_iter_freq == 0:
            test(test_loader, model_dict, criterion_dict, visualizer, epoch, args, fixed_z, fixed_rand_z, i)

        if (i + 1) % args.print_loss_freq == 0:
            iter_data_time = time.time()