def forward(self, input_A, input_B, inst_A, fake_B_prev, dummy_bs=0): tG = self.opt.n_frames_G gpu_split_id = self.opt.n_gpus_gen + 1 if input_A.get_device() == self.gpu_ids[0]: input_A, input_B, inst_A, fake_B_prev = util.remove_dummy_from_tensor( [input_A, input_B, inst_A, fake_B_prev], dummy_bs) if input_A.size(0) == 0: return self.return_dummy(input_A) real_A_all, real_B_all, _ = self.encode_input(input_A, input_B, inst_A) is_first_frame = fake_B_prev is None if is_first_frame: # at the beginning of a sequence; needs to generate the first frame fake_B_prev = self.generate_first_frame(real_A_all, real_B_all) netG = [] for s in range(self.n_scales ): # broadcast netG to all GPUs used for generator netG_s = getattr(self, 'netG' + str(s)) netG_s = torch.nn.parallel.replicate( netG_s, self.opt.gpu_ids[:gpu_split_id]) if self.split_gpus else [ netG_s ] netG.append(netG_s) start_gpu = self.gpu_ids[ 1] if self.split_gpus else real_A_all.get_device() fake_B, fake_B_raw, flow, weight = self.generate_frame_train( netG, real_A_all, fake_B_prev, start_gpu, is_first_frame) fake_B_prev = [B[:, -tG + 1:].detach() for B in fake_B] fake_B = [B[:, tG - 1:] for B in fake_B] return fake_B[ 0], fake_B_raw, flow, weight, real_A_all[:, tG - 1:], real_B_all[:, tG - 2:], fake_B_prev
def forward(self, scale_T, tensors_list, dummy_bs=0): lambda_feat = self.opt.lambda_feat lambda_F = self.opt.lambda_F lambda_T = self.opt.lambda_T scale_S = self.opt.n_scales_spatial tD = self.opt.n_frames_D if tensors_list[0].get_device() == self.gpu_ids[0]: tensors_list = util.remove_dummy_from_tensor( tensors_list, dummy_bs) if tensors_list[0].size(0) == 0: return [self.Tensor(1, 1).fill_(0) ] * (len(self.loss_names_T) if scale_T > 0 else len(self.loss_names)) if scale_T > 0: real_B, fake_B, flow_ref, conf_ref = tensors_list _, _, _, self.height, self.width = real_B.size() loss_D_T_real, loss_D_T_fake, loss_G_T_GAN, loss_G_T_GAN_Feat = self.compute_loss_D_T( real_B, fake_B, flow_ref / 20, conf_ref, scale_T - 1) loss_G_T_Warp = torch.zeros_like(loss_G_T_GAN) loss_list = [ loss_G_T_GAN, loss_G_T_GAN_Feat, loss_D_T_real, loss_D_T_fake, loss_G_T_Warp ] loss_list = [loss.view(-1, 1) for loss in loss_list] return loss_list real_B, fake_B, fake_B_raw, real_A, real_B_prev, fake_B_prev, flow, weight, flow_ref, conf_ref = tensors_list _, _, self.height, self.width = real_B.size() ################### Flow loss ################# if flow is not None: # similar to flownet flow loss_F_Flow = self.criterionFlow( flow, flow_ref, conf_ref) * lambda_F / (2**(scale_S - 1)) # warped prev image should be close to current image real_B_warp = self.resample(real_B_prev, flow) loss_F_Warp = self.criterionFlow(real_B_warp, real_B, conf_ref) * lambda_T ################## weight loss ################## loss_W = torch.zeros_like(weight) if self.opt.no_first_img: dummy0 = torch.zeros_like(weight) loss_W = self.criterionFlow(weight, dummy0, conf_ref) else: loss_F_Flow = loss_F_Warp = loss_W = torch.zeros_like(conf_ref) #################### fake_B loss #################### ### VGG + GAN loss loss_G_VGG = ( self.criterionVGG(fake_B, real_B) * lambda_feat) if not self.opt.no_vgg else torch.zeros_like(loss_W) loss_D_real, loss_D_fake, loss_G_GAN, loss_G_GAN_Feat = self.compute_loss_D( self.netD, real_A, real_B, fake_B) ### Warp loss fake_B_warp_ref = self.resample(fake_B_prev, flow_ref) loss_G_Warp = self.criterionWarp(fake_B, fake_B_warp_ref.detach(), conf_ref) * lambda_T if fake_B_raw is not None: if not self.opt.no_vgg: loss_G_VGG += self.criterionVGG(fake_B_raw, real_B) * lambda_feat l_D_real, l_D_fake, l_G_GAN, l_G_GAN_Feat = self.compute_loss_D( self.netD, real_A, real_B, fake_B_raw) loss_G_GAN += l_G_GAN loss_G_GAN_Feat += l_G_GAN_Feat loss_D_real += l_D_real loss_D_fake += l_D_fake if self.opt.add_face_disc: face_weight = 2 ys, ye, xs, xe = self.get_face_region(real_A) if ys is not None: loss_D_f_real, loss_D_f_fake, loss_G_f_GAN, loss_G_f_GAN_Feat = self.compute_loss_D( self.netD_f, real_A[:, :, ys:ye, xs:xe], real_B[:, :, ys:ye, xs:xe], fake_B[:, :, ys:ye, xs:xe]) loss_G_f_GAN *= face_weight loss_G_f_GAN_Feat *= face_weight else: loss_D_f_real = loss_D_f_fake = loss_G_f_GAN = loss_G_f_GAN_Feat = torch.zeros_like( loss_D_real) loss_list = [ loss_G_VGG, loss_G_GAN, loss_G_GAN_Feat, loss_D_real, loss_D_fake, loss_G_Warp, loss_F_Flow, loss_F_Warp, loss_W ] if self.opt.add_face_disc: loss_list += [ loss_G_f_GAN, loss_G_f_GAN_Feat, loss_D_f_real, loss_D_f_fake ] loss_list = [loss.view(-1, 1) for loss in loss_list] return loss_list