def __init__(self, z_dim, g_shared_dim, img_size, g_conv_dim, apply_attn, attn_g_loc, g_cond_mtd, num_classes, g_init, g_depth, mixed_precision, MODULES, MODEL): super(Generator, self).__init__() self.in_dims = [512, 256, 128] self.out_dims = [256, 128, 64] self.z_dim = z_dim self.num_classes = num_classes self.g_cond_mtd = g_cond_mtd self.mixed_precision = mixed_precision self.MODEL = MODEL self.affine_input_dim = 0 info_dim = 0 if self.MODEL.info_type in ["discrete", "both"]: info_dim += self.MODEL.info_num_discrete_c * self.MODEL.info_dim_discrete_c if self.MODEL.info_type in ["continuous", "both"]: info_dim += self.MODEL.info_num_conti_c self.g_info_injection = self.MODEL.g_info_injection if self.MODEL.info_type != "N/A": if self.g_info_injection == "concat": self.info_mix_linear = MODULES.g_linear( in_features=self.z_dim + info_dim, out_features=self.z_dim, bias=True) elif self.g_info_injection == "cBN": self.affine_input_dim += self.z_dim self.info_proj_linear = MODULES.g_linear( in_features=info_dim, out_features=self.z_dim, bias=True) if self.g_cond_mtd != "W/O": self.affine_input_dim += self.z_dim self.shared = ops.embedding(num_embeddings=self.num_classes, embedding_dim=self.z_dim) self.linear0 = MODULES.g_linear(in_features=self.z_dim, out_features=self.in_dims[0] * 4 * 4, bias=True) self.blocks = [] for index in range(len(self.in_dims)): self.blocks += [[ GenBlock(in_channels=self.in_dims[index], out_channels=self.out_dims[index], g_cond_mtd=self.g_cond_mtd, g_info_injection=self.g_info_injection, affine_input_dim=self.affine_input_dim, MODULES=MODULES) ]] if index + 1 in attn_g_loc and apply_attn: self.blocks += [[ ops.SelfAttention(self.out_dims[index], is_generator=True, MODULES=MODULES) ]] self.blocks = nn.ModuleList( [nn.ModuleList(block) for block in self.blocks]) self.conv4 = MODULES.g_conv2d(in_channels=self.out_dims[-1], out_channels=3, kernel_size=3, stride=1, padding=1) self.tanh = nn.Tanh() ops.init_weights(self.modules, g_init)
def __init__(self, z_dim, g_shared_dim, img_size, g_conv_dim, apply_attn, attn_g_loc, g_cond_mtd, num_classes, g_init, g_depth, mixed_precision, MODULES, MODEL): super(Generator, self).__init__() g_in_dims_collection = { "32": [g_conv_dim * 4, g_conv_dim * 4, g_conv_dim * 4], "64": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2], "128": [ g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2 ], "256": [ g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2 ], "512": [ g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim ] } g_out_dims_collection = { "32": [g_conv_dim * 4, g_conv_dim * 4, g_conv_dim * 4], "64": [g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim], "128": [ g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim ], "256": [ g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim ], "512": [ g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim, g_conv_dim ] } bottom_collection = {"32": 4, "64": 4, "128": 4, "256": 4, "512": 4} self.z_dim = z_dim self.num_classes = num_classes self.g_cond_mtd = g_cond_mtd self.mixed_precision = mixed_precision self.MODEL = MODEL self.in_dims = g_in_dims_collection[str(img_size)] self.out_dims = g_out_dims_collection[str(img_size)] self.bottom = bottom_collection[str(img_size)] self.num_blocks = len(self.in_dims) self.affine_input_dim = 0 info_dim = 0 if self.MODEL.info_type in ["discrete", "both"]: info_dim += self.MODEL.info_num_discrete_c * self.MODEL.info_dim_discrete_c if self.MODEL.info_type in ["continuous", "both"]: info_dim += self.MODEL.info_num_conti_c self.g_info_injection = self.MODEL.g_info_injection if self.MODEL.info_type != "N/A": if self.g_info_injection == "concat": self.info_mix_linear = MODULES.g_linear( in_features=self.z_dim + info_dim, out_features=self.z_dim, bias=True) elif self.g_info_injection == "cBN": self.affine_input_dim += self.z_dim self.info_proj_linear = MODULES.g_linear( in_features=info_dim, out_features=self.z_dim, bias=True) self.linear0 = MODULES.g_linear(in_features=self.z_dim, out_features=self.in_dims[0] * self.bottom * self.bottom, bias=True) if self.g_cond_mtd != "W/O": self.affine_input_dim += self.z_dim self.shared = ops.embedding(num_embeddings=self.num_classes, embedding_dim=self.z_dim) self.blocks = [] for index in range(self.num_blocks): self.blocks += [[ GenBlock(in_channels=self.in_dims[index], out_channels=self.out_dims[index], g_cond_mtd=self.g_cond_mtd, g_info_injection=self.g_info_injection, affine_input_dim=self.affine_input_dim, MODULES=MODULES) ]] if index + 1 in attn_g_loc and apply_attn: self.blocks += [[ ops.SelfAttention(self.out_dims[index], is_generator=True, MODULES=MODULES) ]] self.blocks = nn.ModuleList( [nn.ModuleList(block) for block in self.blocks]) self.bn4 = ops.batchnorm_2d(in_features=self.out_dims[-1]) self.activation = MODULES.g_act_fn self.conv2d5 = MODULES.g_conv2d(in_channels=self.out_dims[-1], out_channels=3, kernel_size=3, stride=1, padding=1) self.tanh = nn.Tanh() ops.init_weights(self.modules, g_init)
def __init__(self, img_size, d_conv_dim, apply_d_sn, apply_attn, attn_d_loc, d_cond_mtd, aux_cls_type, d_embed_dim, normalize_d_embed, num_classes, d_init, d_depth, mixed_precision, MODULES, MODEL): super(Discriminator, self).__init__() self.in_dims = [3] + [64, 128] self.out_dims = [64, 128, 256] self.apply_d_sn = apply_d_sn self.d_cond_mtd = d_cond_mtd self.aux_cls_type = aux_cls_type self.normalize_d_embed = normalize_d_embed self.num_classes = num_classes self.mixed_precision = mixed_precision self.MODEL = MODEL self.blocks = [] for index in range(len(self.in_dims)): self.blocks += [[ DiscBlock(in_channels=self.in_dims[index], out_channels=self.out_dims[index], apply_d_sn=self.apply_d_sn, MODULES=MODULES) ]] if index + 1 in attn_d_loc and apply_attn: self.blocks += [[ ops.SelfAttention(self.out_dims[index], is_generator=False, MODULES=MODULES) ]] self.blocks = nn.ModuleList( [nn.ModuleList(block) for block in self.blocks]) self.activation = MODULES.d_act_fn self.conv1 = MODULES.d_conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1) if not self.apply_d_sn: self.bn1 = MODULES.d_bn(in_features=512) # linear layer for adversarial training if self.d_cond_mtd == "MH": self.linear1 = MODULES.d_linear(in_features=512, out_features=1 + num_classes, bias=True) elif self.d_cond_mtd == "MD": self.linear1 = MODULES.d_linear(in_features=512, out_features=num_classes, bias=True) else: self.linear1 = MODULES.d_linear(in_features=512, out_features=1, bias=True) # double num_classes for Auxiliary Discriminative Classifier if self.aux_cls_type == "ADC": num_classes = num_classes * 2 # linear and embedding layers for discriminator conditioning if self.d_cond_mtd == "AC": self.linear2 = MODULES.d_linear(in_features=512, out_features=num_classes, bias=False) elif self.d_cond_mtd == "PD": self.embedding = MODULES.d_embedding(num_classes, 512) elif self.d_cond_mtd in ["2C", "D2DCE"]: self.linear2 = MODULES.d_linear(in_features=512, out_features=d_embed_dim, bias=True) self.embedding = MODULES.d_embedding(num_classes, d_embed_dim) else: pass # linear and embedding layers for evolved classifier-based GAN if self.aux_cls_type == "TAC": if self.d_cond_mtd == "AC": self.linear_mi = MODULES.d_linear(in_features=512, out_features=num_classes, bias=False) elif self.d_cond_mtd in ["2C", "D2DCE"]: self.linear_mi = MODULES.d_linear(in_features=512, out_features=d_embed_dim, bias=True) self.embedding_mi = MODULES.d_embedding( num_classes, d_embed_dim) else: raise NotImplementedError # Q head network for infoGAN if self.MODEL.info_type in ["discrete", "both"]: out_features = self.MODEL.info_num_discrete_c * self.MODEL.info_dim_discrete_c self.info_discrete_linear = MODULES.d_linear( in_features=512, out_features=out_features, bias=False) if self.MODEL.info_type in ["continuous", "both"]: out_features = self.MODEL.info_num_conti_c self.info_conti_mu_linear = MODULES.d_linear( in_features=512, out_features=out_features, bias=False) self.info_conti_var_linear = MODULES.d_linear( in_features=512, out_features=out_features, bias=False) if d_init: ops.init_weights(self.modules, d_init)
def __init__(self, img_size, d_conv_dim, apply_d_sn, apply_attn, attn_d_loc, d_cond_mtd, aux_cls_type, d_embed_dim, normalize_d_embed, num_classes, d_init, d_depth, mixed_precision, MODULES, MODEL): super(Discriminator, self).__init__() d_in_dims_collection = { "32": [3] + [d_conv_dim * 2, d_conv_dim * 2, d_conv_dim * 2], "64": [3] + [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8], "128": [3] + [ d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16 ], "256": [3] + [ d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16 ], "512": [3] + [ d_conv_dim, d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16 ] } d_out_dims_collection = { "32": [d_conv_dim * 2, d_conv_dim * 2, d_conv_dim * 2, d_conv_dim * 2], "64": [ d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16 ], "128": [ d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16 ], "256": [ d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16 ], "512": [ d_conv_dim, d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16 ] } d_down = { "32": [True, True, False, False], "64": [True, True, True, True, False], "128": [True, True, True, True, True, False], "256": [True, True, True, True, True, True, False], "512": [True, True, True, True, True, True, True, False] } self.d_cond_mtd = d_cond_mtd self.aux_cls_type = aux_cls_type self.normalize_d_embed = normalize_d_embed self.num_classes = num_classes self.mixed_precision = mixed_precision self.in_dims = d_in_dims_collection[str(img_size)] self.out_dims = d_out_dims_collection[str(img_size)] self.MODEL = MODEL down = d_down[str(img_size)] self.blocks = [] for index in range(len(self.in_dims)): if index == 0: self.blocks += [[ DiscOptBlock(in_channels=self.in_dims[index], out_channels=self.out_dims[index], apply_d_sn=apply_d_sn, MODULES=MODULES) ]] else: self.blocks += [[ DiscBlock(in_channels=self.in_dims[index], out_channels=self.out_dims[index], apply_d_sn=apply_d_sn, MODULES=MODULES, downsample=down[index]) ]] if index + 1 in attn_d_loc and apply_attn: self.blocks += [[ ops.SelfAttention(self.out_dims[index], is_generator=False, MODULES=MODULES) ]] self.blocks = nn.ModuleList( [nn.ModuleList(block) for block in self.blocks]) self.activation = MODULES.d_act_fn # linear layer for adversarial training if self.d_cond_mtd == "MH": self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=1 + num_classes, bias=True) elif self.d_cond_mtd == "MD": self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes, bias=True) else: self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=1, bias=True) # double num_classes for Auxiliary Discriminative Classifier if self.aux_cls_type == "ADC": num_classes = num_classes * 2 # linear and embedding layers for discriminator conditioning if self.d_cond_mtd == "AC": self.linear2 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes, bias=False) elif self.d_cond_mtd == "PD": self.embedding = MODULES.d_embedding(num_classes, self.out_dims[-1]) elif self.d_cond_mtd in ["2C", "D2DCE"]: self.linear2 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=d_embed_dim, bias=True) self.embedding = MODULES.d_embedding(num_classes, d_embed_dim) else: pass # linear and embedding layers for evolved classifier-based GAN if self.aux_cls_type == "TAC": if self.d_cond_mtd == "AC": self.linear_mi = MODULES.d_linear( in_features=self.out_dims[-1], out_features=num_classes, bias=False) elif self.d_cond_mtd in ["2C", "D2DCE"]: self.linear_mi = MODULES.d_linear( in_features=self.out_dims[-1], out_features=d_embed_dim, bias=True) self.embedding_mi = MODULES.d_embedding( num_classes, d_embed_dim) else: raise NotImplementedError # Q head network for infoGAN if self.MODEL.info_type in ["discrete", "both"]: out_features = self.MODEL.info_num_discrete_c * self.MODEL.info_dim_discrete_c self.info_discrete_linear = MODULES.d_linear( in_features=self.out_dims[-1], out_features=out_features, bias=False) if self.MODEL.info_type in ["continuous", "both"]: out_features = self.MODEL.info_num_conti_c self.info_conti_mu_linear = MODULES.d_linear( in_features=self.out_dims[-1], out_features=out_features, bias=False) self.info_conti_var_linear = MODULES.d_linear( in_features=self.out_dims[-1], out_features=out_features, bias=False) if d_init: ops.init_weights(self.modules, d_init)