def forward(self, z, label, shared_label=None, eval=False): affine_list = [] with torch.cuda.amp.autocast( ) if self.mixed_precision and not eval else misc.dummy_context_mgr( ) as mp: if self.MODEL.info_type != "N/A": if self.g_info_injection == "concat": z = self.info_mix_linear(z) elif self.g_info_injection == "cBN": z, z_info = z[:, :self.z_dim], z[:, self.z_dim:] affine_list.append(self.info_proj_linear(z_info)) if self.g_cond_mtd != "W/O": if shared_label is None: shared_label = self.shared(label) affine_list.append(shared_label) if len(affine_list) > 0: affines = torch.cat(affine_list, 1) else: affines = None act = self.linear0(z) act = act.view(-1, self.in_dims[0], 4, 4) for index, blocklist in enumerate(self.blocks): for block in blocklist: if isinstance(block, ops.SelfAttention): act = block(act) else: act = block(act, affines) act = self.conv4(act) out = self.tanh(act) return out
def forward(self, z, label, shared_label=None, eval=False): affine_list = [] with torch.cuda.amp.autocast( ) if self.mixed_precision and not eval else misc.dummy_context_mgr( ) as mp: if self.MODEL.info_type != "N/A": if self.MODEL.g_info_injection == "concat": z = self.info_mix_linear(z) elif self.MODEL.g_info_injection == "cBN": z, z_info = z[:, :self.z_dim], z[:, self.z_dim:] affine_list.append(self.info_proj_linear(z_info)) zs = torch.split(z, self.chunk_size, 1) z = zs[0] if self.g_cond_mtd != "W/O": if shared_label is None: shared_label = self.shared(label) affine_list.append(shared_label) if len(affine_list) == 0: affines = [item for item in zs[1:]] else: affines = [ torch.cat(affine_list + [item], 1) for item in zs[1:] ] act = self.linear0(z) act = act.view(-1, self.in_dims[0], self.bottom, self.bottom) counter = 0 for index, blocklist in enumerate(self.blocks): for block in blocklist: if isinstance(block, ops.SelfAttention): act = block(act) else: act = block(act, affines[counter]) counter += 1 act = self.bn4(act) act = self.activation(act) act = self.conv2d5(act) out = self.tanh(act) return out
def forward(self, x, label, eval=False, adc_fake=False): with torch.cuda.amp.autocast( ) if self.mixed_precision and not eval else misc.dummy_context_mgr( ) as mp: embed, proxy, cls_output = None, None, None mi_embed, mi_proxy, mi_cls_output = None, None, None info_discrete_c_logits, info_conti_mu, info_conti_var = None, None, None h = x for index, blocklist in enumerate(self.blocks): for block in blocklist: h = block(h) h = self.conv1(h) if not self.apply_d_sn: h = self.bn1(h) bottom_h, bottom_w = h.shape[2], h.shape[3] h = self.activation(h) h = torch.sum(h, dim=[2, 3]) # adversarial training adv_output = torch.squeeze(self.linear1(h)) # make class labels odd (for fake) or even (for real) for ADC if self.aux_cls_type == "ADC": if adc_fake: label = label * 2 + 1 else: label = label * 2 # forward pass through InfoGAN Q head if self.MODEL.info_type in ["discrete", "both"]: info_discrete_c_logits = self.info_discrete_linear( h / (bottom_h * bottom_w)) if self.MODEL.info_type in ["continuous", "both"]: info_conti_mu = self.info_conti_mu_linear( h / (bottom_h * bottom_w)) info_conti_var = torch.exp( self.info_conti_var_linear(h / (bottom_h * bottom_w))) # class conditioning if self.d_cond_mtd == "AC": if self.normalize_d_embed: for W in self.linear2.parameters(): W = F.normalize(W, dim=1) h = F.normalize(h, dim=1) cls_output = self.linear2(h) elif self.d_cond_mtd == "PD": adv_output = adv_output + torch.sum( torch.mul(self.embedding(label), h), 1) elif self.d_cond_mtd in ["2C", "D2DCE"]: embed = self.linear2(h) proxy = self.embedding(label) if self.normalize_d_embed: embed = F.normalize(embed, dim=1) proxy = F.normalize(proxy, dim=1) elif self.d_cond_mtd == "MD": idx = torch.LongTensor(range(label.size(0))).to(label.device) adv_output = adv_output[idx, label] elif self.d_cond_mtd in ["W/O", "MH"]: pass else: raise NotImplementedError # extra conditioning for TACGAN and ADCGAN if self.aux_cls_type == "TAC": if self.d_cond_mtd == "AC": if self.normalize_d_embed: for W in self.linear_mi.parameters(): W = F.normalize(W, dim=1) mi_cls_output = self.linear_mi(h) elif self.d_cond_mtd in ["2C", "D2DCE"]: mi_embed = self.linear_mi(h) mi_proxy = self.embedding_mi(label) if self.normalize_d_embed: mi_embed = F.normalize(mi_embed, dim=1) mi_proxy = F.normalize(mi_proxy, dim=1) return { "h": h, "adv_output": adv_output, "embed": embed, "proxy": proxy, "cls_output": cls_output, "label": label, "mi_embed": mi_embed, "mi_proxy": mi_proxy, "mi_cls_output": mi_cls_output, "info_discrete_c_logits": info_discrete_c_logits, "info_conti_mu": info_conti_mu, "info_conti_var": info_conti_var }