def forward(self): self.fake_x = self.netG(self.real_x, self.real_y) self.fake_x_IP = upsample2d(self.fake_x, self.opt.fineSize_IP) self.fake_x_E = upsample2d(self.fake_x, self.opt.fineSize_E) self.fake_y = self.netE(self.real_x_E) self.rec_x = self.netG(self.real_x, self.fake_y) self.rec_y = self.netE(self.fake_x_E)
def set_input(self, input): if self.isTrain: if not self.opt.no_mixed_label_D: self.real_A = input['A'].to(self.device) self.real_B = input['B'].to(self.device) self.age_A = input['A_age'].to(self.device) self.age_B = input['B_age'].to(self.device) self.label_AB = input['label'] self.image_paths = input['B_paths'] else: # sample a label for D (A < B or A > B) self.label_AB = [ np.random.choice(range(len(self.relabel_D)), p=self.weight_label_D) ] self.real_A = input[str(self.label_AB[0]) + '_A'].to( self.device) self.real_B = input[str(self.label_AB[0]) + '_B'].to( self.device) self.age_A = input[str(self.label_AB[0]) + '_A_age'].to( self.device) self.age_B = input[str(self.label_AB[0]) + '_B_age'].to( self.device) self.image_paths = input[str(self.label_AB[0]) + '_B_paths'] self.real_A_IP = upsample2d(self.real_A, self.opt.fineSize_IP) self.real_A_E = upsample2d(self.real_A, self.opt.fineSize_E) self.real_B_E = upsample2d(self.real_B, self.opt.fineSize_E) else: self.real_A = input['A'].to(self.device) self.image_paths = input['A_paths'] self.current_iter += 1 self.current_batch_size = int(self.real_A.size(0))
def forward(self): self.fake_B = self.netG(self.real_A, self.one_hot_labels[self.label_B]) self.fake_B_IP = upsample2d(self.fake_B, self.opt.fineSize_IP) self.fake_B_AC = upsample2d(self.fake_B, self.opt.fineSize_AC) if not self.opt.detach_fake_B: self.rec_A = self.netG(self.fake_B, self.one_hot_labels[self.label_A]) else: self.rec_A = self.netG(self.fake_B.detach(), self.one_hot_labels[self.label_A])
def forward(self): self.embedding_A = self.embedding_normalize(self.age_A) self.embedding_B = self.embedding_normalize(self.age_B) self.fake_B = self.netG(self.real_A, self.embedding_B) self.fake_B_IP = upsample2d(self.fake_B, self.opt.fineSize_IP) self.fake_B_E = upsample2d(self.fake_B, self.opt.fineSize_E) if not self.opt.detach_fake_B: self.rec_A = self.netG(self.fake_B, self.embedding_A) else: self.rec_A = self.netG(self.fake_B.detach(), self.embedding_A)
def set_input(self, input): if self.isTrain: self.label_A, self.label_B, self.label_B_not = self.sample_labels() self.real_A = input[self.label_A].to(self.device) self.real_B = input[self.label_B].to(self.device) self.real_A_IP = upsample2d(self.real_A, self.opt.fineSize_IP) self.real_B_AC = upsample2d(self.real_B, self.opt.fineSize_AC) self.image_paths = input['path_' + str(self.label_B)] else: self.real_A = input['A'].to(self.device) self.image_paths = input['A_paths'] self.current_iter += 1 self.current_batch_size = int(self.real_A.size(0))
def forward(self): if self.opt.lr_E > 0.0: self.embedding_A = self.embedding_normalize(self.netE(self.transform_E(self.real_A_E))) self.embedding_B = self.embedding_normalize(self.netE(self.transform_E(self.real_B_E))) else: self.embedding_A = self.embedding_normalize(self.netE(self.transform_E(self.real_A_E))).detach() self.embedding_B = self.embedding_normalize(self.netE(self.transform_E(self.real_B_E))).detach() self.fake_B = self.netG(self.real_A, self.embedding_B) self.fake_B_IP = upsample2d(self.fake_B, self.opt.fineSize_IP) self.fake_B_E = upsample2d(self.fake_B, self.opt.fineSize_E) if not self.opt.detach_fake_B: self.rec_A = self.netG(self.fake_B, self.embedding_A) else: self.rec_A = self.netG(self.fake_B.detach(), self.embedding_A)
def set_input(self, input): if self.isTrain: self.real_x = input['A'].to(self.device) self.real_y = self.attr_normalize(input['B_attr'].to(self.device)) self.image_paths = input['B_paths'] self.real_x_IP = upsample2d(self.real_x, self.opt.fineSize_IP) self.real_x_E = upsample2d(self.real_x, self.opt.fineSize_E) else: self.real_x = input['A'].to(self.device) self.image_paths = input['A_paths'] if 'B_attr' in input: self.real_y = self.attr_normalize(input['B_attr'].to( self.device)) self.image_paths = input['B_paths'] self.current_iter += 1 self.current_batch_size = int(self.real_x.size(0))
def forward(self): self.real_A_IP = upsample2d(self.real_A, self.opt.fineSize_IP) self.real_A_E = upsample2d(self.real_A, self.opt.fineSize_E) self.real_B_E = upsample2d(self.real_B, self.opt.fineSize_E) if not self.opt.bayesian and not self.opt.noisy: y_A = self.netE(self.transform_E(self.real_A_E)) y_B = self.netE(self.transform_E(self.real_B_E)) elif not self.opt.bayesian and self.opt.noisy: y_A, logvar_A = self.netE(self.transform_E(self.real_A_E)) y_B, logvar_B = self.netE(self.transform_E(self.real_B_E)) if 'a' in self.opt.noisy_var_type: y_A_s2 = torch.exp(logvar_A) y_B_s2 = torch.exp(logvar_B) self.resample_A = self.embedding_normalize(resample(y_A, y_A_s2)) self.resample_B = self.embedding_normalize(resample(y_B, y_B_s2)) elif self.opt.bayesian and not self.opt.noisy: y_A, y_A_var = compute_mu_and_var(self.netE, self.transform_E(self.real_A_E), self.opt.bnn_T, False) y_B, y_B_var = compute_mu_and_var(self.netE, self.transform_E(self.real_B_E), self.opt.bnn_T, False) if 'e' in self.opt.noisy_var_type: self.resample_A = self.embedding_normalize(resample(y_A, y_A_var)) self.resample_B = self.embedding_normalize(resample(y_B, y_B_var)) else: # bayesian and noisy y_A, y_A_var, y_A_s2 = compute_mu_and_var(self.netE, self.transform_E(self.real_A_E), self.opt.bnn_T, True) y_B, y_B_var, y_B_s2 = compute_mu_and_var(self.netE, self.transform_E(self.real_B_E), self.opt.bnn_T, True) if 'a' in self.opt.noisy_var_type: self.resample_A = self.embedding_normalize(resample(y_A, y_A_s2 + y_A_var)) self.resample_B = self.embedding_normalize(resample(y_B, y_B_s2 + y_B_var)) self.y_A = y_A self.y_B = y_B self.embedding_A = self.embedding_normalize(self.y_A) self.embedding_B = self.embedding_normalize(self.y_B) if self.opt.lr_E <= 0.0: self.y_A = self.y_A.detach() self.y_B = self.y_B.detach() self.embedding_A = self.embedding_A.detach() self.embedding_B = self.embedding_B.detach() if self.opt.noisy_var_type: # noisy_var_type is not empty self.resample_A = self.resample_A.detach() self.resample_B = self.resample_B.detach() self.fake_B = self.netG(self.real_A, self.embedding_B) self.fake_B_IP = upsample2d(self.fake_B, self.opt.fineSize_IP) self.fake_B_E = upsample2d(self.fake_B, self.opt.fineSize_E) if not self.opt.detach_fake_B: self.rec_A = self.netG(self.fake_B, self.embedding_A) else: self.rec_A = self.netG(self.fake_B.detach(), self.embedding_A)
def sample_from_prior(self): real_B_E = upsample2d(self.real_B, self.opt.fineSize_E) # A -> B, sample embeddings according to B if not self.opt.bayesian and not self.opt.noisy: y_B = self.netE(self.transform_E(real_B_E)) elif not self.opt.bayesian and self.opt.noisy: y_B, logvar_B = self.netE(self.transform_E(self.real_B_E)) elif self.opt.bayesian and not self.opt.noisy: y_B, y_B_var = compute_mu_and_var(self.netE, self.transform_E(self.real_B_E), self.opt.bnn_T, False) else: # bayesian and noisy y_B, y_B_var, s2_B = compute_mu_and_var(self.netE, self.transform_E(self.real_B_E), self.opt.bnn_T, True) self.embedding_B = self.embedding_normalize(y_B.detach()) return self.netG(self.real_A, self.embedding_B)
def test(self): if hasattr(self, 'real_B'): real_B_E = upsample2d(self.real_B, self.opt.fineSize_E) if 'real_B' not in self.visual_names: self.visual_names += ['real_B', 'fake_B'] if not self.opt.bayesian and not self.opt.noisy: y_B = self.netE(self.transform_E(real_B_E)) elif not self.opt.bayesian and self.opt.noisy: y_B, logvar_B = self.netE(self.transform_E(self.real_B_E)) elif self.opt.bayesian and not self.opt.noisy: y_B, y_B_var = compute_mu_and_var(self.netE, self.transform_E(self.real_B_E), self.opt.bnn_T, False) else: # bayesian and noisy y_B, y_B_var, s2_B = compute_mu_and_var(self.netE, self.transform_E(self.real_B_E), self.opt.bnn_T, True) self.embedding_B = self.embedding_normalize(y_B.detach()) self.fake_B = self.netG(self.real_A, self.embedding_B) return