Ejemplo n.º 1
0
 def evaluate(self):
     args = self.args
     nets_ema = self.nets_ema
     resume_iter = args.resume_iter
     self._load_checkpoint(args.resume_iter)
     calculate_metrics(nets_ema, args, step=resume_iter, mode='latent')
     calculate_metrics(nets_ema, args, step=resume_iter, mode='reference')
Ejemplo n.º 2
0
 def evaluate(self):
     args = self.args
     nets_ema = self.nets_ema
     resume_iter = args.resume_iter
     self._load_checkpoint(args.resume_iter)
     fid_values, fid_mean = calculate_metrics(nets_ema,
                                              args,
                                              step=resume_iter,
                                              mode='test')
     return fid_values, fid_mean
Ejemplo n.º 3
0
 def evaluate(self):
     args = self.args
     nets_ema = self.nets_ema
     for name in self.nets_ema:
         self.nets_ema[name].eval()
     resume_iter = args.resume_iter
     print("check point loading.......")
     self._load_checkpoint(args.resume_iter)
     print("check point loaded.......")
     return calculate_metrics(nets_ema,
                              args,
                              step=resume_iter,
                              mode='latent')
     # calculate_metrics(nets_ema, args, step=resume_iter, mode='reference')
     for name in self.nets_ema:
         self.nets_ema[name].train()
Ejemplo n.º 4
0
    def train(self, loaders):
        args = self.args
        nets = self.nets
        nets_ema = self.nets_ema
        optims = self.optims

        # fetch random validation images for debugging
        fetcher = InputFetcher(loaders.src, loaders.ref, args.latent_dim, 'train')
        fetcher_val = InputFetcher(loaders.val, None, args.latent_dim, 'val')
        inputs_val = next(fetcher_val)

        # resume training if necessary
        if args.resume_iter > 0:
            self._load_checkpoint(args.resume_iter)

        # remember the initial value of ds weight
        initial_lambda_ds = args.lambda_ds

        print('Start training...')
        start_time = time.time()
        for i in range(args.resume_iter, args.total_iters):
            # fetch images and labels
            inputs = next(fetcher)
            x_real, y_org = inputs.x_src, inputs.y_src
            x_ref, x_ref2, y_trg = inputs.x_ref, inputs.x_ref2, inputs.y_ref
            z_trg, z_trg2 = inputs.z_trg, inputs.z_trg2

            masks = nets.fan.get_heatmap(x_real) if args.w_hpf > 0 else None

            # train the discriminator
            d_loss, d_losses_latent = compute_d_loss(
                nets, args, x_real, y_org, y_trg, z_trg=z_trg, masks=masks)
            self._reset_grad()
            d_loss.backward()
            optims.discriminator.step()

            d_loss, d_losses_ref = compute_d_loss(
                nets, args, x_real, y_org, y_trg, x_ref=x_ref, masks=masks)
            self._reset_grad()
            d_loss.backward()
            optims.discriminator.step()

            # train the generator
            g_loss, g_losses_latent = compute_g_loss(
                nets, args, x_real, y_org, y_trg, z_trgs=[z_trg, z_trg2], masks=masks)
            self._reset_grad()
            g_loss.backward()
            optims.generator.step()
            optims.mapping_network.step()
            optims.style_encoder.step()

            g_loss, g_losses_ref = compute_g_loss(
                nets, args, x_real, y_org, y_trg, x_refs=[x_ref, x_ref2], masks=masks)
            self._reset_grad()
            g_loss.backward()
            optims.generator.step()

            # compute moving average of network parameters
            moving_average(nets.generator, nets_ema.generator, beta=0.999)
            moving_average(nets.mapping_network, nets_ema.mapping_network, beta=0.999)
            moving_average(nets.style_encoder, nets_ema.style_encoder, beta=0.999)

            # decay weight for diversity sensitive loss
            if args.lambda_ds > 0:
                args.lambda_ds -= (initial_lambda_ds / args.ds_iter)

            # print out log info
            if (i+1) % args.print_every == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))[:-7]
                log = "Elapsed time [%s], Iteration [%i/%i], " % (elapsed, i+1, args.total_iters)
                all_losses = dict()
                for loss, prefix in zip([d_losses_latent, d_losses_ref, g_losses_latent, g_losses_ref],
                                        ['D/latent_', 'D/ref_', 'G/latent_', 'G/ref_']):
                    for key, value in loss.items():
                        all_losses[prefix + key] = value
                all_losses['G/lambda_ds'] = args.lambda_ds
                log += ' '.join(['%s: [%.4f]' % (key, value) for key, value in all_losses.items()])
                print(log)
                wandb.log({"Losses":all_losses})

            # generate images for debugging
            if (i+1) % args.sample_every == 0:
                os.makedirs(args.sample_dir, exist_ok=True)
                utils.debug_image(nets_ema, args, inputs=inputs_val, step=i+1)

            # save model checkpoints
            if (i+1) % args.save_every == 0:
                self._save_checkpoint(step=i+1)

            # compute FID and LPIPS if necessary
            if (i+1) % args.eval_every == 0:
                calculate_metrics(nets_ema, args, i+1, mode='latent')
                calculate_metrics(nets_ema, args, i+1, mode='reference')
Ejemplo n.º 5
0
    def train(self, loaders):
        if self.args.loss == 'arcface':
            BACKBONE_RESUME_ROOT = 'D:/face-recognition/stargan-v2-master/ms1m_ir50/backbone_ir50_ms1m_epoch120.pth'

            INPUT_SIZE = [112, 112]
            arcface = IR_50(INPUT_SIZE)

            if os.path.isfile(BACKBONE_RESUME_ROOT):
                arcface.load_state_dict(torch.load(BACKBONE_RESUME_ROOT))
                print("Loading Backbone Checkpoint '{}'".format(
                    BACKBONE_RESUME_ROOT))

            DEVICE = torch.device(
                "cuda:0" if torch.cuda.is_available() else "cpu")
            criterion_id = arcface.to(DEVICE)
        elif self.args.loss == 'perceptual':
            criterion_id = network.LossEG(False, 0)

        args = self.args
        nets = self.nets
        nets_ema = self.nets_ema

        optims = self.optims
        for net in nets.keys():
            if net == 'linear_classfier':

                optims[net] = torch.optim.Adam(params=nets[net].parameters(),
                                               lr=args.lr2,
                                               betas=[args.beta1, args.beta2],
                                               weight_decay=args.weight_decay)

        print(optims)
        print(self.nets.keys())
        # assert False

        # fetch random validation images for debugging
        if args.dataset == 'mpie':
            fetcher = InputFetcher_mpie(loaders.src, args.latent_dim, 'train')
            fetcher_val = InputFetcher_mpie(loaders.val, args.latent_dim,
                                            'val')

        elif args.dataset == '300vw':
            fetcher = InputFetcher_300vw(loaders.src, args.latent_dim, 'train')
            fetcher_val = InputFetcher_300vw(loaders.val, args.latent_dim,
                                             'val')
        inputs_val = next(fetcher_val)

        # resume training if necessary
        if args.resume_iter > 0:
            self._load_checkpoint(args.resume_iter)

        # remember the initial value of ds weight
        initial_lambda_ds = args.lambda_ds

        print('Start training...')
        start_time = time.time()
        for i in range(args.resume_iter, args.total_iters):
            if (i + 1) % args.decay_every == 0:

                # print('54555')
                times = (i + 1) / args.decay_every
                # print(args.lr*0.1**int(times))
                optims = Munch()
                for net in nets.keys():
                    if net == 'fan':
                        continue
                    optims[net] = torch.optim.Adam(
                        params=nets[net].parameters(),
                        lr=args.lr * 0.1**int(times),
                        betas=[args.beta1, args.beta2],
                        weight_decay=args.weight_decay)

                # optims = torch.optim.Adam(
                #             params=self.nets[net].parameters(),
                #             lr=args.lr*0.1**int(times),
                #             betas=[args.beta1, args.beta2],
                #             weight_decay=args.weight_decay)

            # fetch images and labels
            inputs = next(fetcher)
            # x_label, x2_label, x3_label, x4_label, x1_one_hot, x3_one_hot, x1_id, x3_id

            x1_label = inputs.x_label
            x2_label = inputs.x2_label
            x3_label = inputs.x3_label
            if args.dataset == 'mpie':
                x4_label = inputs.x4_label

                param_x4 = x4_label[:, 0, :].unsqueeze(0)
                param_x4 = param_x4.view(-1, 136).float()

            x1_one_hot, x3_one_hot = inputs.x1_one_hot, inputs.x3_one_hot
            x1_id, x3_id = inputs.x1_id, inputs.x3_id

            param_x1 = x1_label[:, 0, :].unsqueeze(0)
            param_x1 = param_x1.view(-1, 136).float()

            param_x2 = x2_label[:, 0, :].unsqueeze(0)
            param_x2 = param_x2.view(-1, 136).float()

            param_x3 = x3_label[:, 0, :].unsqueeze(0)
            param_x3 = param_x3.view(-1, 136).float()

            one_hot_x1 = x1_one_hot[:, 0, :].unsqueeze(0)
            one_hot_x1 = one_hot_x1.view(-1, 150).float()

            one_hot_x3 = x3_one_hot[:, 0, :].unsqueeze(0)
            one_hot_x3 = one_hot_x3.view(-1, 150).float()

            # print(param_x1.shape)
            # print(one_hot_x1.shape)
            # assert False

            # masks = nets.fan.get_heatmap(x_real) if args.w_hpf > 0 else None

            # linear_decoder = Linear_decoder()
            # id_linear_encoder = Id_linear_encoder()
            # lm_linear_encoder = Lm_linear_encoder()
            # linear_discriminator = Linear_discriminator()

            if args.dataset == '300vw':
                print('300vw')

            elif args.dataset == 'mpie':
                # train the discriminator
                d_tran_loss, d_tran_losses = compute_d_tran_loss(
                    nets,
                    args,
                    param_x1,
                    param_x2,
                    param_x3,
                    param_x4,
                    one_hot_x1,
                    one_hot_x3,
                    x1_id,
                    x3_id,
                    masks=None,
                    loss_select=args.loss)
                self._reset_grad()
                d_tran_loss.backward()
                optims.linear_discriminator.step()
                moving_average(nets.linear_discriminator,
                               nets_ema.linear_discriminator,
                               beta=0.999)

                # train the classfier
                c_loss, c_losses = compute_c_loss(nets,
                                                  args,
                                                  param_x1,
                                                  param_x2,
                                                  param_x3,
                                                  param_x4,
                                                  one_hot_x1,
                                                  one_hot_x3,
                                                  x1_id,
                                                  x3_id,
                                                  masks=None,
                                                  loss_select=args.loss)
                self._reset_grad()
                c_loss.backward()
                optims.linear_classfier.step()
                moving_average(nets.linear_classfier,
                               nets_ema.linear_classfier,
                               beta=0.999)

                # train the transformer
                t_loss, t_losses = compute_t_loss(nets,
                                                  args,
                                                  param_x1,
                                                  param_x2,
                                                  param_x3,
                                                  param_x4,
                                                  one_hot_x1,
                                                  one_hot_x3,
                                                  x1_id,
                                                  x3_id,
                                                  masks=None,
                                                  loss_select=args.loss)
                self._reset_grad()
                t_loss.backward()
                optims.linear_decoder.step()
                optims.lm_linear_encoder.step()
                optims.id_linear_encoder.step()

                moving_average(nets.linear_decoder,
                               nets_ema.linear_decoder,
                               beta=0.999)
                moving_average(nets.lm_linear_encoder,
                               nets_ema.lm_linear_encoder,
                               beta=0.999)
                moving_average(nets.id_linear_encoder,
                               nets_ema.id_linear_encoder,
                               beta=0.999)

            # # train the discriminator
            # d_loss, d_losses = compute_d_loss(
            #     nets, args, x1_source,x1_source_lm, x2_target, x2_target_lm, masks=None, loss_select = args.loss)
            # self._reset_grad()
            # d_loss.backward()
            # optims.discriminator.step()
            #
            #
            #
            # # train the generator
            # g_loss, g_losses = compute_g_loss(
            #     nets, args, x1_source, x1_source_lm, x2_target, x2_target_lm, criterion_id, masks=None, loss_select = args.loss)
            # self._reset_grad()
            # g_loss.backward()
            #
            # if args.transformer:
            #     optims.lm_encoder.step()
            #     optims.lm_transformer.step()
            #     optims.lm_decoder.step()
            #     optims.style_encoder.step()
            # else:
            #     optims.generator.step()
            #     optims.style_encoder.step()
            #
            # if args.transformer:
            #     moving_average(nets.lm_encoder, nets_ema.lm_encoder, beta=0.999)
            #     moving_average(nets.lm_transformer, nets_ema.lm_transformer, beta=0.999)
            #     moving_average(nets.lm_decoder, nets_ema.lm_decoder, beta=0.999)
            #     moving_average(nets.style_encoder, nets_ema.style_encoder, beta=0.999)
            #
            # else:
            #     # compute moving average of network parameters
            #     moving_average(nets.generator, nets_ema.generator, beta=0.999)
            #     # moving_average(nets.mapping_network, nets_ema.mapping_network, beta=0.999)
            #     moving_average(nets.style_encoder, nets_ema.style_encoder, beta=0.999)

            # decay weight for diversity sensitive loss
            if args.lambda_ds > 0:
                args.lambda_ds -= (initial_lambda_ds / args.ds_iter)

            # print out log info
            if (i + 1) % args.print_every == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))[:-7]
                log = "Elapsed time [%s], Iteration [%i/%i], " % (
                    elapsed, i + 1, args.total_iters)
                all_losses = dict()
                # for loss, prefix in zip([d_losses, g_losses],['D/', 'G/']):
                for loss, prefix in zip([d_tran_losses, t_losses, c_losses],
                                        ['D/', 'G/', 'C/']):

                    for key, value in loss.items():
                        all_losses[prefix + key] = value
                all_losses['G/lambda_ds'] = args.lambda_ds
                log += ' '.join([
                    '%s: [%.4f]' % (key, value)
                    for key, value in all_losses.items()
                ])
                print(log)
                for key, value in all_losses.items():
                    self.writer.add_scalar(key, value, i + 1)

            # generate images for debugging
            if (i + 1) % args.sample_every == 0:
                os.makedirs(args.sample_dir, exist_ok=True)
                utils.debug_image(nets_ema,
                                  args,
                                  inputs=inputs_val,
                                  step=i + 1)

            # save model checkpoints
            if (i + 1) % args.save_every == 0:
                self._save_checkpoint(step=i + 1)

            # compute FID and LPIPS if necessary
            if (i + 1) % args.eval_every == 0:
                calculate_metrics(nets_ema, args, i + 1, mode='latent')
                calculate_metrics(nets_ema, args, i + 1, mode='reference')
        self.writer.close()
Ejemplo n.º 6
0
    def train(self, loaders):
        args = self.args
        nets = self.nets
        nets_ema = self.nets_ema
        optims = self.optims

        # fetch random validation images for debugging
        fetcher = InputFetcher(loaders.src, loaders.ref, args.latent_dim,
                               'train')
        fetcher_val = InputFetcher(loaders.val, None, args.latent_dim, 'val')
        inputs_val = next(fetcher_val)

        # resume training if necessary
        if args.resume_iter > 0:
            self._load_checkpoint(args.resume_iter)

        # remember the initial value of ds weight
        initial_lambda_ds = args.lambda_ds

        print('Start training...')
        start_time = time.time()
        for i in range(args.resume_iter, args.total_iters):
            # fetch images and labels
            inputs = next(fetcher)
            x_real, y_org = inputs.x_src, inputs.y_src
            x_ref, x_ref2, y_trg = inputs.x_ref, inputs.x_ref2, inputs.y_ref
            z_trg, z_trg2 = inputs.z_trg, inputs.z_trg2

            masks = nets.fan.get_heatmap(x_real) if args.w_hpf > 0 else None

            # train the discriminator
            d_loss, d_losses_latent = compute_d_loss(nets,
                                                     args,
                                                     x_real,
                                                     y_org,
                                                     y_trg,
                                                     z_trg=z_trg,
                                                     masks=masks)
            self._reset_grad()
            d_loss.backward()
            optims.discriminator.step()

            d_loss, d_losses_ref = compute_d_loss(nets,
                                                  args,
                                                  x_real,
                                                  y_org,
                                                  y_trg,
                                                  x_ref=x_ref,
                                                  masks=masks)
            self._reset_grad()
            d_loss.backward()
            optims.discriminator.step()

            # train the generator
            g_loss, g_losses_latent = compute_g_loss(nets,
                                                     args,
                                                     x_real,
                                                     y_org,
                                                     y_trg,
                                                     z_trgs=[z_trg, z_trg2],
                                                     masks=masks)
            self._reset_grad()
            g_loss.backward()
            optims.generator.step()
            optims.mapping_network.step()
            optims.style_encoder.step()

            g_loss, g_losses_ref = compute_g_loss(nets,
                                                  args,
                                                  x_real,
                                                  y_org,
                                                  y_trg,
                                                  x_refs=[x_ref, x_ref2],
                                                  masks=masks)
            self._reset_grad()
            g_loss.backward()
            optims.generator.step()

            # compute moving average of network parameters
            moving_average(nets.generator, nets_ema.generator, beta=0.999)
            moving_average(nets.mapping_network,
                           nets_ema.mapping_network,
                           beta=0.999)
            moving_average(nets.style_encoder,
                           nets_ema.style_encoder,
                           beta=0.999)

            # decay weight for diversity sensitive loss
            if args.lambda_ds > 0:
                args.lambda_ds -= (initial_lambda_ds / args.ds_iter)

            # print out log info
            if (i + 1) % args.print_every == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))[:-7]
                log = "Elapsed time [%s], Iteration [%i/%i], " % (
                    elapsed, i + 1, args.total_iters)
                all_losses = dict()
                for loss, prefix in zip([
                        d_losses_latent, d_losses_ref, g_losses_latent,
                        g_losses_ref
                ], ['D/latent_', 'D/ref_', 'G/latent_', 'G/ref_']):
                    for key, value in loss.items():
                        all_losses[prefix + key] = value
                all_losses['G/lambda_ds'] = args.lambda_ds
                log += ' '.join([
                    '%s: [%.4f]' % (key, value)
                    for key, value in all_losses.items()
                ])
                print(log)

            # generate images for debugging
            if (i + 1) % args.sample_every == 0:
                os.makedirs(args.sample_dir, exist_ok=True)
                utils.debug_image(nets_ema,
                                  args,
                                  inputs=inputs_val,
                                  step=i + 1)

            # save model checkpoints
            if (i + 1) % args.save_every == 0:
                self._save_checkpoint(step=i + 1)

            # compute FID and LPIPS if necessary
            if (i + 1) % args.eval_every == 0:
                calculate_metrics(nets_ema, args, i + 1, mode='latent')
                calculate_metrics(nets_ema, args, i + 1, mode='reference')

            import adaiw
            if (i + 1) % args.print_learned == 0:
                alphas_list = []
                for module in self.nets.generator.decode:
                    if isinstance(module.norm1, adaiw.AdaIN):
                        alphas_list.append([
                            ('norm_1_white', module.norm1.alpha_white_param,
                             'norm_1_color', module.norm1.alpha_color_param),
                            ('norm_2_white', module.norm2.alpha_white_param,
                             'norm_2_color', module.norm2.alpha_color_param)
                        ])

                txt_f = os.path.join(args.notes_path, "alphas.txt")
                with open(txt_f, "a+") as f:
                    f.write("{} iterations: \n".format(i + 1))
                    for item in alphas_list:
                        f.write("{}\n".format(item))

            if (i + 1) % args.print_std == 0:
                with open(os.path.join(args.notes_path, "std.txt"), "a+") as f:
                    f.write("{} iterations: \n".format(i + 1))
                    for j, module in enumerate(self.nets.generator.decode):
                        print([
                            module.std_b4_norm_1.item(),
                            module.std_b4_norm_2.item(),
                            module.std_b4_join.item(),
                            module.std_b4_output.item()
                        ],
                              file=f)

            if (i + 1) % args.print_sqrt_error == 0:
                with open(os.path.join(args.notes_path, "sqrt_errors.txt"),
                          "a+") as f:
                    f.write("{} iterations: \n".format(i + 1))
                    all_errors = []
                    for j, module in enumerate(self.nets.generator.decode):
                        errors = []
                        for norm in [module.norm1, module.norm2]:
                            if isinstance(norm, adaiw.AdaIN) and isinstance(
                                    norm.normalizer,
                                    adaiw.normalizer.Whitening):
                                errors.append(norm.normalizer.last_error)
                        all_errors.append(tuple(errors))
                    print(all_errors, file=f)

            if (i + 1) % args.print_color == 0:
                with torch.no_grad():
                    with open(os.path.join(args.notes_path, "last_color.txt"),
                              "a+") as f:
                        f.write("{} iterations: \n".format(i + 1))
                        for _, module in enumerate(self.nets.generator.decode):
                            if hasattr(module.norm1, 'last_injected_stat'):
                                diag1, tri1 = module.norm1.last_injected_stat.diagonal(
                                ), module.norm1.last_injected_stat.tril(-1)
                                print("Mean Diag:",
                                      diag1[0].mean().item(),
                                      "Std Diag:",
                                      diag1[0].std().item(),
                                      file=f)
                                print("Mean Tril:",
                                      tri1[0].mean().item(),
                                      "Std Tril:",
                                      tri1[0].std().item(),
                                      file=f)
                                diag2, tri2 = module.norm2.last_injected_stat.diagonal(
                                ), module.norm2.last_injected_stat.tril(-1)
                                print("Mean Diag:",
                                      diag2[0].mean().item(),
                                      "Std Diag:",
                                      diag2[0].std().item(),
                                      file=f)
                                print("Mean Tril:",
                                      tri2[0].mean().item(),
                                      "Std Tril:",
                                      tri2[0].std().item(),
                                      file=f)
                                print(file=f)
Ejemplo n.º 7
0
    def train(self, loaders):
        if self.args.loss == 'arcface':
            BACKBONE_RESUME_ROOT = 'D:/face-recognition/stargan-v2-master/ms1m_ir50/backbone_ir50_ms1m_epoch120.pth'

            INPUT_SIZE = [112, 112]
            arcface = IR_50(INPUT_SIZE)

            if os.path.isfile(BACKBONE_RESUME_ROOT):
                arcface.load_state_dict(torch.load(BACKBONE_RESUME_ROOT))
                print("Loading Backbone Checkpoint '{}'".format(
                    BACKBONE_RESUME_ROOT))

            DEVICE = torch.device(
                "cuda:0" if torch.cuda.is_available() else "cpu")
            criterion_id = arcface.to(DEVICE)
        elif self.args.loss == 'perceptual':
            criterion_id = network.LossEG(False, 0)
        elif self.args.loss == 'lightcnn':
            BACKBONE_RESUME_ROOT = 'D:/face-reenactment/stargan-v2-master/FR_Pretrained_Test/Pretrained/LightCNN/LightCNN_29Layers_V2_checkpoint.pth.tar'

            # INPUT_SIZE = [128, 128]
            Model = LightCNN_29Layers_v2()

            if os.path.isfile(BACKBONE_RESUME_ROOT):

                Model = WrappedModel(Model)
                checkpoint = torch.load(BACKBONE_RESUME_ROOT)
                Model.load_state_dict(checkpoint['state_dict'])
                print("Loading Backbone Checkpoint '{}'".format(
                    BACKBONE_RESUME_ROOT))

            DEVICE = torch.device(
                "cuda:0" if torch.cuda.is_available() else "cpu")
            criterion_id = Model.to(DEVICE)
            # criterion_id = FR_Pretrained_Test.LossEG(False, 0)
        if self.args.id_embed:
            id_embedder = network.vgg_feature(False, 0)
        else:
            id_embedder = None

        args = self.args
        nets = self.nets
        nets_ema = self.nets_ema
        optims = self.optims

        # fetch random validation images for debugging
        fetcher = InputFetcher(loaders.src, args.latent_dim, 'train',
                               args.multi)
        fetcher_val = InputFetcher(loaders.val, args.latent_dim, 'val',
                                   args.multi)
        inputs_val = next(fetcher_val)

        # resume training if necessary
        if args.resume_iter > 0:
            self._load_checkpoint(args.resume_iter)

        # remember the initial value of ds weight
        initial_lambda_ds = args.lambda_ds

        print('Start training...')
        start_time = time.time()
        for i in range(args.resume_iter, args.total_iters):
            # fetch images and labels
            inputs = next(fetcher)

            if args.multi:
                x1_1_source = inputs.x1
                x1_2_source = inputs.x2
                x1_3_source = inputs.x3
                x1_4_source = inputs.x4

                x2_target, x2_target_lm = inputs.x5, inputs.x5_lm

                d_loss, d_losses = compute_d_loss_multi(nets,
                                                        args,
                                                        x1_1_source,
                                                        x1_2_source,
                                                        x1_3_source,
                                                        x1_4_source,
                                                        x2_target,
                                                        x2_target_lm,
                                                        masks=None,
                                                        loss_select=args.loss)
                self._reset_grad()
                d_loss.backward()
                optims.discriminator.step()

                # train the generator
                g_loss, g_losses = compute_g_loss_multi(nets,
                                                        args,
                                                        x1_1_source,
                                                        x1_2_source,
                                                        x1_3_source,
                                                        x1_4_source,
                                                        x2_target,
                                                        x2_target_lm,
                                                        criterion_id,
                                                        masks=None,
                                                        loss_select=args.loss)
                self._reset_grad()
                g_loss.backward()
                if args.id_embed:
                    optims.generator.step()
                    optims.mlp.step()
                else:
                    optims.generator.step()

            else:
                x1_source_lm = None
                x1_source = inputs.x1
                x2_target, x2_target_lm = inputs.x2, inputs.x2_lm
                # label = inputs.label

                # masks = nets.fan.get_heatmap(x_real) if args.w_hpf > 0 else None

                # train the discriminator
                d_loss, d_losses = compute_d_loss(nets,
                                                  args,
                                                  x1_source,
                                                  x1_source_lm,
                                                  x2_target,
                                                  x2_target_lm,
                                                  masks=None,
                                                  loss_select=args.loss,
                                                  embedder=id_embedder)
                self._reset_grad()
                d_loss.backward()
                optims.discriminator.step()

                # train the generator
                g_loss, g_losses = compute_g_loss(nets,
                                                  args,
                                                  x1_source,
                                                  x1_source_lm,
                                                  x2_target,
                                                  x2_target_lm,
                                                  criterion_id,
                                                  masks=None,
                                                  loss_select=args.loss,
                                                  embedder=id_embedder)
                self._reset_grad()
                g_loss.backward()
                if args.id_embed:
                    optims.generator.step()
                    optims.mlp.step()
                else:
                    optims.generator.step()

            # compute moving average of network parameters
            moving_average(nets.generator, nets_ema.generator, beta=0.999)
            # moving_average(nets.mapping_network, nets_ema.mapping_network, beta=0.999)
            if args.id_embed:
                moving_average(nets.mlp, nets_ema.mlp, beta=0.999)
            else:
                moving_average(nets.style_encoder,
                               nets_ema.style_encoder,
                               beta=0.999)

            # decay weight for diversity sensitive loss
            if args.lambda_ds > 0:
                args.lambda_ds -= (initial_lambda_ds / args.ds_iter)

            # print out log info
            if (i + 1) % args.print_every == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))[:-7]
                log = "Elapsed time [%s], Iteration [%i/%i], " % (
                    elapsed, i + 1, args.total_iters)
                all_losses = dict()
                for loss, prefix in zip([d_losses, g_losses], ['D/', 'G/']):
                    for key, value in loss.items():
                        all_losses[prefix + key] = value
                all_losses['G/lambda_ds'] = args.lambda_ds
                log += ' '.join([
                    '%s: [%.4f]' % (key, value)
                    for key, value in all_losses.items()
                ])
                print(log)
                for key, value in all_losses.items():
                    self.writer.add_scalar(key, value, i + 1)

            # generate images for debugging
            if (i + 1) % args.sample_every == 0:
                os.makedirs(args.sample_dir, exist_ok=True)

                utils.debug_image(nets_ema,
                                  args,
                                  inputs=inputs_val,
                                  step=i + 1,
                                  embedder=id_embedder)

            # save model checkpoints
            if (i + 1) % args.save_every == 0:
                self._save_checkpoint(step=i + 1)

            # compute FID and LPIPS if necessary
            if (i + 1) % args.eval_every == 0:
                calculate_metrics(nets_ema, args, i + 1, mode='latent')
                # calculate_metrics(nets_ema, args, i+1, mode='reference')
        self.writer.close()
    def train(self, loaders):
        args = self.args
        nets = self.nets
        nets_ema = self.nets_ema
        optims = self.optims

        dbg = 0

        # fetch random validation images for debugging
        #fetcher = InputFetcher(loaders.src, loaders.ref, args.latent_dim, 'train')
        #fetcher_val = InputFetcher(loaders.val, None, args.latent_dim, 'val')
        #inputs_val = next(fetcher_val)

        fc2_loader, test_loader = loaders
        fetcher = FC2Fetcher(fc2_loader, None, args.latent_dim)
        inputs_val = next(fetcher)
        s_trg_path = os.getcwd() + "/s_trg_out"
        self.grid = None

        if not os.path.exists(s_trg_path):
            os.makedirs(s_trg_path)

        # resume training if necessary
        if args.resume_iter > 0:
            self._load_checkpoint(args.resume_iter)

        # remember the initial value of ds weight
        initial_lambda_ds = args.lambda_ds

        print('Start training...')
        start_time = time.time()
        for i in range(args.resume_iter, args.total_iters):
            # fetch images and labels
            inputs = next(fetcher)
            x_real, x_real2, y_org = inputs.x_src, inputs.x2_src, inputs.y_src
            x_ref, y_trg = inputs.x_ref, inputs.y_ref
            flow, mask, = inputs.flow, inputs.mask
            z_trg, z_trg2 = inputs.z_trg, inputs.z_trg2

            masks = nets.fan.get_heatmap(x_real) if args.w_hpf > 0 else None

            # train the discriminator
            d_loss, d_losses_latent = self.compute_d_loss(nets,
                                                          args,
                                                          x_real,
                                                          y_org,
                                                          y_trg,
                                                          z_trg=z_trg,
                                                          masks=masks)
            self._reset_grad()
            d_loss.backward()
            optims.discriminator.step()

            d_loss, d_losses_ref = self.compute_d_loss(nets,
                                                       args,
                                                       x_real,
                                                       y_org,
                                                       y_trg,
                                                       x_ref=x_ref,
                                                       masks=masks)
            self._reset_grad()
            d_loss.backward()
            optims.discriminator.step()

            # train the generator
            g_loss, g_losses_latent, s_trg_lat = self.compute_g_loss(
                nets,
                args,
                x_real,
                x_real2,
                flow,
                mask,
                y_org,
                y_trg,
                z_trgs=[z_trg, z_trg2],
                masks=masks)
            self._reset_grad()
            g_loss.backward()
            optims.generator.step()
            optims.mapping_network.step()
            optims.style_encoder.step()

            g_loss, g_losses_ref, s_trg_ref = self.compute_g_loss(
                nets,
                args,
                x_real,
                x_real2,
                flow,
                mask,
                y_org,
                y_trg,
                x_refs=[x_ref, None],
                masks=masks)
            self._reset_grad()
            g_loss.backward()
            optims.generator.step()

            # compute moving average of network parameters
            moving_average(nets.generator, nets_ema.generator, beta=0.999)
            moving_average(nets.mapping_network,
                           nets_ema.mapping_network,
                           beta=0.999)
            moving_average(nets.style_encoder,
                           nets_ema.style_encoder,
                           beta=0.999)

            # decay weight for diversity sensitive loss
            if args.lambda_ds > 0:
                args.lambda_ds -= (initial_lambda_ds / args.ds_iter)

            # print out log info
            if (i + 1) % args.print_every == 0 + dbg:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))[:-7]
                log = "Elapsed time [%s], Iteration [%i/%i], " % (
                    elapsed, i + 1, args.total_iters)
                all_losses = dict()
                for loss, prefix in zip([
                        d_losses_latent, d_losses_ref, g_losses_latent,
                        g_losses_ref
                ], ['D/latent_', 'D/ref_', 'G/latent_', 'G/ref_']):
                    for key, value in loss.items():
                        all_losses[prefix + key] = value
                all_losses['G/lambda_ds'] = args.lambda_ds
                log += ' '.join([
                    '%s: [%.4f]' % (key, value)
                    for key, value in all_losses.items()
                ])
                print(log)
                with open('losses.txt', 'a') as log_file:
                    log_file.write(log + '\n')

            # generate images for debugging
            if (i + 1) % args.sample_every == 0 + dbg:
                os.makedirs(args.sample_dir, exist_ok=True)
                utils.debug_image(nets_ema,
                                  args,
                                  inputs=inputs_val,
                                  step=i + 1)

            # save model checkpoints
            if (i + 1) % args.save_every == 0 + dbg:
                self._save_checkpoint(step=i + 1)

            # compute FID and LPIPS if necessary
            if (i + 1) % args.eval_every == 0:
                calculate_metrics(nets_ema, args, i + 1, mode='latent')
                calculate_metrics(nets_ema, args, i + 1, mode='reference')

            # save sintel debug
            if (i + 1) % args.sample_every == 0 + dbg:
                self.debugSintel(nets, args, i + 1, test_loader)
                self.debugSintel(nets, args, i + 1, test_loader,
                                 inputs_val.x_ref)

            if (i + 1) % 100 == 0 + dbg:
                s_trg_out = torch.cat((s_trg_lat, s_trg_ref), 0).numpy()
                np.save(s_trg_path + "/s_trg_" + str(i + 1) + ".npy",
                        s_trg_out)

            if dbg == 1:
                blah
Ejemplo n.º 9
0
    def train(self, loaders):
        BACKBONE_RESUME_ROOT = 'D:/face-recognition/stargan-v2-master/ms1m_ir50/backbone_ir50_ms1m_epoch120.pth'

        INPUT_SIZE = [112, 112]
        arcface = IR_50(INPUT_SIZE)

        if os.path.isfile(BACKBONE_RESUME_ROOT):
            arcface.load_state_dict(torch.load(BACKBONE_RESUME_ROOT))
            print("Loading Backbone Checkpoint '{}'".format(
                BACKBONE_RESUME_ROOT))

        DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        arcface = arcface.to(DEVICE)

        args = self.args
        nets = self.nets
        # conf = self.conf
        # arcface = self.arcface
        nets_ema = self.nets_ema
        optims = self.optims

        # fetch random validation images for debugging
        fetcher = InputFetcher(loaders.src, args.latent_dim, 'train')
        fetcher_val = InputFetcher(loaders.val, args.latent_dim, 'val')
        inputs_val = next(fetcher_val)

        # resume training if necessary
        if args.resume_iter > 0:
            self._load_checkpoint(args.resume_iter)

        # remember the initial value of ds weight
        initial_lambda_ds = args.lambda_ds

        print('Start training...')
        start_time = time.time()
        for i in range(args.resume_iter, args.total_iters):
            # fetch images and labels
            inputs = next(fetcher)

            x1_source, x1_source_lm = inputs.x1, inputs.x_lm
            x2_target, x2_target_lm = inputs.x2, inputs.x2_lm

            # x_source_4_channel, x2_target, x2_target_lm = inputs.x1_c, inputs.x2, inputs.x2_lm
            # x_source_4_channel = nn.functional.interpolate(x_source_4_channel[:, :, :, :], size=(128, 128), mode='bilinear')

            # x_real, y_org = inputs.x_src, inputs.y_src
            # x_ref, x_ref2, y_trg = inputs.x_ref, inputs.x_ref2, inputs.y_ref
            # z_trg, z_trg2 = inputs.z_trg, inputs.z_trg2

            # masks = nets.fan.get_heatmap(x_real) if args.w_hpf > 0 else None

            # train the discriminator
            d_loss, d_losses = compute_d_loss(nets,
                                              args,
                                              x1_source,
                                              x1_source_lm,
                                              x2_target,
                                              x2_target_lm,
                                              z_trg=None,
                                              masks=None)
            self._reset_grad()
            d_loss.backward()
            optims.discriminator.step()

            # d_loss, d_losses_ref = compute_d_loss(
            #     nets, args, x_real, y_org, y_trg, x_ref=x_ref, masks=masks)
            # self._reset_grad()
            # d_loss.backward()
            # optims.discriminator.step()

            # train the generator
            # g_loss, g_losses = compute_g_loss(
            #     nets, args, x1_source, x2_target, x2_target_lm, arcface, conf, z_trgs=None, masks=None)
            g_loss, g_losses = compute_g_loss(nets,
                                              args,
                                              x1_source,
                                              x1_source_lm,
                                              x2_target,
                                              x2_target_lm,
                                              arcface,
                                              z_trgs=None,
                                              masks=None)
            self._reset_grad()
            g_loss.backward()
            optims.generator.step()
            optims.style_encoder.step()

            # g_loss, g_losses_ref = compute_g_loss(
            #     nets, args, x_real, y_org, y_trg, arcface, conf, x_refs=[x_ref, x_ref2], masks=masks)
            # self._reset_grad()
            # g_loss.backward()
            # optims.generator.step()

            # compute moving average of network parameters
            moving_average(nets.generator, nets_ema.generator, beta=0.999)
            # moving_average(nets.mapping_network, nets_ema.mapping_network, beta=0.999)
            moving_average(nets.style_encoder,
                           nets_ema.style_encoder,
                           beta=0.999)

            # decay weight for diversity sensitive loss
            if args.lambda_ds > 0:
                args.lambda_ds -= (initial_lambda_ds / args.ds_iter)

            # print out log info
            if (i + 1) % args.print_every == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))[:-7]
                log = "Elapsed time [%s], Iteration [%i/%i], " % (
                    elapsed, i + 1, args.total_iters)
                all_losses = dict()
                for loss, prefix in zip([d_losses, g_losses], ['D/', 'G/']):
                    for key, value in loss.items():
                        all_losses[prefix + key] = value
                all_losses['G/lambda_ds'] = args.lambda_ds
                log += ' '.join([
                    '%s: [%.4f]' % (key, value)
                    for key, value in all_losses.items()
                ])
                print(log)
                for key, value in all_losses.items():
                    self.writer.add_scalar(key, value, i + 1)

            # generate images for debugging
            if (i + 1) % args.sample_every == 0:
                os.makedirs(args.sample_dir, exist_ok=True)
                utils.debug_image(nets_ema,
                                  args,
                                  inputs=inputs_val,
                                  step=i + 1)

            # save model checkpoints
            if (i + 1) % args.save_every == 0:
                self._save_checkpoint(step=i + 1)

            # compute FID and LPIPS if necessary
            if (i + 1) % args.eval_every == 0:
                calculate_metrics(nets_ema, args, i + 1, mode='latent')
                calculate_metrics(nets_ema, args, i + 1, mode='reference')
        self.writer.close()
Ejemplo n.º 10
0
    def train(self, loaders):
        args = self.args
        nets = self.nets
        nets_ema = self.nets_ema
        optims = self.optims

        # fetch random validation images for debugging
        fetcher = InputFetcher(loaders.src, loaders.ref, args.latent_dim, 'train')
        fetcher_val = InputFetcher(loaders.val, None, args.latent_dim, 'val')
        x_fixed = next(fetcher_val)
        x_fixed = x_fixed.x_src

        # resume training if necessary
        if args.resume_iter > 0:
            self._load_checkpoint(args.resume_iter)

        # remember the initial value of ds weight
        initial_lambda_ds = args.lambda_ds

        print('Start training...')
        start_time = time.time()

        for i in range(args.resume_iter, args.total_iters):
            # fetch images and labels
            inputs = next(fetcher)
            x_real, y_org = inputs.x_src, inputs.y_src
            x_ref, x_ref2, y_trg = inputs.x_ref, inputs.x_ref2, inputs.y_ref
            z_trg, z_trg2 = inputs.z_trg, inputs.z_trg2

            masks = nets.fan.get_heatmap(x_real) if args.w_hpf > 0 else None

            # train the discriminator
            d_loss, d_losses_latent = compute_d_loss(
                nets, args, x_real, y_org, y_trg, z_trg=z_trg, masks=masks)
            self._reset_grad()
            d_loss.backward()
            optims.discriminator.step()

            """
            Removing Reference based training
            d_loss, d_losses_ref = compute_d_loss(
                nets, args, x_real, y_org, y_trg, x_ref=x_ref, masks=masks)
            self._reset_grad()
            d_loss.backward()
            optims.discriminator.step()
            """

            # train the generator
            g_loss, g_losses_latent = compute_g_loss(
                nets, args, x_real, y_org, y_trg, z_trgs=[z_trg, z_trg2], masks=masks)
            self._reset_grad()
            g_loss.backward()
            optims.generator.step()
            optims.mapping_network.step()
            optims.style_encoder.step()

            """
            Removing reference based training
            g_loss, g_losses_ref = compute_g_loss(
                nets, args, x_real, y_org, y_trg, x_refs=[x_ref, x_ref2], masks=masks)
            self._reset_grad()
            g_loss.backward()
            optims.generator.step()
            """

            # compute moving average of network parameters
            """
            moving_average(nets.generator, nets_ema.generator, beta=0.999)
            moving_average(nets.mapping_network, nets_ema.mapping_network, beta=0.999)
            moving_average(nets.style_encoder, nets_ema.style_encoder, beta=0.999)
            """

            # decay weight for diversity sensitive loss
            if args.lambda_ds > 0:
                args.lambda_ds -= (initial_lambda_ds / args.ds_iter)

            # print out log info
            if (i+1) % args.print_every == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))[:-7]
                log = "Elapsed time [%s], Iteration [%i/%i], " % (elapsed, i+1, args.total_iters)
                all_losses = dict()
                for loss, prefix in zip([d_losses_latent, g_losses_latent],
                                        ['D/latent_', 'G/latent_']):
                    for key, value in loss.items():
                        all_losses[prefix + key] = value
                all_losses['G/lambda_ds'] = args.lambda_ds
                log += ' '.join(['%s: [%.4f]' % (key, value) for key, value in all_losses.items()])
                print(log)

            """
            # generate images for debugging
            if (i+1) % args.sample_every == 0:
                os.makedirs(args.sample_dir, exist_ok=True)
                utils.debug_image(nets_ema, args, inputs=inputs_val, step=i+1)
            """
            if (i+1) % args.sample_every == 0:
                with torch.no_grad():
                    x_fake_list = [x_fixed]
                    for j in range(args.num_domains):
                        label = torch.ones((x_fixed.size(0),),dtype=torch.long).to(self.device)
                        label = label*j
                        z = torch.randn((x_fixed.size(0),args.latent_dim)).to(self.device)
                        style = self.nets.mapping_network(z,label)
                        x_fake_list.append(self.nets.generator(x_fixed, style))
                    x_concat = torch.cat(x_fake_list, dim=3)
                    sample_path = os.path.join('samples', '{}-images.jpg'.format(i+1))
                    save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)
                    print('Saved real and fake images into {}...'.format(sample_path))

            # save model checkpoints
            if (i+1) % args.save_every == 0:
                self._save_checkpoint(step=i+1)

            # compute FID and LPIPS if necessary
            if (i+1) % args.eval_every == 0:
                calculate_metrics(nets_ema, args, i+1, mode='latent')
                calculate_metrics(nets_ema, args, i+1, mode='reference')
Ejemplo n.º 11
0
    def train(self, loaders):


        if self.args.vgg_encode:
            vgg_encode = network.vgg_feature(False, 0)


        args = self.args
        nets = self.nets
        nets_ema = self.nets_ema
	
        optims = self.optims
        for net in nets.keys():
            if net == 'linear_classfier':

                optims[net] = torch.optim.Adam(
                    params=nets[net].parameters(),
                    lr=args.lr2,
                    betas=[args.beta1, args.beta2],
                    weight_decay=args.weight_decay)

        print(optims)
        print(self.nets.keys())
        # assert False

        # fetch random validation images for debugging
        if args.dataset == 'rafd':
            fetcher = InputFetcher_mpie(loaders.src, args.latent_dim, 'train')
            fetcher_val = InputFetcher_mpie(loaders.val, args.latent_dim, 'val')

            
        elif args.dataset == 'vox1':
            fetcher = InputFetcher_300vw(loaders.src, args.latent_dim, 'train')
            fetcher_val = InputFetcher_300vw(loaders.val, args.latent_dim, 'val')


        inputs_val = next(fetcher_val)

        # resume training if necessary
        if args.resume_iter > 0:
            self._load_checkpoint(args.resume_iter)

        # remember the initial value of ds weight
        initial_lambda_ds = args.lambda_ds

        print('Start training...')
        start_time = time.time()
        for i in range(args.resume_iter, args.total_iters):


            # fetch images and labels
            inputs = next(fetcher)
            # x_label, x2_label, x3_label, x4_label, x1_one_hot, x3_one_hot, x1_id, x3_id

            x1_label = inputs.x_label
            x2_label = inputs.x2_label
            x3_label = inputs.x3_label
            if  args.dataset == 'rafd':
                x4_label = inputs.x4_label

                param_x4 = x4_label[:, 0, :].unsqueeze(0)
                param_x4 = param_x4.view(-1, 136).float()
                param_x5 = None
            # elif args.dataset == 'vox1' and args.d_id:
            #
            #     x4_label = inputs.x4_label
            #     param_x4 = x4_label[:, 0, :].unsqueeze(0)
            #     param_x4 = param_x4.view(-1, 136).float()
            #
            #     x5_label = inputs.x5_label
            #     param_x5 = x5_label[:, 0, :].unsqueeze(0)
            #     param_x5 = param_x5.view(-1, 136).float()

            else:
                param_x4 = None
                param_x5 = None




            x1_one_hot, x3_one_hot = inputs.x1_one_hot, inputs.x3_one_hot
            x1_id, x3_id = inputs.x1_id, inputs.x3_id

            param_x1 = x1_label
            # param_x1 = x1_label[:, 0, :].unsqueeze(0)
            # param_x1 = param_x1.view(-1, 136).float()


            param_x2 = x2_label[:, 0, :].unsqueeze(0)
            param_x2 = param_x2.view(-1, 136).float()

            param_x3 = x3_label
            # param_x3 = x3_label[:, 0, :].unsqueeze(0)
            # param_x3 = param_x3.view(-1, 136).float()




            if args.dataset == 'vox1':

                one_hot_x1 = x1_one_hot[:, 0, :].unsqueeze(0)
                # one_hot_x1 = one_hot_x1.view(-1, 12606).float()
                one_hot_x1 = one_hot_x1.view(-1, 1251).float()


                one_hot_x3 = x3_one_hot[:, 0, :].unsqueeze(0)
                # one_hot_x3 = one_hot_x3.view(-1, 12606).float()
                one_hot_x3 = one_hot_x3.view(-1, 1251).float()
            elif args.dataset == 'rafd':
                one_hot_x1 = x1_one_hot[:, 0, :].unsqueeze(0)
                one_hot_x1 = one_hot_x1.view(-1, 67).float()

                one_hot_x3 = x3_one_hot[:, 0, :].unsqueeze(0)
                one_hot_x3 = one_hot_x3.view(-1, 67).float()




            if args.dataset == 'vox1':
                d_tran_loss, d_tran_losses = compute_d_tran_loss(
                    nets, args, param_x1,param_x2,param_x3,param_x4,one_hot_x1,one_hot_x3,x1_id, x3_id, masks=None, loss_select = args.loss, vgg_encode =vgg_encode)
                self._reset_grad()
                d_tran_loss.backward()
                optims.linear_discriminator.step()
                moving_average(nets.linear_discriminator, nets_ema.linear_discriminator, beta=0.999)

                # train the classfier
                c_loss, c_losses = compute_c_loss(
                    nets, args, param_x1,param_x2,param_x3,param_x4,one_hot_x1,one_hot_x3,x1_id, x3_id, masks=None, loss_select = args.loss)
                self._reset_grad()
                c_loss.backward()
                optims.linear_classfier.step()
                moving_average(nets.linear_classfier, nets_ema.linear_classfier, beta=0.999)

                # train the transformer
                t_loss, t_losses = compute_t_loss(
                    nets, args, param_x1,param_x2,param_x3,param_x4,param_x5,one_hot_x1,one_hot_x3,x1_id, x3_id, masks=None, loss_select = args.loss, vgg_encode =vgg_encode)
                self._reset_grad()
                t_loss.backward()
                optims.linear_decoder.step()
                optims.lm_linear_encoder.step()


                moving_average(nets.linear_decoder, nets_ema.linear_decoder, beta=0.999)
                moving_average(nets.lm_linear_encoder, nets_ema.lm_linear_encoder, beta=0.999)

                
            elif args.dataset == 'rafd':
                # train the discriminator
                d_tran_loss, d_tran_losses = compute_d_tran_loss(
                    nets, args, param_x1,param_x2,param_x3,param_x4,one_hot_x1,one_hot_x3,x1_id, x3_id, masks=None, loss_select = args.loss, vgg_encode =vgg_encode)
                self._reset_grad()
                d_tran_loss.backward()
                optims.linear_discriminator.step()
                moving_average(nets.linear_discriminator, nets_ema.linear_discriminator, beta=0.999)

                # train the classfier
                c_loss, c_losses = compute_c_loss(
                    nets, args, param_x1,param_x2,param_x3,param_x4,one_hot_x1,one_hot_x3,x1_id, x3_id, masks=None, loss_select = args.loss)
                self._reset_grad()
                c_loss.backward()
                optims.linear_classfier.step()
                moving_average(nets.linear_classfier, nets_ema.linear_classfier, beta=0.999)

                # train the transformer
                t_loss, t_losses = compute_t_loss(
                    nets, args, param_x1,param_x2,param_x3,param_x4, param_x5,one_hot_x1,one_hot_x3,x1_id, x3_id, masks=None, loss_select = args.loss, vgg_encode =vgg_encode)
                self._reset_grad()
                t_loss.backward()
                optims.linear_decoder.step()
                optims.lm_linear_encoder.step()


                moving_average(nets.linear_decoder, nets_ema.linear_decoder, beta=0.999)
                moving_average(nets.lm_linear_encoder, nets_ema.lm_linear_encoder, beta=0.999)





            # decay weight for diversity sensitive loss
            if args.lambda_ds > 0:
                args.lambda_ds -= (initial_lambda_ds / args.ds_iter)

            # print out log info
            if (i+1) % args.print_every == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))[:-7]
                log = "Elapsed time [%s], Iteration [%i/%i], " % (elapsed, i+1, args.total_iters)
                all_losses = dict()




                for loss, prefix in zip([d_tran_losses, t_losses],['D/', 'G/']):
                # for loss, prefix in zip([d_tran_losses, t_losses, c_losses],['D/', 'G/', 'C/']):

                    for key, value in loss.items():
                        all_losses[prefix + key] = value


                all_losses['G/lambda_ds'] = args.lambda_ds
                log += ' '.join(['%s: [%.4f]' % (key, value) for key, value in all_losses.items()])
                print(log)
                for key, value in all_losses.items():
                    self.writer.add_scalar(key, value, i+1)

            # generate images for debugging
            if (i+1) % args.sample_every == 0:
                os.makedirs(args.sample_dir, exist_ok=True)
                utils.debug_image(nets_ema, args, inputs=inputs_val, step=i+1, vgg_encode =vgg_encode)

            # save model checkpoints
            if (i+1) % args.save_every == 0:
                self._save_checkpoint(step=i+1)

            # compute FID and LPIPS if necessary
            if (i+1) % args.eval_every == 0:
                calculate_metrics(nets_ema, args, i+1, mode='latent')
                calculate_metrics(nets_ema, args, i+1, mode='reference')
        self.writer.close()