Exemple #1
0
  def dis_update(self, input, target, interpol_pose, real_inp, real_target, opt):
    self.disc.zero_grad()

    if (opt['gen_type'] == 'stacked'):
      out_gen = self.gen(input, interpol_pose)
      out_gen = out_gen[-1]
    else:
      out_gen = self.gen(input)

    inp_img, inp_pose, out_pose = pose_utils.get_imgpose(input, opt['use_input_pose'], opt['pose_dim'])

    fake_disc_inp = torch.cat([inp_img, inp_pose, out_gen, out_pose], dim=1)
    r_inp_img, r_inp_pose, r_out_pose = pose_utils.get_imgpose(real_inp, opt['use_input_pose'], opt['pose_dim'])
    real_disc_inp = torch.cat([r_inp_img, r_inp_pose, real_target, r_out_pose], dim=1)
    data_dis = torch.cat((real_disc_inp, fake_disc_inp), 0)
    res_dis = self.disc(data_dis)

    for it in range(res_dis.shape[0]):
      out = res_dis[it,:]
      if(it<opt['batch_size']):
        out_true_n = out.size(0)
        # real inputs should be 1
        # all1 = Variable(torch.ones((out_true_n)).cuda())
        if it == 0:
          # ad_true_loss = nn.functional.binary_cross_entropy(out, all1)
          ad_true_loss = -torch.mean(torch.log(out + 1e-7))
        else:
          # ad_true_loss += nn.functional.binary_cross_entropy(out, all1)
          ad_true_loss += -torch.mean(torch.log(out + 1e-7))
      else:
        out_fake_n = out.size(0)
        # fake inputs should be 0, appear after batch_size iters
        # all0 = Variable(torch.zeros((out_fake_n)).cuda())
        if it == opt['batch_size']:
          ad_fake_loss = -torch.mean(torch.log(1- out + 1e-7))
        else:
          ad_fake_loss += -torch.mean(torch.log(1 - out + 1e-7))

    ad_true_loss = ad_true_loss*opt['gan_penalty_weight']/self.batch_size
    ad_fake_loss = ad_fake_loss*opt['gan_penalty_weight']/self.batch_size
    ad_loss = ad_true_loss + ad_fake_loss
    loss = ad_loss
    loss.backward()
    self.disc_opt.step()
    self.dis_total_loss = loss.item()
    self.dis_true_loss = ad_true_loss.item()
    self.dis_fake_loss = ad_fake_loss.item()
    return [self.dis_total_loss , self.dis_true_loss , self.dis_fake_loss ]
Exemple #2
0
 def forward(self, input, warps, masks):
     inp_app, inp_pose, tg_pose = pose_utils.get_imgpose(
         input, self.use_input_pose, self.pose_dim)
     inp_app = torch.cat([inp_app, inp_pose], dim=1)
     skips_app = self.encoder_app(inp_app)
     skips_pose = self.encoder_pose(tg_pose)
     # define concatenate func
     skips = self.concatenate_skips(skips_app, skips_pose, warps, masks)
     out = self.decoder(skips)
     return out
Exemple #3
0
 def forward(self, input, target_pose, target_warps, target_masks):
     # extract initial input and init pose
     init_input, init_pose, _ = pose_utils.get_imgpose(
         input, self.use_input_pose, self.pose_dim)
     outputs = []
     outputs_2 = []
     outputs_3 = []
     # at every stage feed output from previous stage, input pose(if use input pose) as target pose for previous stage and new target pose from the list
     for i in range(self.num_stacks):
         if (i == 0):
             if (self.use_input_pose):
                 inp = torch.cat([
                     init_input, init_pose,
                     target_pose[:,
                                 i * self.pose_dim:(i + 1) * self.pose_dim]
                 ],
                                 dim=1)
             else:
                 inp = torch.cat([
                     init_input, target_pose[:, i * self.pose_dim:(i + 1) *
                                             self.pose_dim]
                 ],
                                 dim=1)
             # out = self.stacked_gen[i](inp)
             out, out_2, out_3 = self.generator(inp, target_warps[:, i],
                                                target_masks[:, i])
         else:
             if (self.use_input_pose):
                 stage_inp = torch.cat([
                     out,
                     target_pose[:,
                                 (i - 1) * self.pose_dim:i * self.pose_dim],
                     target_pose[:,
                                 i * self.pose_dim:(i + 1) * self.pose_dim]
                 ],
                                       dim=1)
             else:
                 stage_inp = torch.cat([
                     out, target_pose[:, i * self.pose_dim:(i + 1) *
                                      self.pose_dim]
                 ],
                                       dim=1)
             # out = self.stacked_gen[i](stage_inp)
             out, out_2, out_3 = self.generator(stage_inp, target_warps[:,
                                                                        i],
                                                target_masks[:, i])
         outputs.append(out)
         outputs_2.append(out_2)
         outputs_2.append(out_3)
     return outputs, outputs_2, outputs_3
Exemple #4
0
  def gen_update(self, input, target, interpol_pose, opt):
    self.gen.zero_grad()

    if(opt['gen_type']=='stacked'):
      outputs_gen = self.gen(input, interpol_pose)
      out_gen = outputs_gen[-1]
    else:
      out_gen = self.gen(input)
      outputs_gen = []

    inp_img, inp_pose, out_pose = pose_utils.get_imgpose(input, opt['use_input_pose'], opt['pose_dim'])

    inp_dis = torch.cat([inp_img, inp_pose, out_gen, out_pose], dim=1)
    out_dis = self.disc(inp_dis)

    # computing adversarial loss
    for it in range(out_dis.shape[0]):
      out = out_dis[it, :]
      all_ones = Variable(torch.ones((out.size(0))).cuda())
      if it==0:
        # ad_loss = nn.functional.binary_cross_entropy(out, all_ones)
        ad_loss = -torch.mean(torch.log(out + 1e-7))
      else:
        # ad_loss += nn.functional.binary_cross_entropy(out, all_ones)
        ad_loss += -torch.mean(torch.log(out + 1e-7)
                               )
    ll_loss = self.ll_loss_criterion(out_gen, target)
    ad_loss = ad_loss * opt['gan_penalty_weight'] / self.batch_size
    ll_loss = ll_loss * opt['l1_penalty_weight']
    total_loss = ad_loss + ll_loss
    total_loss.backward()
    self.gen_opt.step()
    self.gen_ll_loss = ll_loss.item()
    self.gen_ad_loss = ad_loss.item()
    self.gen_total_loss = total_loss.item()
    return out_gen, outputs_gen, [self.gen_total_loss, self.gen_ll_loss, self.gen_ad_loss ]
Exemple #5
0
    def gen_update(self, input, target, other_inputs, opt):
        self.gen.zero_grad()

        if (opt['gen_type'] == 'stacked'):
            interpol_pose = other_inputs['interpol_pose']
            interpol_warps = other_inputs['interpol_warps']
            interpol_masks = other_inputs['interpol_masks']
            outputs_gen = self.gen(input, interpol_pose, interpol_warps,
                                   interpol_masks)
            out_gen = outputs_gen[-1]
        else:
            warps = other_inputs['warps']
            masks = other_inputs['masks']
            out_gen, out_gen_2, out_gen_3 = self.gen(input, warps, masks)
            outputs_gen = []

        inp_img, inp_pose, out_pose = pose_utils.get_imgpose(
            input, opt['use_input_pose'], opt['pose_dim'])

        inp_dis = torch.cat([inp_img, inp_pose, out_gen, out_pose], dim=1)
        out_dis = self.disc(inp_dis)

        inp_dis_2 = torch.cat([inp_img, inp_pose, out_gen_2, out_pose], dim=1)
        out_dis_2 = self.disc(inp_dis_2)

        inp_dis_3 = torch.cat([inp_img, inp_pose, out_gen_3, out_pose], dim=1)
        out_dis_3 = self.disc(inp_dis_3)

        # computing adversarial loss
        for it in range(out_dis.shape[0]):
            out = out_dis[it, :]
            all_ones = Variable(torch.ones((out.size(0))).cuda())
            if it == 0:
                # ad_loss = nn.functional.binary_cross_entropy(out, all_ones)
                ad_loss = -torch.mean(torch.log(out + 1e-7))
            else:
                # ad_loss += nn.functional.binary_cross_entropy(out, all_ones)
                ad_loss += -torch.mean(torch.log(out + 1e-7))

        for it in range(out_dis_2.shape[0]):
            out_2 = out_dis_2[it, :]
            all_ones = Variable(torch.ones((out.size(0))).cuda())
            ad_loss += -torch.mean(torch.log(out_2 + 1e-7))

        for it in range(out_dis_3.shape[0]):
            out_3 = out_dis_3[it, :]
            all_ones = Variable(torch.ones((out.size(0))).cuda())
            ad_loss += -torch.mean(torch.log(out_3 + 1e-7))

        if self.content_loss_layer != 'none':
            content_out_gen = pose_utils.Feature_Extractor(
                self.content_model,
                input=out_gen,
                layer_name=self.content_loss_layer)
            content_target = pose_utils.Feature_Extractor(
                self.content_model,
                input=target,
                layer_name=self.content_loss_layer)
            ll_loss = self.nn_loss(content_out_gen, content_target,
                                   self.nn_loss_area_size,
                                   self.nn_loss_area_size)
        else:
            ll_loss = self.ll_loss_criterion(out_gen, target)

        ll_loss += self.ll_loss_criterion(out_gen, target)
        ll_loss += self.ll_loss_criterion(out_gen_2, target)
        ll_loss += self.ll_loss_criterion(out_gen_3, target)

        ad_loss = ad_loss * opt['gan_penalty_weight'] / self.batch_size
        ll_loss = ll_loss * opt['l1_penalty_weight']
        total_loss = ad_loss + ll_loss
        total_loss.backward()
        self.gen_opt.step()
        self.gen_ll_loss = ll_loss.item()
        self.gen_ad_loss = ad_loss.item()
        self.gen_total_loss = total_loss.item()
        return out_gen, outputs_gen, [
            self.gen_total_loss, self.gen_ll_loss, self.gen_ad_loss
        ]