コード例 #1
0
    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.real = torch.cat((self.real_A, self.real_B), dim=0) if self.opt.nce_idt and self.opt.isTrain else self.real_A
        if self.opt.flip_equivariance:
            self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
            if self.flipped_for_equivariance:
                self.real = torch.flip(self.real, [3])

        self.fake = self.netG(self.real)
        self.fake_B = self.fake[:self.real_A.size(0)]
        if self.opt.nce_idt:
            self.idt_B = self.fake[self.real_A.size(0):]

        d = 1
        self.pred_real_A = self.netf_s(self.real_A)    
        self.gt_pred_A = F.log_softmax(self.pred_real_A,dim= d).argmax(dim=d)
            
        self.pred_fake_B = self.netf_s(self.fake_B)
        self.pfB = F.log_softmax(self.pred_fake_B,dim=d)#.argmax(dim=d)
        self.pfB_max = self.pfB.argmax(dim=d)

        if hasattr(self,'criterionMask'):
                label_A = self.input_A_label
                label_A_in = label_A.unsqueeze(1)
                label_A_inv = torch.tensor(np.ones(label_A.size())).to(self.device) - label_A>0.5
                label_A_inv = label_A_inv.unsqueeze(1)
                self.real_A_out_mask = self.real_A *label_A_inv
                self.fake_B_out_mask = self.fake_B *label_A_inv            

                if self.opt.D_noise > 0.0:
                    self.fake_B_noisy = gaussian(self.fake_B, self.opt.D_noise)
                    self.real_B_noisy = gaussian(self.real_B, self.opt.D_noise)
コード例 #2
0
    def forward(self):
        self.fake_B = self.netG_A(self.real_A)

        if self.rec_noise > 0.0:
            self.fake_B_noisy1 = gaussian(self.fake_B, self.rec_noise)
            self.rec_A = self.netG_B(self.fake_B_noisy1)
        else:
            self.rec_A = self.netG_B(self.fake_B)

        self.fake_A = self.netG_B(self.real_B)

        if self.rec_noise > 0.0:
            self.fake_A_noisy1 = gaussian(self.fake_A, self.rec_noise)
            self.rec_B = self.netG_A(self.fake_A_noisy1)
        else:
            self.rec_B = self.netG_A(self.fake_A)

        if self.isTrain:
            # Forward all four images through classifier
            # Keep predictions from fake images only
            #print('real_A shape=',self.real_A.shape)
            #print('real_A=',self.real_A)
            self.pred_real_A = self.netCLS(self.real_A)
            _, self.gt_pred_A = self.pred_real_A.max(1)
            pred_real_B = self.netCLS(self.real_B)
            _, self.gt_pred_B = pred_real_B.max(1)
            self.pred_fake_A = self.netCLS(self.fake_A)
            self.pred_fake_B = self.netCLS(self.fake_B)

            _, self.pfB = self.pred_fake_B.max(1)  #beniz: unused ?
コード例 #3
0
    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        super().forward()

        d = 1
        self.pred_real_A = self.netf_s(self.real_A)
        self.gt_pred_A = F.log_softmax(self.pred_real_A, dim=d).argmax(dim=d)

        self.pred_real_B = self.netf_s(self.real_B)
        self.gt_pred_B = F.log_softmax(self.pred_real_B, dim=d).argmax(dim=d)

        self.pred_fake_B = self.netf_s(self.fake_B)
        self.pfB = F.log_softmax(self.pred_fake_B, dim=d)  #.argmax(dim=d)
        self.pfB_max = self.pfB.argmax(dim=d)

        if hasattr(self, 'criterionMask'):
            label_A = self.input_A_label
            label_A_in = label_A.unsqueeze(1)
            label_A_inv = torch.tensor(np.ones(label_A.size())).to(
                self.device) - label_A > 0.5
            label_A_inv = label_A_inv.unsqueeze(1)
            self.real_A_out_mask = self.real_A * label_A_inv
            self.fake_B_out_mask = self.fake_B * label_A_inv

            if self.opt.D_noise > 0.0:
                self.fake_B_noisy = gaussian(self.fake_B, self.opt.D_noise)
                self.real_B_noisy = gaussian(self.real_B, self.opt.D_noise)
コード例 #4
0
    def forward(self):
        self.z_fake_B, self.n_fake_B = self.netG_A(self.real_A)
        d = 1
        #self.netDecoderG_A.eval()
        self.fake_B, self.latent_fake_B = self.netDecoderG_A(
            self.z_fake_B,
            input_is_latent=True,
            truncation=self.truncation,
            truncation_latent=self.mean_latent_A,
            randomize_noise=False,
            noise=self.n_fake_B,
            return_latents=True)

        if self.isTrain:
            #self.netDecoderG_B.eval()
            if self.rec_noise > 0.0:
                self.fake_B_noisy1 = gaussian(self.fake_B, self.rec_noise)
                self.z_rec_A, self.n_rec_A = self.netG_B(self.fake_B_noisy1)
            else:
                self.z_rec_A, self.n_rec_A = self.netG_B(self.fake_B)
            self.rec_A = self.netDecoderG_B(
                self.z_rec_A,
                input_is_latent=True,
                truncation=self.truncation,
                truncation_latent=self.mean_latent_B,
                randomize_noise=False,
                noise=self.n_rec_A)[0]

            self.z_fake_A, self.n_fake_A = self.netG_B(self.real_B)
            self.fake_A, self.latent_fake_A = self.netDecoderG_B(
                self.z_fake_A,
                input_is_latent=True,
                truncation=self.truncation,
                truncation_latent=self.mean_latent_B,
                randomize_noise=False,
                return_latents=True,
                noise=self.n_fake_A)

            if self.rec_noise > 0.0:
                self.fake_A_noisy1 = gaussian(self.fake_A, self.rec_noise)
                self.z_rec_B, self.n_rec_B = self.netG_A(self.fake_A_noisy1)
            else:
                self.z_rec_B, self.n_rec_B = self.netG_A(self.fake_A)
            self.rec_B = self.netDecoderG_A(
                self.z_rec_B,
                input_is_latent=True,
                truncation=self.truncation,
                truncation_latent=self.mean_latent_A,
                randomize_noise=False,
                noise=self.n_rec_B)[0]
コード例 #5
0
    def calculate_NCE_loss(self, src, tgt):
        n_layers = len(self.nce_layers)
        feat_q = self.netG(tgt, self.nce_layers, encode_only=True)

        if self.opt.flip_equivariance and self.flipped_for_equivariance:
            feat_q = [torch.flip(fq, [3]) for fq in feat_q]

        feat_k = self.netG(src, self.nce_layers, encode_only=True)
        feat_k_pool, sample_ids = self.netF(feat_k, self.opt.num_patches, None)
        feat_q_pool, _ = self.netF(feat_q, self.opt.num_patches, sample_ids)

        total_nce_loss = 0.0
        for f_q, f_k, crit, nce_layer in zip(feat_q_pool, feat_k_pool, self.criterionNCE, self.nce_layers):
            if self.opt.contrastive_noise>0.0:
                f_q=gaussian(f_q,self.opt.contrastive_noise)
                f_k=gaussian(f_k,self.opt.contrastive_noise)
            loss = crit(f_q, f_k) * self.opt.lambda_NCE
            total_nce_loss += loss.mean()

        return total_nce_loss / n_layers
コード例 #6
0
ファイル: cut_model.py プロジェクト: pnsuau/joliGAN
    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        super().forward()

        self.real = torch.cat((self.real_A, self.real_B), dim=0) if self.opt.nce_idt and self.opt.isTrain else self.real_A
        if self.opt.flip_equivariance:
            self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
            if self.flipped_for_equivariance:
                self.real = torch.flip(self.real, [3])

        self.fake = self.netG(self.real)
        self.fake_B = self.fake[:self.real_A.size(0)]
        if self.opt.nce_idt:
            self.idt_B = self.fake[self.real_A.size(0):]

        if self.opt.D_noise > 0.0:
            self.fake_B_noisy = gaussian(self.fake_B, self.opt.D_noise)
            self.real_B_noisy = gaussian(self.real_B, self.opt.D_noise)

        self.diff_real_A_fake_B = self.real_A - self.fake_B
コード例 #7
0
    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        super().forward()
        self.fake_B = self.netG_A(self.real_A)  # G_A(A)
        if self.rec_noise > 0.0:
            self.fake_B_noisy1 = gaussian(self.fake_B, self.rec_noise)
            self.rec_A= self.netG_B(self.fake_B_noisy1)
        else:
            self.rec_A = self.netG_B(self.fake_B)   # G_B(G_A(A))
        
        self.fake_A = self.netG_B(self.real_B)  # G_B(B)
        if self.rec_noise > 0.0:
            self.fake_A_noisy1 = gaussian(self.fake_A, self.rec_noise)
            self.rec_B = self.netG_A(self.fake_A_noisy1)
        else:
            self.rec_B = self.netG_A(self.fake_A)   # G_A(G_B(B))
            
        if self.opt.D_noise > 0.0:
            self.fake_B_noisy = gaussian(self.fake_B, self.opt.D_noise)
            self.real_A_noisy = gaussian(self.real_A, self.opt.D_noise)
            self.fake_A_noisy = gaussian(self.fake_A, self.opt.D_noise)
            self.real_B_noisy = gaussian(self.real_B, self.opt.D_noise)

        if self.opt.lambda_identity > 0:
            self.idt_A = self.netG_A(self.real_B)
            self.idt_B = self.netG_B(self.real_A)
    def forward(self):
        self.z_fake_B, self.n_fake_B = self.netG_A(self.real_A)

        d = 1

        #self.netDecoderG_A.eval()
        self.fake_B, self.latent_fake_B = self.netDecoderG_A(
            self.z_fake_B,
            input_is_latent=True,
            truncation=self.truncation,
            truncation_latent=self.mean_latent_A,
            randomize_noise=self.randomize_noise,
            return_latents=True,
            noise=self.n_fake_B)
        if self.opt.decoder_size > self.opt.crop_size:
            self.fake_B = F.interpolate(self.fake_B, self.opt.crop_size)

        if self.isTrain:
            #self.netDecoderG_B.eval()
            if self.rec_noise > 0.0:
                self.fake_B_noisy1 = gaussian(self.fake_B, self.rec_noise)
                self.z_rec_A, self.n_rec_A = self.netG_B(self.fake_B_noisy1)
            else:
                self.z_rec_A, self.n_rec_A = self.netG_B(self.fake_B)
            self.rec_A = self.netDecoderG_B(
                self.z_rec_A,
                input_is_latent=True,
                truncation=self.truncation,
                truncation_latent=self.mean_latent_B,
                randomize_noise=self.randomize_noise,
                noise=self.n_rec_A)[0]
            if self.opt.decoder_size > self.opt.crop_size:
                self.rec_A = F.interpolate(self.rec_A, self.opt.crop_size)

            self.z_fake_A, self.n_fake_A = self.netG_B(self.real_B)
            self.fake_A, self.latent_fake_A = self.netDecoderG_B(
                self.z_fake_A,
                input_is_latent=True,
                truncation=self.truncation,
                truncation_latent=self.mean_latent_B,
                randomize_noise=self.randomize_noise,
                return_latents=True,
                noise=self.n_fake_A)
            if self.opt.decoder_size > self.opt.crop_size:
                self.fake_A = F.interpolate(self.fake_A, self.opt.crop_size)

            if self.rec_noise > 0.0:
                self.fake_A_noisy1 = gaussian(self.fake_A, self.rec_noise)
                self.z_rec_B, self.n_rec_B = self.netG_A(self.fake_A_noisy1)
            else:
                self.z_rec_B, self.n_rec_B = self.netG_A(self.fake_A)
            self.rec_B = self.netDecoderG_A(
                self.z_rec_B,
                input_is_latent=True,
                truncation=self.truncation,
                truncation_latent=self.mean_latent_A,
                randomize_noise=self.randomize_noise,
                noise=self.n_rec_B)[0]
            if self.opt.decoder_size > self.opt.crop_size:
                self.rec_B = F.interpolate(self.rec_B, self.opt.crop_size)

            self.pred_real_A = self.netf_s(self.real_A)

            self.gt_pred_A = F.log_softmax(self.pred_real_A,
                                           dim=d).argmax(dim=d)

            self.pred_real_B = self.netf_s(self.real_B)
            self.gt_pred_B = F.log_softmax(self.pred_real_B,
                                           dim=d).argmax(dim=d)

            self.pred_fake_A = self.netf_s(self.fake_A)

            self.pfA = F.log_softmax(self.pred_fake_A, dim=d)  #.argmax(dim=d)
            self.pfA_max = self.pfA.argmax(dim=d)

            if hasattr(self, 'criterionMask'):
                label_A = self.input_A_label
                label_A_in = label_A.unsqueeze(1)
                label_A_inv = torch.tensor(np.ones(label_A.size())).to(
                    self.device) - label_A
                label_A_inv = label_A_inv.unsqueeze(1)
                #label_A_inv = torch.cat ([label_A_inv,label_A_inv,label_A_inv],1)

                self.real_A_out_mask = self.real_A * label_A_inv
                self.fake_B_out_mask = self.fake_B * label_A_inv

                if self.D_noise:
                    self.fake_B_noisy = gaussian(self.fake_B)
                    self.real_A_noisy = gaussian(self.real_A)
                    #self.real_A_mask_in = self.aug_seq(self.real_A_mask_in)
                    #self.fake_B_mask_in = self.aug_seq(self.fake_B_mask_in)
                    #self.real_A_mask = self.aug_seq(self.real_A_mask)
                    #self.fake_B_mask = self.aug_seq(self.fake_B_mask)

                if hasattr(self, 'input_B_label'):

                    label_B = self.input_B_label
                    label_B_in = label_B.unsqueeze(1)
                    label_B_inv = torch.tensor(np.ones(label_B.size())).to(
                        self.device) - label_B
                    label_B_inv = label_B_inv.unsqueeze(1)
                    #label_B_inv = torch.cat ([label_B_inv,label_B_inv,label_B_inv],1)

                    self.real_B_out_mask = self.real_B * label_B_inv
                    self.fake_A_out_mask = self.fake_A * label_B_inv

                    if self.D_noise:
                        self.fake_A_noisy = gaussian(self.fake_A)
                        self.real_B_noisy = gaussian(self.real_B)
                        #self.real_B_mask_in = self.aug_seq(self.real_B_mask_in)
                        #self.fake_A_mask_in = self.aug_seq(self.fake_A_mask_in)
                        #self.real_B_mask = self.aug_seq(self.real_B_mask)
                        #self.fake_A_mask = self.aug_seq(self.fake_A_mask)

        self.pred_fake_B = self.netf_s(self.fake_B)
        self.pfB = F.log_softmax(self.pred_fake_B, dim=d)  #.argmax(dim=d)
        self.pfB_max = self.pfB.argmax(dim=d)
コード例 #9
0
    def forward(self):
        self.fake_B = self.netG_A(self.real_A)
        d = 1

        if self.isTrain:
            if self.rec_noise > 0.0:
                self.fake_B_noisy1 = gaussian(self.fake_B, self.rec_noise)
                self.rec_A = self.netG_B(self.fake_B_noisy1)
            else:
                self.rec_A = self.netG_B(self.fake_B)

            self.fake_A = self.netG_B(self.real_B)
            if self.rec_noise > 0.0:
                self.fake_A_noisy1 = gaussian(self.fake_A, self.rec_noise)
                self.rec_B = self.netG_A(self.fake_A_noisy1)
            else:
                self.rec_B = self.netG_A(self.fake_A)

            self.pred_real_A = self.netf_s(self.real_A)

            self.gt_pred_A = F.log_softmax(self.pred_real_A,
                                           dim=d).argmax(dim=d)

            self.pred_real_B = self.netf_s(self.real_B)
            self.gt_pred_B = F.log_softmax(self.pred_real_B,
                                           dim=d).argmax(dim=d)

            self.pred_fake_A = self.netf_s(self.fake_A)

            self.pfA = F.log_softmax(self.pred_fake_A, dim=d)  #.argmax(dim=d)
            self.pfA_max = self.pfA.argmax(dim=d)

            if hasattr(self, 'criterionMask'):
                label_A = self.input_A_label
                label_A_in = label_A.unsqueeze(1)
                label_A_inv = torch.tensor(np.ones(label_A.size())).to(
                    self.device) - label_A
                label_A_inv = label_A_inv.unsqueeze(1)
                #label_A_inv = torch.cat ([label_A_inv,label_A_inv,label_A_inv],1)

                self.real_A_out_mask = self.real_A * label_A_inv
                self.fake_B_out_mask = self.fake_B * label_A_inv

                if self.disc_in_mask:
                    self.real_A_mask_in = self.real_A * label_A_in
                    self.fake_B_mask_in = self.fake_B * label_A_in
                    self.real_A_mask = self.real_A  #* label_A_in + self.real_A_out_mask
                    self.fake_B_mask = self.fake_B_mask_in + self.real_A_out_mask.float(
                    )

                if self.D_noise > 0.0:
                    self.fake_B_noisy = gaussian(self.fake_B, self.D_noise)
                    self.real_A_noisy = gaussian(self.real_A, self.D_noise)

                if hasattr(self, 'input_B_label'):

                    label_B = self.input_B_label
                    label_B_in = label_B.unsqueeze(1)
                    label_B_inv = torch.tensor(np.ones(label_B.size())).to(
                        self.device) - label_B
                    label_B_inv = label_B_inv.unsqueeze(1)

                    self.real_B_out_mask = self.real_B * label_B_inv
                    self.fake_A_out_mask = self.fake_A * label_B_inv
                    if self.disc_in_mask:
                        self.real_B_mask_in = self.real_B * label_B_in
                        self.fake_A_mask_in = self.fake_A * label_B_in
                        self.real_B_mask = self.real_B  #* label_B_in + self.real_B_out_mask
                        self.fake_A_mask = self.fake_A_mask_in + self.real_B_out_mask.float(
                        )

                    if self.D_noise > 0.0:
                        self.fake_A_noisy = gaussian(self.fake_A, self.D_noise)
                        self.real_B_noisy = gaussian(self.real_B, self.D_noise)

        self.pred_fake_B = self.netf_s(self.fake_B)
        self.pfB = F.log_softmax(self.pred_fake_B, dim=d)  #.argmax(dim=d)
        self.pfB_max = self.pfB.argmax(dim=d)