Exemple #1
0
class Pix2Pix:
    def __init__(self, args) -> None:
        self.lr = args.learning_rate
        self.LAMBDA = args.LAMBDA
        self.save = args.save
        self.batch_size = args.batch_size
        self.path = args.path
        self.n_epochs = args.epoch_num
        self.eval_interval = 10
        self.G_image_loss = []
        self.G_GAN_loss = []
        self.G_total_loss = []
        self.D_loss = []
        self.netG = Generator().to("cuda")
        self.netD = Discriminator().to("cuda")
        self.optimizerG = flow.optim.Adam(self.netG.parameters(),
                                          lr=self.lr,
                                          betas=(0.5, 0.999))
        self.optimizerD = flow.optim.Adam(self.netD.parameters(),
                                          lr=self.lr,
                                          betas=(0.5, 0.999))
        self.criterionGAN = flow.nn.BCEWithLogitsLoss()
        self.criterionL1 = flow.nn.L1Loss()

        self.checkpoint_path = os.path.join(self.path, "checkpoint")
        self.test_images_path = os.path.join(self.path, "test_images")

        mkdirs(self.checkpoint_path, self.test_images_path)
        self.logger = init_logger(os.path.join(self.path, "log.txt"))

    def train(self):
        # init dataset
        x, y = load_facades()
        # flow.Tensor() bug in here
        x, y = np.ascontiguousarray(x), np.ascontiguousarray(y)
        self.fixed_inp = to_tensor(x[:self.batch_size].astype(np.float32))
        self.fixed_target = to_tensor(y[:self.batch_size].astype(np.float32))

        batch_num = len(x) // self.batch_size
        label1 = to_tensor(np.ones((self.batch_size, 1, 30, 30)),
                           dtype=flow.float32)
        label0 = to_tensor(np.zeros((self.batch_size, 1, 30, 30)),
                           dtype=flow.float32)

        for epoch_idx in range(self.n_epochs):
            self.netG.train()
            self.netD.train()
            start = time.time()

            # run every epoch to shuffle
            for batch_idx in range(batch_num):
                inp = to_tensor(x[batch_idx * self.batch_size:(batch_idx + 1) *
                                  self.batch_size].astype(np.float32))
                target = to_tensor(
                    y[batch_idx * self.batch_size:(batch_idx + 1) *
                      self.batch_size].astype(np.float32))

                # update D
                d_fake_loss, d_real_loss, d_loss = self.train_discriminator(
                    inp, target, label0, label1)

                # update G
                g_gan_loss, g_image_loss, g_total_loss, g_out = self.train_generator(
                    inp, target, label1)

                self.G_GAN_loss.append(g_gan_loss)
                self.G_image_loss.append(g_image_loss)
                self.G_total_loss.append(g_total_loss)
                self.D_loss.append(d_loss)
                if (batch_idx + 1) % self.eval_interval == 0:
                    self.logger.info(
                        "{}th epoch, {}th batch, d_fakeloss:{:>8.4f}, d_realloss:{:>8.4f},  ggan_loss:{:>8.4f}, gl1_loss:{:>8.4f}"
                        .format(
                            epoch_idx + 1,
                            batch_idx + 1,
                            d_fake_loss,
                            d_real_loss,
                            g_gan_loss,
                            g_image_loss,
                        ))

            self.logger.info("Time for epoch {} is {} sec.".format(
                epoch_idx + 1,
                time.time() - start))

            if (epoch_idx + 1) % 2 * self.eval_interval == 0:
                # save .train() images
                # save .eval() images
                self._eval_generator_and_save_images(epoch_idx)

        if self.save:
            flow.save(
                self.netG.state_dict(),
                os.path.join(self.checkpoint_path,
                             "pix2pix_g_{}".format(epoch_idx + 1)),
            )

            flow.save(
                self.netD.state_dict(),
                os.path.join(self.checkpoint_path,
                             "pix2pix_d_{}".format(epoch_idx + 1)),
            )

            # save train loss and val error to plot
            np.save(
                os.path.join(self.path,
                             "G_image_loss_{}.npy".format(self.n_epochs)),
                self.G_image_loss,
            )
            np.save(
                os.path.join(self.path,
                             "G_GAN_loss_{}.npy".format(self.n_epochs)),
                self.G_GAN_loss,
            )
            np.save(
                os.path.join(self.path,
                             "G_total_loss_{}.npy".format(self.n_epochs)),
                self.G_total_loss,
            )
            np.save(
                os.path.join(self.path, "D_loss_{}.npy".format(self.n_epochs)),
                self.D_loss,
            )
            self.logger.info("*************** Train done ***************** ")

    def train_generator(self, input, target, label1):
        g_out = self.netG(input)
        # First, G(A) should fake the discriminator
        fake_AB = flow.cat([input, g_out], 1)
        pred_fake = self.netD(fake_AB)
        gan_loss = self.criterionGAN(pred_fake, label1)
        # Second, G(A) = B
        l1_loss = self.criterionL1(g_out, target)
        # combine loss and calculate gradients
        g_loss = gan_loss + self.LAMBDA * l1_loss
        g_loss.backward()

        self.optimizerG.step()
        self.optimizerG.zero_grad()
        return (
            to_numpy(gan_loss),
            to_numpy(self.LAMBDA * l1_loss),
            to_numpy(g_loss),
            to_numpy(g_out, False),
        )

    def train_discriminator(self, input, target, label0, label1):
        g_out = self.netG(input)
        # Fake; stop backprop to the generator by detaching fake_B
        fake_AB = flow.cat([input, g_out.detach()], 1)
        pred_fake = self.netD(fake_AB)
        d_fake_loss = self.criterionGAN(pred_fake, label0)

        # Real
        real_AB = flow.cat([input, target], 1)
        pred_real = self.netD(real_AB)
        d_real_loss = self.criterionGAN(pred_real, label1)

        # combine loss and calculate gradients
        d_loss = (d_fake_loss + d_real_loss) * 0.5
        d_loss.backward()
        self.optimizerD.step()
        self.optimizerD.zero_grad()
        return to_numpy(d_fake_loss), to_numpy(d_real_loss), to_numpy(d_loss)

    def _eval_generator_and_save_images(self, epoch_idx):
        results = self._eval_generator()
        save_images(
            results,
            to_numpy(self.fixed_inp, False),
            to_numpy(self.fixed_target, False),
            path=os.path.join(self.test_images_path,
                              "testimage_{:02d}.png".format(epoch_idx + 1)),
        )

    def _eval_generator(self):
        self.netG.eval()
        with flow.no_grad():
            g_out = self.netG(self.fixed_inp)
        return to_numpy(g_out, False)
Exemple #2
0
class DeformablePose_GAN(nn.Module):
    def __init__(self, opt):
        super(DeformablePose_GAN, self).__init__()

        # load generator and discriminator models
        # adding extra layers for larger image size
        nfilters_decoder = (512, 512, 512, 256, 128,
                            3) if max(opt.image_size) < 256 else (512, 512,
                                                                  512, 512,
                                                                  256, 128, 3)
        nfilters_encoder = (64, 128, 256, 512, 512,
                            512) if max(opt.image_size) < 256 else (64, 128,
                                                                    256, 512,
                                                                    512, 512,
                                                                    512)

        if (opt.use_input_pose):
            input_nc = 3 + 2 * opt.pose_dim
        else:
            input_nc = 3 + opt.pose_dim

        self.batch_size = opt.batch_size
        self.num_stacks = opt.num_stacks
        self.pose_dim = opt.pose_dim
        if (opt.gen_type == 'stacked'):
            self.gen = Stacked_Generator(input_nc,
                                         opt.num_stacks,
                                         opt.image_size,
                                         opt.pose_dim,
                                         nfilters_encoder,
                                         nfilters_decoder,
                                         opt.warp_skip,
                                         use_input_pose=opt.use_input_pose)
            # hack to get better results
            pretrained_gen_path = '../exp/' + 'full_' + opt.dataset + '/models/gen_090.pkl'
            self.gen.generator.load_state_dict(torch.load(pretrained_gen_path))
            print("Loaded generator from pretrained model ")
        elif (opt.gen_type == 'baseline'):
            self.gen = Deformable_Generator(input_nc,
                                            self.pose_dim,
                                            opt.image_size,
                                            nfilters_encoder,
                                            nfilters_decoder,
                                            opt.warp_skip,
                                            use_input_pose=opt.use_input_pose)
        else:
            raise Exception('Invalid gen_type')
        # discriminator also sees the output image for the target pose
        self.disc = Discriminator(input_nc + 3,
                                  use_input_pose=opt.use_input_pose)
        # self.disc_2 = Discriminator(6, use_input_pose=opt.use_input_pose)
        pretrained_disc_path = "/home/linlilang/pose-transfer/exp/baseline_market/models/disc_020.pkl"
        print("Loaded discriminator from pretrained model ")
        self.disc.load_state_dict(torch.load(pretrained_disc_path))

        print('---------- Networks initialized -------------')
        # print_network(self.gen)
        # print_network(self.disc)
        print('-----------------------------------------------')
        # Setup the optimizers
        lr = opt.learning_rate
        self.disc_opt = torch.optim.Adam(self.disc.parameters(),
                                         lr=lr,
                                         betas=(0.5, 0.999))
        # self.disc_opt_2 = torch.optim.Adam(self.disc_2.parameters(), lr=lr, betas=(0.5, 0.999))
        self.gen_opt = torch.optim.Adam(self.gen.parameters(),
                                        lr=lr,
                                        betas=(0.5, 0.999))

        self.content_loss_layer = opt.content_loss_layer
        self.nn_loss_area_size = opt.nn_loss_area_size
        if self.content_loss_layer != 'none':
            self.content_model = resnet101(pretrained=True)
            # Setup the loss function for training
        # Network weight initialization
        self.gen.cuda()
        self.disc.cuda()
        # self.disc_2.cuda()
        self._nn_loss_area_size = opt.nn_loss_area_size
        # applying xavier_uniform, equivalent to glorot unigorm, as in Keras Defo GAN
        # skipping as models are pretrained
        # self.disc.apply(xavier_weights_init)
        # self.gen.apply(xavier_weights_init)
        self.ll_loss_criterion = torch.nn.L1Loss()

    # add code for intermediate supervision for the interpolated poses using pretrained pose-estimator
    def gen_update(self, input, target, other_inputs, opt):
        self.gen.zero_grad()

        if (opt['gen_type'] == 'stacked'):
            interpol_pose = other_inputs['interpol_pose']
            interpol_warps = other_inputs['interpol_warps']
            interpol_masks = other_inputs['interpol_masks']
            outputs_gen = self.gen(input, interpol_pose, interpol_warps,
                                   interpol_masks)
            out_gen = outputs_gen[-1]
        else:
            warps = other_inputs['warps']
            masks = other_inputs['masks']
            out_gen, out_gen_2, out_gen_3 = self.gen(input, warps, masks)
            outputs_gen = []

        inp_img, inp_pose, out_pose = pose_utils.get_imgpose(
            input, opt['use_input_pose'], opt['pose_dim'])

        inp_dis = torch.cat([inp_img, inp_pose, out_gen, out_pose], dim=1)
        out_dis = self.disc(inp_dis)

        inp_dis_2 = torch.cat([inp_img, inp_pose, out_gen_2, out_pose], dim=1)
        out_dis_2 = self.disc(inp_dis_2)

        inp_dis_3 = torch.cat([inp_img, inp_pose, out_gen_3, out_pose], dim=1)
        out_dis_3 = self.disc(inp_dis_3)

        # computing adversarial loss
        for it in range(out_dis.shape[0]):
            out = out_dis[it, :]
            all_ones = Variable(torch.ones((out.size(0))).cuda())
            if it == 0:
                # ad_loss = nn.functional.binary_cross_entropy(out, all_ones)
                ad_loss = -torch.mean(torch.log(out + 1e-7))
            else:
                # ad_loss += nn.functional.binary_cross_entropy(out, all_ones)
                ad_loss += -torch.mean(torch.log(out + 1e-7))

        for it in range(out_dis_2.shape[0]):
            out_2 = out_dis_2[it, :]
            all_ones = Variable(torch.ones((out.size(0))).cuda())
            ad_loss += -torch.mean(torch.log(out_2 + 1e-7))

        for it in range(out_dis_3.shape[0]):
            out_3 = out_dis_3[it, :]
            all_ones = Variable(torch.ones((out.size(0))).cuda())
            ad_loss += -torch.mean(torch.log(out_3 + 1e-7))

        if self.content_loss_layer != 'none':
            content_out_gen = pose_utils.Feature_Extractor(
                self.content_model,
                input=out_gen,
                layer_name=self.content_loss_layer)
            content_target = pose_utils.Feature_Extractor(
                self.content_model,
                input=target,
                layer_name=self.content_loss_layer)
            ll_loss = self.nn_loss(content_out_gen, content_target,
                                   self.nn_loss_area_size,
                                   self.nn_loss_area_size)
        else:
            ll_loss = self.ll_loss_criterion(out_gen, target)

        ll_loss += self.ll_loss_criterion(out_gen, target)
        ll_loss += self.ll_loss_criterion(out_gen_2, target)
        ll_loss += self.ll_loss_criterion(out_gen_3, target)

        ad_loss = ad_loss * opt['gan_penalty_weight'] / self.batch_size
        ll_loss = ll_loss * opt['l1_penalty_weight']
        total_loss = ad_loss + ll_loss
        total_loss.backward()
        self.gen_opt.step()
        self.gen_ll_loss = ll_loss.item()
        self.gen_ad_loss = ad_loss.item()
        self.gen_total_loss = total_loss.item()
        return out_gen, outputs_gen, [
            self.gen_total_loss, self.gen_ll_loss, self.gen_ad_loss
        ]

    def dis_update(self, input, target, other_inputs, real_inp, real_target,
                   opt):
        self.disc.zero_grad()

        if (opt['gen_type'] == 'stacked'):
            interpol_pose = other_inputs['interpol_pose']
            interpol_warps = other_inputs['interpol_warps']
            interpol_masks = other_inputs['interpol_masks']
            out_gen = self.gen(input, interpol_pose, interpol_warps,
                               interpol_masks)
            out_gen = out_gen[-1]
        else:
            warps = other_inputs['warps']
            masks = other_inputs['masks']
            out_gen, out_gen_2, out_gen_3 = self.gen(input, warps, masks)

        inp_img, inp_pose, out_pose = pose_utils.get_imgpose(
            input, opt['use_input_pose'], opt['pose_dim'])

        fake_disc_inp = torch.cat([inp_img, inp_pose, out_gen, out_pose],
                                  dim=1)
        r_inp_img, r_inp_pose, r_out_pose = pose_utils.get_imgpose(
            real_inp, opt['use_input_pose'], opt['pose_dim'])
        real_disc_inp = torch.cat(
            [r_inp_img, r_inp_pose, real_target, r_out_pose], dim=1)
        data_dis = torch.cat((real_disc_inp, fake_disc_inp), 0)
        res_dis = self.disc(data_dis)

        fake_disc_inp_2 = torch.cat([inp_img, inp_pose, out_gen_2, out_pose],
                                    dim=1)
        data_dis_2 = torch.cat((real_disc_inp, fake_disc_inp_2), 0)
        res_dis_2 = self.disc(data_dis_2)

        fake_disc_inp_3 = torch.cat([inp_img, inp_pose, out_gen_3, out_pose],
                                    dim=1)
        data_dis_3 = torch.cat((real_disc_inp, fake_disc_inp_3), 0)
        res_dis_3 = self.disc(data_dis_3)

        # print(res_dis.shape)
        for it in range(res_dis.shape[0]):
            out = res_dis[it, :]
            if (it < opt['batch_size']):
                out_true_n = out.size(0)
                # real inputs should be 1
                # all1 = Variable(torch.ones((out_true_n)).cuda())
                if it == 0:
                    # ad_true_loss = nn.functional.binary_cross_entropy(out, all1)
                    ad_true_loss = -torch.mean(torch.log(out + 1e-7))
                else:
                    # ad_true_loss += nn.functional.binary_cross_entropy(out, all1)
                    ad_true_loss += -torch.mean(torch.log(out + 1e-7))
            else:
                out_fake_n = out.size(0)
                # fake inputs should be 0, appear after batch_size iters
                # all0 = Variable(torch.zeros((out_fake_n)).cuda())
                if it == opt['batch_size']:
                    # ad_true_loss = -torch.mean(torch.log(out + 1e-7))= nn.functional.binary_cross_entropy(out, all0)
                    ad_fake_loss = -torch.mean(torch.log(1 - out + 1e-7))
                else:
                    ad_fake_loss += -torch.mean(torch.log(1 - out + 1e-7))

        for it in range(res_dis_2.shape[0]):
            out_2 = res_dis_2[it, :]
            if (it < opt['batch_size']):
                out_true_n_2 = out_2.size(0)
                # real inputs should be 1
                # all1 = Variable(torch.ones((out_true_n)).cuda())
                ad_true_loss += -torch.mean(torch.log(out_2 + 1e-7))
            else:
                out_fake_n_2 = out_2.size(0)
                # fake inputs should be 0, appear after batch_size iters
                # all0 = Variable(torch.zeros((out_fake_n)).cuda())
                ad_fake_loss += -torch.mean(torch.log(1 - out_2 + 1e-7))

        for it in range(res_dis_3.shape[0]):
            out_3 = res_dis_3[it, :]
            if (it < opt['batch_size']):
                out_true_n_3 = out_3.size(0)
                # real inputs should be 1
                # all1 = Variable(torch.ones((out_true_n)).cuda())
                ad_true_loss += -torch.mean(torch.log(out_3 + 1e-7))
            else:
                out_fake_n_3 = out_3.size(0)
                # fake inputs should be 0, appear after batch_size iters
                # all0 = Variable(torch.zeros((out_fake_n)).cuda())
                ad_fake_loss += -torch.mean(torch.log(1 - out_3 + 1e-7))

        ad_true_loss = ad_true_loss * opt[
            'gan_penalty_weight'] / self.batch_size
        ad_fake_loss = ad_fake_loss * opt[
            'gan_penalty_weight'] / self.batch_size
        ad_loss = ad_true_loss + ad_fake_loss
        loss = ad_loss
        loss.backward()
        self.disc_opt.step()
        self.dis_total_loss = loss.item()
        self.dis_true_loss = ad_true_loss.item()
        self.dis_fake_loss = ad_fake_loss.item()
        return [self.dis_total_loss, self.dis_true_loss, self.dis_fake_loss]

    def nn_loss(self, predicted, ground_truth, nh=3, nw=3):
        v_pad = nh // 2
        h_pad = nw // 2
        val_pad = nn.ConstantPad2d((v_pad, v_pad, h_pad, h_pad),
                                   -10000)(ground_truth)

        reference_tensors = []
        for i_begin in range(0, nh):
            i_end = i_begin - nh + 1
            i_end = None if i_end == 0 else i_end
            for j_begin in range(0, nw):
                j_end = j_begin - nw + 1
                j_end = None if j_end == 0 else j_end
                sub_tensor = val_pad[:, :, i_begin:i_end, j_begin:j_end]
                reference_tensors.append(sub_tensor.unsqueeze(-1))
        reference = torch.cat(reference_tensors, dim=-1)
        ground_truth = ground_truth.unsqueeze(dim=-1)

        predicted = predicted.unsqueeze(-1)
        abs = torch.abs(reference - predicted)
        # sum along channels
        norms = torch.sum(abs, dim=1)
        # min over neighbourhood
        loss, _ = torch.min(norms, dim=-1)
        # loss = torch.sum(loss)/self.batch_size
        loss = torch.mean(loss)

        return loss

    def resume(self, save_dir):
        last_model_name = pose_utils.get_model_list(save_dir, "gen")
        if last_model_name is None:
            return 1
        self.gen.load_state_dict(torch.load(last_model_name))
        epoch = int(last_model_name[-7:-4])
        print('Resume gen from epoch %d' % epoch)
        last_model_name = pose_utils.get_model_list(save_dir, "dis")
        if last_model_name is None:
            return 1
        epoch = int(last_model_name[-7:-4])
        self.disc.load_state_dict(torch.load(last_model_name))
        print('Resume disc from epoch %d' % epoch)
        return epoch

    def save(self, save_dir, epoch):
        gen_filename = os.path.join(save_dir, 'gen_{0:03d}.pkl'.format(epoch))
        disc_filename = os.path.join(save_dir,
                                     'disc_{0:03d}.pkl'.format(epoch))
        torch.save(self.gen.state_dict(), gen_filename)
        torch.save(self.disc.state_dict(), disc_filename)

    def normalize_image(self, x):
        return x[:, 0:3, :, :]
Exemple #3
0
class Pose_GAN(nn.Module):
  def __init__(self, opt):
    super(Pose_GAN, self).__init__()

    # load generator and discriminator models
    # adding extra layers for larger image size
    if(opt.checkMode == 0):
      nfilters_decoder = (512, 512, 512, 256, 128, 3) if max(opt.image_size) < 256 else (512, 512, 512, 512, 256, 128, 3)
      nfilters_encoder = (64, 128, 256, 512, 512, 512) if max(opt.image_size) < 256 else (64, 128, 256, 512, 512, 512, 512)
    else:
      nfilters_decoder = (128, 3) if max(opt.image_size) < 256 else (256, 128, 3)
      nfilters_encoder = (64, 128) if max(opt.image_size) < 256 else (64, 128, 256)

    if (opt.use_input_pose):
      input_nc = 3 + 2*opt.pose_dim
    else:
      input_nc = 3 + opt.pose_dim

    self.num_stacks = opt.num_stacks
    self.batch_size = opt.batch_size
    self.pose_dim = opt.pose_dim
    if(opt.gen_type=='stacked'):
      self.gen = Stacked_Generator(input_nc, opt.num_stacks, opt.pose_dim, nfilters_encoder, nfilters_decoder, use_input_pose=opt.use_input_pose)
    elif(opt.gen_type=='baseline'):
      self.gen = Generator(input_nc, nfilters_encoder, nfilters_decoder, use_input_pose=opt.use_input_pose)
    else:
      raise Exception('Invalid gen_type')
    # discriminator also sees the output image for the target pose
    self.disc = Discriminator(input_nc + 3, use_input_pose=opt.use_input_pose, checkMode=opt.checkMode)
    print('---------- Networks initialized -------------')
    print_network(self.gen)
    print_network(self.disc)
    print('-----------------------------------------------')
    # Setup the optimizers
    lr = opt.learning_rate
    self.disc_opt = torch.optim.Adam(self.disc.parameters(), lr=lr, betas=(0.5, 0.999))
    self.gen_opt = torch.optim.Adam(self.gen.parameters(), lr=lr, betas=(0.5, 0.999))

    # Network weight initialization
    self.gen.cuda()
    self.disc.cuda()
    self.disc.apply(xavier_weights_init)
    self.gen.apply(xavier_weights_init)

    # Setup the loss function for training
    self.ll_loss_criterion = torch.nn.L1Loss()

  # add code for intermediate supervision for the interpolated poses using pretrained pose-estimator
  def gen_update(self, input, target, interpol_pose, opt):
    self.gen.zero_grad()

    if(opt['gen_type']=='stacked'):
      outputs_gen = self.gen(input, interpol_pose)
      out_gen = outputs_gen[-1]
    else:
      out_gen = self.gen(input)
      outputs_gen = []

    inp_img, inp_pose, out_pose = pose_utils.get_imgpose(input, opt['use_input_pose'], opt['pose_dim'])

    inp_dis = torch.cat([inp_img, inp_pose, out_gen, out_pose], dim=1)
    out_dis = self.disc(inp_dis)

    # computing adversarial loss
    for it in range(out_dis.shape[0]):
      out = out_dis[it, :]
      all_ones = Variable(torch.ones((out.size(0))).cuda())
      if it==0:
        # ad_loss = nn.functional.binary_cross_entropy(out, all_ones)
        ad_loss = -torch.mean(torch.log(out + 1e-7))
      else:
        # ad_loss += nn.functional.binary_cross_entropy(out, all_ones)
        ad_loss += -torch.mean(torch.log(out + 1e-7)
                               )
    ll_loss = self.ll_loss_criterion(out_gen, target)
    ad_loss = ad_loss * opt['gan_penalty_weight'] / self.batch_size
    ll_loss = ll_loss * opt['l1_penalty_weight']
    total_loss = ad_loss + ll_loss
    total_loss.backward()
    self.gen_opt.step()
    self.gen_ll_loss = ll_loss.item()
    self.gen_ad_loss = ad_loss.item()
    self.gen_total_loss = total_loss.item()
    return out_gen, outputs_gen, [self.gen_total_loss, self.gen_ll_loss, self.gen_ad_loss ]

  def dis_update(self, input, target, interpol_pose, real_inp, real_target, opt):
    self.disc.zero_grad()

    if (opt['gen_type'] == 'stacked'):
      out_gen = self.gen(input, interpol_pose)
      out_gen = out_gen[-1]
    else:
      out_gen = self.gen(input)

    inp_img, inp_pose, out_pose = pose_utils.get_imgpose(input, opt['use_input_pose'], opt['pose_dim'])

    fake_disc_inp = torch.cat([inp_img, inp_pose, out_gen, out_pose], dim=1)
    r_inp_img, r_inp_pose, r_out_pose = pose_utils.get_imgpose(real_inp, opt['use_input_pose'], opt['pose_dim'])
    real_disc_inp = torch.cat([r_inp_img, r_inp_pose, real_target, r_out_pose], dim=1)
    data_dis = torch.cat((real_disc_inp, fake_disc_inp), 0)
    res_dis = self.disc(data_dis)

    for it in range(res_dis.shape[0]):
      out = res_dis[it,:]
      if(it<opt['batch_size']):
        out_true_n = out.size(0)
        # real inputs should be 1
        # all1 = Variable(torch.ones((out_true_n)).cuda())
        if it == 0:
          # ad_true_loss = nn.functional.binary_cross_entropy(out, all1)
          ad_true_loss = -torch.mean(torch.log(out + 1e-7))
        else:
          # ad_true_loss += nn.functional.binary_cross_entropy(out, all1)
          ad_true_loss += -torch.mean(torch.log(out + 1e-7))
      else:
        out_fake_n = out.size(0)
        # fake inputs should be 0, appear after batch_size iters
        # all0 = Variable(torch.zeros((out_fake_n)).cuda())
        if it == opt['batch_size']:
          ad_fake_loss = -torch.mean(torch.log(1- out + 1e-7))
        else:
          ad_fake_loss += -torch.mean(torch.log(1 - out + 1e-7))

    ad_true_loss = ad_true_loss*opt['gan_penalty_weight']/self.batch_size
    ad_fake_loss = ad_fake_loss*opt['gan_penalty_weight']/self.batch_size
    ad_loss = ad_true_loss + ad_fake_loss
    loss = ad_loss
    loss.backward()
    self.disc_opt.step()
    self.dis_total_loss = loss.item()
    self.dis_true_loss = ad_true_loss.item()
    self.dis_fake_loss = ad_fake_loss.item()
    return [self.dis_total_loss , self.dis_true_loss , self.dis_fake_loss ]

  def resume(self, save_dir):
    last_model_name = pose_utils.get_model_list(save_dir,"gen")
    if last_model_name is None:
      return 1
    self.gen.load_state_dict(torch.load(last_model_name))
    epoch = int(last_model_name[-7:-4])
    print('Resume gen from epoch %d' % epoch)
    last_model_name = pose_utils.get_model_list(save_dir, "dis")
    if last_model_name is None:
      return 1
    epoch = int(last_model_name[-7:-4])
    self.disc.load_state_dict(torch.load(last_model_name))
    print('Resume disc from epoch %d' % epoch)
    return epoch

  def save(self, save_dir, epoch):
    gen_filename = os.path.join(save_dir, 'gen_{0:03d}.pkl'.format(epoch))
    disc_filename = os.path.join(save_dir, 'disc_{0:03d}.pkl'.format(epoch))
    torch.save(self.gen.state_dict(), gen_filename)
    torch.save(self.disc.state_dict(), disc_filename)

  def normalize_image(self, x):
    return x[:,0:3,:,:]
Exemple #4
0
        summary.add_scalar(f'loss G/loss Overall',
                           loss_overall.data.cpu().numpy(), iter_count)

        etime = time.time() - stime
        rtime = etime * (total_epoch_iter - iter_count) / (iter_count + eps)
        print(
            f'Epoch: {epoch+1:03d}/{num_epochs:03d}, Iter: {i+1:04d}/{total_iter:04d}, ',
            end='')
        print(f'Loss G: {loss_overall.data:.4f}, Loss D: {loss_D.data:.4f}, ',
              end='')
        print(f'Elapsed: {sec2time(etime)}, Remaining: {sec2time(rtime)}')

        if (i + 1) % 10 == 0:
            summary.add_image(f'image/sr_image', sr[0], iter_count)
            summary.add_image(f'image/lr_image', lr[0], iter_count)
            summary.add_image(f'image/hr_image', hr[0], iter_count)

    torch.save(
        G.state_dict(),
        f'./models/weights/G_epoch_{epoch+1}_loss_{loss_overall.data:.4f}.pth')
    torch.save(
        D.state_dict(),
        f'./models/weights/D_epoch_{epoch+1}_loss_{loss_D.data:.4f}.pth')

    if (epoch + 1) % 10 == 0:
        learning_rateG *= 0.5
        learning_rateD *= 0.5

    update_lr(optimizerG, learning_rateG)
    update_lr(optimizerD, learning_rateD)