def __init__(self, z_dim, g_shared_dim, img_size, g_conv_dim, apply_attn,
                 attn_g_loc, g_cond_mtd, num_classes, g_init, g_depth,
                 mixed_precision, MODULES, MODEL):
        super(Generator, self).__init__()
        self.in_dims = [512, 256, 128]
        self.out_dims = [256, 128, 64]

        self.z_dim = z_dim
        self.num_classes = num_classes
        self.g_cond_mtd = g_cond_mtd
        self.mixed_precision = mixed_precision
        self.MODEL = MODEL
        self.affine_input_dim = 0

        info_dim = 0
        if self.MODEL.info_type in ["discrete", "both"]:
            info_dim += self.MODEL.info_num_discrete_c * self.MODEL.info_dim_discrete_c
        if self.MODEL.info_type in ["continuous", "both"]:
            info_dim += self.MODEL.info_num_conti_c

        self.g_info_injection = self.MODEL.g_info_injection
        if self.MODEL.info_type != "N/A":
            if self.g_info_injection == "concat":
                self.info_mix_linear = MODULES.g_linear(
                    in_features=self.z_dim + info_dim,
                    out_features=self.z_dim,
                    bias=True)
            elif self.g_info_injection == "cBN":
                self.affine_input_dim += self.z_dim
                self.info_proj_linear = MODULES.g_linear(
                    in_features=info_dim, out_features=self.z_dim, bias=True)

        if self.g_cond_mtd != "W/O":
            self.affine_input_dim += self.z_dim
            self.shared = ops.embedding(num_embeddings=self.num_classes,
                                        embedding_dim=self.z_dim)

        self.linear0 = MODULES.g_linear(in_features=self.z_dim,
                                        out_features=self.in_dims[0] * 4 * 4,
                                        bias=True)

        self.blocks = []
        for index in range(len(self.in_dims)):
            self.blocks += [[
                GenBlock(in_channels=self.in_dims[index],
                         out_channels=self.out_dims[index],
                         g_cond_mtd=self.g_cond_mtd,
                         g_info_injection=self.g_info_injection,
                         affine_input_dim=self.affine_input_dim,
                         MODULES=MODULES)
            ]]

            if index + 1 in attn_g_loc and apply_attn:
                self.blocks += [[
                    ops.SelfAttention(self.out_dims[index],
                                      is_generator=True,
                                      MODULES=MODULES)
                ]]

        self.blocks = nn.ModuleList(
            [nn.ModuleList(block) for block in self.blocks])

        self.conv4 = MODULES.g_conv2d(in_channels=self.out_dims[-1],
                                      out_channels=3,
                                      kernel_size=3,
                                      stride=1,
                                      padding=1)
        self.tanh = nn.Tanh()

        ops.init_weights(self.modules, g_init)
示例#2
0
    def __init__(self, z_dim, g_shared_dim, img_size, g_conv_dim, apply_attn,
                 attn_g_loc, g_cond_mtd, num_classes, g_init, g_depth,
                 mixed_precision, MODULES, MODEL):
        super(Generator, self).__init__()
        g_in_dims_collection = {
            "32": [g_conv_dim * 4, g_conv_dim * 4, g_conv_dim * 4],
            "64":
            [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2],
            "128": [
                g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8,
                g_conv_dim * 4, g_conv_dim * 2
            ],
            "256": [
                g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8,
                g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2
            ],
            "512": [
                g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8,
                g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim
            ]
        }

        g_out_dims_collection = {
            "32": [g_conv_dim * 4, g_conv_dim * 4, g_conv_dim * 4],
            "64": [g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim],
            "128": [
                g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4,
                g_conv_dim * 2, g_conv_dim
            ],
            "256": [
                g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8,
                g_conv_dim * 4, g_conv_dim * 2, g_conv_dim
            ],
            "512": [
                g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8,
                g_conv_dim * 4, g_conv_dim * 2, g_conv_dim, g_conv_dim
            ]
        }

        bottom_collection = {"32": 4, "64": 4, "128": 4, "256": 4, "512": 4}

        self.z_dim = z_dim
        self.num_classes = num_classes
        self.g_cond_mtd = g_cond_mtd
        self.mixed_precision = mixed_precision
        self.MODEL = MODEL
        self.in_dims = g_in_dims_collection[str(img_size)]
        self.out_dims = g_out_dims_collection[str(img_size)]
        self.bottom = bottom_collection[str(img_size)]
        self.num_blocks = len(self.in_dims)
        self.affine_input_dim = 0

        info_dim = 0
        if self.MODEL.info_type in ["discrete", "both"]:
            info_dim += self.MODEL.info_num_discrete_c * self.MODEL.info_dim_discrete_c
        if self.MODEL.info_type in ["continuous", "both"]:
            info_dim += self.MODEL.info_num_conti_c

        self.g_info_injection = self.MODEL.g_info_injection
        if self.MODEL.info_type != "N/A":
            if self.g_info_injection == "concat":
                self.info_mix_linear = MODULES.g_linear(
                    in_features=self.z_dim + info_dim,
                    out_features=self.z_dim,
                    bias=True)
            elif self.g_info_injection == "cBN":
                self.affine_input_dim += self.z_dim
                self.info_proj_linear = MODULES.g_linear(
                    in_features=info_dim, out_features=self.z_dim, bias=True)

        self.linear0 = MODULES.g_linear(in_features=self.z_dim,
                                        out_features=self.in_dims[0] *
                                        self.bottom * self.bottom,
                                        bias=True)

        if self.g_cond_mtd != "W/O":
            self.affine_input_dim += self.z_dim
            self.shared = ops.embedding(num_embeddings=self.num_classes,
                                        embedding_dim=self.z_dim)

        self.blocks = []
        for index in range(self.num_blocks):
            self.blocks += [[
                GenBlock(in_channels=self.in_dims[index],
                         out_channels=self.out_dims[index],
                         g_cond_mtd=self.g_cond_mtd,
                         g_info_injection=self.g_info_injection,
                         affine_input_dim=self.affine_input_dim,
                         MODULES=MODULES)
            ]]

            if index + 1 in attn_g_loc and apply_attn:
                self.blocks += [[
                    ops.SelfAttention(self.out_dims[index],
                                      is_generator=True,
                                      MODULES=MODULES)
                ]]

        self.blocks = nn.ModuleList(
            [nn.ModuleList(block) for block in self.blocks])

        self.bn4 = ops.batchnorm_2d(in_features=self.out_dims[-1])
        self.activation = MODULES.g_act_fn
        self.conv2d5 = MODULES.g_conv2d(in_channels=self.out_dims[-1],
                                        out_channels=3,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)
        self.tanh = nn.Tanh()

        ops.init_weights(self.modules, g_init)
    def __init__(self, img_size, d_conv_dim, apply_d_sn, apply_attn,
                 attn_d_loc, d_cond_mtd, aux_cls_type, d_embed_dim,
                 normalize_d_embed, num_classes, d_init, d_depth,
                 mixed_precision, MODULES, MODEL):
        super(Discriminator, self).__init__()
        self.in_dims = [3] + [64, 128]
        self.out_dims = [64, 128, 256]

        self.apply_d_sn = apply_d_sn
        self.d_cond_mtd = d_cond_mtd
        self.aux_cls_type = aux_cls_type
        self.normalize_d_embed = normalize_d_embed
        self.num_classes = num_classes
        self.mixed_precision = mixed_precision
        self.MODEL = MODEL

        self.blocks = []
        for index in range(len(self.in_dims)):
            self.blocks += [[
                DiscBlock(in_channels=self.in_dims[index],
                          out_channels=self.out_dims[index],
                          apply_d_sn=self.apply_d_sn,
                          MODULES=MODULES)
            ]]

            if index + 1 in attn_d_loc and apply_attn:
                self.blocks += [[
                    ops.SelfAttention(self.out_dims[index],
                                      is_generator=False,
                                      MODULES=MODULES)
                ]]

        self.blocks = nn.ModuleList(
            [nn.ModuleList(block) for block in self.blocks])

        self.activation = MODULES.d_act_fn
        self.conv1 = MODULES.d_conv2d(in_channels=256,
                                      out_channels=512,
                                      kernel_size=3,
                                      stride=1,
                                      padding=1)

        if not self.apply_d_sn:
            self.bn1 = MODULES.d_bn(in_features=512)

        # linear layer for adversarial training
        if self.d_cond_mtd == "MH":
            self.linear1 = MODULES.d_linear(in_features=512,
                                            out_features=1 + num_classes,
                                            bias=True)
        elif self.d_cond_mtd == "MD":
            self.linear1 = MODULES.d_linear(in_features=512,
                                            out_features=num_classes,
                                            bias=True)
        else:
            self.linear1 = MODULES.d_linear(in_features=512,
                                            out_features=1,
                                            bias=True)

        # double num_classes for Auxiliary Discriminative Classifier
        if self.aux_cls_type == "ADC":
            num_classes = num_classes * 2

        # linear and embedding layers for discriminator conditioning
        if self.d_cond_mtd == "AC":
            self.linear2 = MODULES.d_linear(in_features=512,
                                            out_features=num_classes,
                                            bias=False)
        elif self.d_cond_mtd == "PD":
            self.embedding = MODULES.d_embedding(num_classes, 512)
        elif self.d_cond_mtd in ["2C", "D2DCE"]:
            self.linear2 = MODULES.d_linear(in_features=512,
                                            out_features=d_embed_dim,
                                            bias=True)
            self.embedding = MODULES.d_embedding(num_classes, d_embed_dim)
        else:
            pass

        # linear and embedding layers for evolved classifier-based GAN
        if self.aux_cls_type == "TAC":
            if self.d_cond_mtd == "AC":
                self.linear_mi = MODULES.d_linear(in_features=512,
                                                  out_features=num_classes,
                                                  bias=False)
            elif self.d_cond_mtd in ["2C", "D2DCE"]:
                self.linear_mi = MODULES.d_linear(in_features=512,
                                                  out_features=d_embed_dim,
                                                  bias=True)
                self.embedding_mi = MODULES.d_embedding(
                    num_classes, d_embed_dim)
            else:
                raise NotImplementedError

        # Q head network for infoGAN
        if self.MODEL.info_type in ["discrete", "both"]:
            out_features = self.MODEL.info_num_discrete_c * self.MODEL.info_dim_discrete_c
            self.info_discrete_linear = MODULES.d_linear(
                in_features=512, out_features=out_features, bias=False)
        if self.MODEL.info_type in ["continuous", "both"]:
            out_features = self.MODEL.info_num_conti_c
            self.info_conti_mu_linear = MODULES.d_linear(
                in_features=512, out_features=out_features, bias=False)
            self.info_conti_var_linear = MODULES.d_linear(
                in_features=512, out_features=out_features, bias=False)

        if d_init:
            ops.init_weights(self.modules, d_init)
示例#4
0
    def __init__(self, img_size, d_conv_dim, apply_d_sn, apply_attn,
                 attn_d_loc, d_cond_mtd, aux_cls_type, d_embed_dim,
                 normalize_d_embed, num_classes, d_init, d_depth,
                 mixed_precision, MODULES, MODEL):
        super(Discriminator, self).__init__()
        d_in_dims_collection = {
            "32": [3] + [d_conv_dim * 2, d_conv_dim * 2, d_conv_dim * 2],
            "64":
            [3] + [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8],
            "128": [3] + [
                d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8,
                d_conv_dim * 16
            ],
            "256": [3] + [
                d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8,
                d_conv_dim * 8, d_conv_dim * 16
            ],
            "512": [3] + [
                d_conv_dim, d_conv_dim, d_conv_dim * 2, d_conv_dim * 4,
                d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16
            ]
        }

        d_out_dims_collection = {
            "32":
            [d_conv_dim * 2, d_conv_dim * 2, d_conv_dim * 2, d_conv_dim * 2],
            "64": [
                d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8,
                d_conv_dim * 16
            ],
            "128": [
                d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8,
                d_conv_dim * 16, d_conv_dim * 16
            ],
            "256": [
                d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8,
                d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16
            ],
            "512": [
                d_conv_dim, d_conv_dim, d_conv_dim * 2, d_conv_dim * 4,
                d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16,
                d_conv_dim * 16
            ]
        }

        d_down = {
            "32": [True, True, False, False],
            "64": [True, True, True, True, False],
            "128": [True, True, True, True, True, False],
            "256": [True, True, True, True, True, True, False],
            "512": [True, True, True, True, True, True, True, False]
        }

        self.d_cond_mtd = d_cond_mtd
        self.aux_cls_type = aux_cls_type
        self.normalize_d_embed = normalize_d_embed
        self.num_classes = num_classes
        self.mixed_precision = mixed_precision
        self.in_dims = d_in_dims_collection[str(img_size)]
        self.out_dims = d_out_dims_collection[str(img_size)]
        self.MODEL = MODEL
        down = d_down[str(img_size)]

        self.blocks = []
        for index in range(len(self.in_dims)):
            if index == 0:
                self.blocks += [[
                    DiscOptBlock(in_channels=self.in_dims[index],
                                 out_channels=self.out_dims[index],
                                 apply_d_sn=apply_d_sn,
                                 MODULES=MODULES)
                ]]
            else:
                self.blocks += [[
                    DiscBlock(in_channels=self.in_dims[index],
                              out_channels=self.out_dims[index],
                              apply_d_sn=apply_d_sn,
                              MODULES=MODULES,
                              downsample=down[index])
                ]]

            if index + 1 in attn_d_loc and apply_attn:
                self.blocks += [[
                    ops.SelfAttention(self.out_dims[index],
                                      is_generator=False,
                                      MODULES=MODULES)
                ]]

        self.blocks = nn.ModuleList(
            [nn.ModuleList(block) for block in self.blocks])

        self.activation = MODULES.d_act_fn

        # linear layer for adversarial training
        if self.d_cond_mtd == "MH":
            self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1],
                                            out_features=1 + num_classes,
                                            bias=True)
        elif self.d_cond_mtd == "MD":
            self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1],
                                            out_features=num_classes,
                                            bias=True)
        else:
            self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1],
                                            out_features=1,
                                            bias=True)

        # double num_classes for Auxiliary Discriminative Classifier
        if self.aux_cls_type == "ADC":
            num_classes = num_classes * 2

        # linear and embedding layers for discriminator conditioning
        if self.d_cond_mtd == "AC":
            self.linear2 = MODULES.d_linear(in_features=self.out_dims[-1],
                                            out_features=num_classes,
                                            bias=False)
        elif self.d_cond_mtd == "PD":
            self.embedding = MODULES.d_embedding(num_classes,
                                                 self.out_dims[-1])
        elif self.d_cond_mtd in ["2C", "D2DCE"]:
            self.linear2 = MODULES.d_linear(in_features=self.out_dims[-1],
                                            out_features=d_embed_dim,
                                            bias=True)
            self.embedding = MODULES.d_embedding(num_classes, d_embed_dim)
        else:
            pass

        # linear and embedding layers for evolved classifier-based GAN
        if self.aux_cls_type == "TAC":
            if self.d_cond_mtd == "AC":
                self.linear_mi = MODULES.d_linear(
                    in_features=self.out_dims[-1],
                    out_features=num_classes,
                    bias=False)
            elif self.d_cond_mtd in ["2C", "D2DCE"]:
                self.linear_mi = MODULES.d_linear(
                    in_features=self.out_dims[-1],
                    out_features=d_embed_dim,
                    bias=True)
                self.embedding_mi = MODULES.d_embedding(
                    num_classes, d_embed_dim)
            else:
                raise NotImplementedError

        # Q head network for infoGAN
        if self.MODEL.info_type in ["discrete", "both"]:
            out_features = self.MODEL.info_num_discrete_c * self.MODEL.info_dim_discrete_c
            self.info_discrete_linear = MODULES.d_linear(
                in_features=self.out_dims[-1],
                out_features=out_features,
                bias=False)
        if self.MODEL.info_type in ["continuous", "both"]:
            out_features = self.MODEL.info_num_conti_c
            self.info_conti_mu_linear = MODULES.d_linear(
                in_features=self.out_dims[-1],
                out_features=out_features,
                bias=False)
            self.info_conti_var_linear = MODULES.d_linear(
                in_features=self.out_dims[-1],
                out_features=out_features,
                bias=False)

        if d_init:
            ops.init_weights(self.modules, d_init)