예제 #1
0
 def backward_D(self):
     """Calculate the GAN loss for the discriminators"""
     base_function._unfreeze(self.net_D)
     #print(self.input_P2.shape, self.img_gen.shape)
     self.loss_dis_img_gen = self.backward_D_basic(self.net_D,
                                                   self.input_P2,
                                                   self.img_gen)
    def backward_D(self):
        """Calculate the GAN loss for the discriminators"""
        base_function._unfreeze(self.net_D)
        i = np.random.randint(len(self.img_gen))
        fake = self.img_gen[i]
        real = self.P_frame_step[:,i,...]
        self.loss_dis_img_gen = self.backward_D_basic(self.net_D, real, fake)

        base_function._unfreeze(self.net_D_V)
        i = np.random.randint(len(self.img_gen)-self.opt.frames_D_V+1)
        # fake = [self.img_gen[i]]
        # real = [self.P_frame_step[:,i,...]]
        fake = []
        real = []        
        for frame in range(self.opt.frames_D_V-1):
            fake.append(self.img_gen[i+frame]-self.img_gen[i+frame+1])
            real.append(self.P_frame_step[:,i+frame,...]
                       -self.P_frame_step[:,i+frame+1,...])
        fake = torch.cat(fake, dim=1)
        real = torch.cat(real, dim=1)
        self.loss_dis_img_gen_v = self.backward_D_basic(self.net_D_V, real, fake)
예제 #3
0
    def backward_D(self):
        """Calculate the GAN loss for the discriminators"""
        
        # Spatial GAN Loss
        base_function._unfreeze(self.net_D)
        i = np.random.randint(len(self.img_gen))
        fake = self.img_gen[i]
        real = self.P_step[:,i,...]
        self.loss_dis_img_gen = self.backward_D_basic(self.net_D, real, fake)

        # Temporal GAN Loss
        base_function._unfreeze(self.net_D_V)
        i = np.random.randint(len(self.img_gen)-self.opt.frames_D_V+1)
        fake = []
        real = []        
        for frame in range(self.opt.frames_D_V):
            fake.append(self.img_gen[i+frame].unsqueeze(2))
            real.append(self.P_step[:,i+frame,...].unsqueeze(2))
        fake = torch.cat(fake, dim=2)
        real = torch.cat(real, dim=2)
        self.loss_dis_img_gen_v = self.backward_D_basic(self.net_D_V, real, fake)
예제 #4
0
 def backward_D(self):
     base_function._unfreeze(self.net_D)
     self.loss_dis_img_gen = self.backward_D_basic(self.net_D,
                                                   self.input_P2,
                                                   self.img_gen)
예제 #5
0
 def backward_D(self):
     """Calculate the GAN loss for the discriminators"""
     base_function._unfreeze(self.net_D)
     self.loss_dis_img_gen = self.backward_D_basic(self.net_D,
                                                   self.input_fullP2,
                                                   self.img_gen)  #注意有无背景!