Ejemplo n.º 1
0
    def add_norm_layer(layer):
        nonlocal norm_type
        if norm_type.startswith('spectral'):
            layer = sn(layer)
            subnorm_type = norm_type[len('spectral'):]

        if subnorm_type == 'none' or len(subnorm_type) == 0:
            return layer

        # remove bias in the previous layer, which is meaningless
        # since it has no effect after normalization
        if getattr(layer, 'bias', None) is not None:
            delattr(layer, 'bias')
            layer.register_parameter('bias', None)

        if subnorm_type == 'batch':
            norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
        elif subnorm_type == 'syncbatch':
            norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True)
        elif subnorm_type == 'instance':
            norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
        else:
            raise ValueError('normalization layer %s is not recognized' % subnorm_type)

        return nn.Sequential(layer, norm_layer)
Ejemplo n.º 2
0
    def __init__(self, config_text, norm_nc, label_nc, use_weight_norm=False):
        super().__init__()

        assert config_text.startswith("spade")
        parsed = re.search("spade(\D+)(\d)x\d", config_text)
        param_free_norm_type = str(parsed.group(1))
        ks = int(parsed.group(2))

        if param_free_norm_type == "instance":
            self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
        elif param_free_norm_type == "syncbatch":
            self.param_free_norm = SynchronizedBatchNorm2d(norm_nc,
                                                           affine=False)
        elif param_free_norm_type == "batch":
            self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
        else:
            raise ValueError(
                "%s is not a recognized param-free norm type in SPADE" %
                param_free_norm_type)

        # The dimension of the intermediate embedding space. Yes, hardcoded.
        nhidden = 128

        pw = ks // 2
        self.mlp_shared = nn.Sequential(
            nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
            nn.ReLU())
        self.mlp_gamma = nn.Conv2d(nhidden,
                                   norm_nc,
                                   kernel_size=ks,
                                   padding=pw)
        self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
        self.use_weight_norm = use_weight_norm
Ejemplo n.º 3
0
    def __init__(self,
                 config_text,
                 norm_nc,
                 label_nc,
                 no_instance=True,
                 add_dist=False):
        super().__init__()
        self.no_instance = no_instance
        self.add_dist = add_dist
        assert config_text.startswith('spade')
        parsed = re.search('spade(\D+)(\d)x\d', config_text)
        param_free_norm_type = str(parsed.group(1))
        ks = int(parsed.group(2))

        if param_free_norm_type == 'instance':
            self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
        elif param_free_norm_type == 'syncbatch':
            self.param_free_norm = SynchronizedBatchNorm2d(norm_nc,
                                                           affine=False)
        elif param_free_norm_type == 'batch':
            self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
        else:
            raise ValueError(
                '%s is not a recognized param-free norm type in SPADE' %
                param_free_norm_type)
        self.class_specified_affine = ClassAffine(label_nc, norm_nc, add_dist)

        if not no_instance:
            self.inst_conv = nn.Conv2d(1, 1, kernel_size=1, padding=0)
Ejemplo n.º 4
0
    def add_norm_layer(layer):
        nonlocal norm_type
        if norm_type.startswith("spectral"):
            layer = spectral_norm(layer)
            subnorm_type = norm_type[len("spectral"):]
        else:
            subnorm_type = norm_type

        if subnorm_type == "none" or len(subnorm_type) == 0:
            return layer

        # remove bias in the previous layer, which is meaningless
        # since it has no effect after normalization
        if getattr(layer, "bias", None) is not None:
            delattr(layer, "bias")
            layer.register_parameter("bias", None)

        if subnorm_type == "batch":
            norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
        elif subnorm_type == "sync_batch":
            norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer),
                                                 affine=True)
        elif subnorm_type == "instance":
            norm_layer = nn.InstanceNorm2d(get_out_channel(layer),
                                           affine=False)
        else:
            raise ValueError("normalization layer %s is not recognized" %
                             subnorm_type)

        return nn.Sequential(layer, norm_layer)
Ejemplo n.º 5
0
    def __init__(self, config_text, norm_nc, label_nc):
        super().__init__()

        assert config_text.startswith('fade')
        parsed = re.search('fade(\D+)(\d)x\d', config_text)
        param_free_norm_type = str(parsed.group(1))
        ks = int(parsed.group(2))

        if param_free_norm_type == 'instance':
            self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
        elif param_free_norm_type == 'syncbatch':
            self.param_free_norm = SynchronizedBatchNorm2d(norm_nc,
                                                           affine=False)
        elif param_free_norm_type == 'batch':
            self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
        else:
            raise ValueError(
                '%s is not a recognized param-free norm type in FADE' %
                param_free_norm_type)

        pw = ks // 2
        self.mlp_gamma = nn.Conv2d(label_nc,
                                   norm_nc,
                                   kernel_size=ks,
                                   padding=pw)
        self.mlp_beta = nn.Conv2d(label_nc,
                                  norm_nc,
                                  kernel_size=ks,
                                  padding=pw)
Ejemplo n.º 6
0
    def __init__(self, config_text, norm_nc, label_nc, opt, triple=False):
        super().__init__()

        assert config_text.startswith('spade')
        parsed = re.search('spade(\D+)(\d)x\d', config_text)
        param_free_norm_type = str(parsed.group(1))
        ks = int(parsed.group(2))

        if param_free_norm_type == 'instance':
            self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
        elif param_free_norm_type == 'syncbatch':
            self.param_free_norm = SynchronizedBatchNorm2d(norm_nc,
                                                           affine=False)
        elif param_free_norm_type == 'batch':
            self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
        else:
            raise ValueError(
                '%s is not a recognized param-free norm type in SPADE' %
                param_free_norm_type)

        # The dimension of the intermediate embedding space. Yes, hardcoded.
        nhidden = 128

        pw = ks // 2
        if opt.from_disp or triple:
            label_nc = opt.disp_nc + label_nc
        self.mlp_shared = nn.Sequential(
            nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
            nn.ReLU())
        self.mlp_gamma = nn.Conv2d(nhidden,
                                   norm_nc,
                                   kernel_size=ks,
                                   padding=pw)
        self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
Ejemplo n.º 7
0
    def __init__(self, config_text, norm_nc, label_nc, mask_emb_dim, attr_nc,
                 attr_emb_dim):
        super().__init__()

        assert config_text.startswith('spade')
        parsed = re.search('spade(\D+)(\d)x\d', config_text)
        param_free_norm_type = str(parsed.group(1))
        ks = int(parsed.group(2))

        if param_free_norm_type == 'instance':
            self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
        elif param_free_norm_type == 'syncbatch':
            self.param_free_norm = SynchronizedBatchNorm2d(norm_nc,
                                                           affine=False)
        elif param_free_norm_type == 'batch':
            self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
        else:
            raise ValueError(
                '%s is not a recognized param-free norm type in SPADE' %
                param_free_norm_type)

        # The dimension of the intermediate embedding space. Yes, hardcoded.
        nhidden = 128

        if attr_nc == 0:
            attr_emb_dim = 0
            self.has_attr = False
        else:
            self.has_attr = True

        pw = ks // 2
        self.emb = nn.Conv2d(label_nc, mask_emb_dim, 1, bias=False)
        if self.has_attr:
            self.emb_attr = nn.Conv2d(attr_nc, attr_emb_dim, 1, bias=False)
        self.mlp_shared = nn.Sequential(
            nn.Conv2d(mask_emb_dim + attr_emb_dim,
                      nhidden,
                      kernel_size=ks,
                      padding=pw), nn.ReLU(inplace=True))
        self.mlp_gamma = nn.Conv2d(nhidden,
                                   norm_nc,
                                   kernel_size=ks,
                                   padding=pw)
        self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)

        self.mlp_shared_avg = nn.Sequential(
            nn.Conv2d(mask_emb_dim + attr_emb_dim,
                      nhidden,
                      kernel_size=ks,
                      padding=pw), nn.ReLU(inplace=True))
        self.mlp_gamma_avg = nn.Conv2d(nhidden,
                                       norm_nc,
                                       kernel_size=1,
                                       padding=0,
                                       bias=False)
        self.mlp_beta_avg = nn.Conv2d(nhidden,
                                      norm_nc,
                                      kernel_size=1,
                                      padding=0,
                                      bias=False)
 def convlayer(in_channels,
               out_channels,
               kernel_size,
               stride=1,
               padding=0,
               dilation=1,
               groups=1,
               bias=True):
     return nn.Sequential(
         SynchronizedBatchNorm2d(in_channels),
         Conv2d(in_channels, out_channels, kernel_size, stride, padding,
                dilation, groups, bias), nn.LeakyReLU(0.2))
Ejemplo n.º 9
0
    def __init__(self, fin, fout, opt):
        super().__init__()
        # attributes
        self.learned_shortcut = (fin != fout)
        fmiddle = fin

        # create conv layers
        self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
        self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
        if self.learned_shortcut:
            self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)

        # apply spectral norm if specified
        if 'spectral' in opt.norm_S:
            self.conv_0 = spectral_norm(self.conv_0)
            self.conv_1 = spectral_norm(self.conv_1)
            if self.learned_shortcut:
                self.conv_s = spectral_norm(self.conv_s)

        # define normalization layers
        subnorm_type = opt.norm_S.replace('spectral', '')
        if subnorm_type == 'batch':
            self.norm_layer_in = nn.BatchNorm2d(fin, affine=True)
            self.norm_layer_out = nn.BatchNorm2d(fout, affine=True)
            if self.learned_shortcut:
                self.norm_layer_s = nn.BatchNorm2d(fout, affine=True)
        elif subnorm_type == 'syncbatch':
            self.norm_layer_in = SynchronizedBatchNorm2d(fin, affine=True)
            self.norm_layer_out = SynchronizedBatchNorm2d(fout, affine=True)
            if self.learned_shortcut:
                self.norm_layer_s = SynchronizedBatchNorm2d(fout, affine=True)
        elif subnorm_type == 'instance':
            self.norm_layer_in = nn.InstanceNorm2d(fin, affine=False)
            self.norm_layer_out = nn.InstanceNorm2d(fout, affine=False)
            if self.learned_shortcut:
                self.norm_layer_s = nn.InstanceNorm2d(fout, affine=False)
        else:
            raise ValueError('normalization layer %s is not recognized' %
                             subnorm_type)
Ejemplo n.º 10
0
    def __init__(self,
                 config_text,
                 norm_nc,
                 label_nc,
                 ACE_Name=None,
                 status='train',
                 spade_params=None,
                 use_rgb=True):
        super().__init__()

        self.ACE_Name = ACE_Name
        self.status = status
        self.save_npy = False
        self.Spade = SPADE(*spade_params)
        self.use_rgb = use_rgb
        self.style_length = 512
        self.blending_gamma = nn.Parameter(torch.zeros(1), requires_grad=True)
        self.blending_beta = nn.Parameter(torch.zeros(1), requires_grad=True)
        self.noise_var = nn.Parameter(torch.zeros(norm_nc), requires_grad=True)

        assert config_text.startswith('spade')
        parsed = re.search('spade(\D+)(\d)x\d', config_text)
        param_free_norm_type = str(parsed.group(1))
        ks = int(parsed.group(2))
        pw = ks // 2

        if param_free_norm_type == 'instance':
            self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
        elif param_free_norm_type == 'syncbatch':
            self.param_free_norm = SynchronizedBatchNorm2d(norm_nc,
                                                           affine=False)
        elif param_free_norm_type == 'batch':
            self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
        else:
            raise ValueError(
                '%s is not a recognized param-free norm type in SPADE' %
                param_free_norm_type)

        # The dimension of the intermediate embedding space. Yes, hardcoded.

        if self.use_rgb:
            self.create_gamma_beta_fc_layers()

            self.conv_gamma = nn.Conv2d(self.style_length,
                                        norm_nc,
                                        kernel_size=ks,
                                        padding=pw)
            self.conv_beta = nn.Conv2d(self.style_length,
                                       norm_nc,
                                       kernel_size=ks,
                                       padding=pw)
Ejemplo n.º 11
0
 def __init__(self, opt):
     super().__init__()
     nf = opt.ngf
     kw = 4 if opt.domain_rela else 3
     pw = int((kw - 1.0) / 2)
     self.feature = nn.Sequential(nn.Conv2d(4 * nf, 2 * nf, kw, stride=2, padding=pw),
                             SynchronizedBatchNorm2d(2 * nf, affine=True),
                             nn.LeakyReLU(0.2, False),
                             nn.Conv2d(2 * nf, nf, kw, stride=2, padding=pw),
                             SynchronizedBatchNorm2d(nf, affine=True),
                             nn.LeakyReLU(0.2, False),
                             nn.Conv2d(nf, int(nf // 2), kw, stride=2, padding=pw),
                             SynchronizedBatchNorm2d(int(nf // 2), affine=True),
                             nn.LeakyReLU(0.2, False))  #32*8*8
     model = [nn.Linear(int(nf // 2) * 8 * 8, 100),
             SynchronizedBatchNorm1d(100, affine=True),
             nn.ReLU()]
     if opt.domain_rela:
         model += [nn.Linear(100, 1)]
     else:
         model += [nn.Linear(100, 2),
                   nn.LogSoftmax(dim=1)]
     self.classifier = nn.Sequential(*model)
Ejemplo n.º 12
0
    def __init__(self, config_text, norm_nc, image_nc, downsample_n):
        super().__init__()

        assert config_text.startswith('spade')
        parsed = re.search('spade(\D+)(\d)x\d', config_text)
        param_free_norm_type = str(parsed.group(1))
        ks = int(parsed.group(2))

        if param_free_norm_type == 'instance':
            self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
        elif param_free_norm_type == 'syncbatch':
            self.param_free_norm = SynchronizedBatchNorm2d(norm_nc,
                                                           affine=False)
        elif param_free_norm_type == 'batch':
            self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
        else:
            raise ValueError(
                '%s is not a recognized param-free norm type in SPADE' %
                param_free_norm_type)

        # The dimension of the intermediate embedding space. Yes, hardcoded.
        nhidden = 128

        pw = ks // 2
        self.mlp_shared = nn.Sequential(
            nn.Conv2d(image_nc, nhidden, kernel_size=ks, padding=pw),
            nn.ReLU())
        self.middle = []
        for i in range(downsample_n):
            self.middle += [
                nn.Conv2d(nhidden,
                          nhidden,
                          kernel_size=3,
                          padding=pw,
                          stride=2)
            ]
            self.middle += [nn.ReLU()]
        self.middle = nn.Sequential(*self.middle)

        self.mlp_gamma = nn.Conv2d(nhidden,
                                   norm_nc,
                                   kernel_size=ks,
                                   padding=pw)
        self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
Ejemplo n.º 13
0
    def __init__(self, norm_nc, hidden_nc=0, norm='batch', ks=3, params_free=False):
        super().__init__()
        pw = ks//2
        if not isinstance(hidden_nc, list): hidden_nc = [hidden_nc]
        for i, nhidden in enumerate(hidden_nc):                                    
            mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
            mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
            
            if not params_free or (i != 0):
                s = str(i+1) if i > 0 else ''                                
                setattr(self, 'mlp_gamma%s' % s, mlp_gamma)
                setattr(self, 'mlp_beta%s' % s, mlp_beta)

        if 'batch' in norm:
            self.batch_norm = SynchronizedBatchNorm2d(norm_nc, affine=False)
        else:
            self.batch_norm = nn.InstanceNorm2d(norm_nc, affine=False)
        
        self.params_free = params_free

        self.hidden_nc = hidden_nc
Ejemplo n.º 14
0
    def __init__(self,
                 config_text,
                 norm_nc,
                 label_nc,
                 PONO=False,
                 use_apex=False):
        super().__init__()

        assert config_text.startswith('spade')
        parsed = re.search('spade(\D+)(\d)x\d', config_text)
        param_free_norm_type = str(parsed.group(1))
        ks = int(parsed.group(2))
        self.pad_type = 'zero'

        if PONO:
            self.param_free_norm = PositionalNorm2d
        elif param_free_norm_type == 'instance':
            self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
        elif param_free_norm_type == 'syncbatch':
            if use_apex:
                self.param_free_norm = apex.parallel.SyncBatchNorm(
                    norm_nc, affine=False)
            else:
                self.param_free_norm = SynchronizedBatchNorm2d(norm_nc,
                                                               affine=False)
        elif param_free_norm_type == 'batch':
            self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
        else:
            raise ValueError(
                '%s is not a recognized param-free norm type in SPADE' %
                param_free_norm_type)

        # The dimension of the intermediate embedding space. Yes, hardcoded.
        nhidden = 128
        pw = ks // 2
        #print(ks, pw)
        if self.pad_type != 'zero':
            self.mlp_shared = nn.Sequential(
                nn.ReflectionPad2d(pw),
                nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=0),
                nn.ReLU())
            self.pad = nn.ReflectionPad2d(pw)
            self.mlp_gamma = nn.Conv2d(nhidden,
                                       norm_nc,
                                       kernel_size=ks,
                                       padding=0)
            self.mlp_beta = nn.Conv2d(nhidden,
                                      norm_nc,
                                      kernel_size=ks,
                                      padding=0)
        else:
            self.mlp_shared = nn.Sequential(
                nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
                nn.ReLU())
            self.mlp_gamma = nn.Conv2d(nhidden,
                                       norm_nc,
                                       kernel_size=ks,
                                       padding=pw)
            self.mlp_beta = nn.Conv2d(nhidden,
                                      norm_nc,
                                      kernel_size=ks,
                                      padding=pw)
Ejemplo n.º 15
0
    def __init__(self, config_text, norm_nc, label_nc):
        super().__init__()

        assert config_text.startswith('spade')
        parsed = re.search('spade(\D+)(\d)x\d', config_text)
        param_free_norm_type = str(parsed.group(1))
        ks = int(parsed.group(2))

        if param_free_norm_type == 'instance':
            self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
        elif param_free_norm_type == 'syncbatch':
            self.param_free_norm = SynchronizedBatchNorm2d(norm_nc,
                                                           affine=False)
        elif param_free_norm_type == 'batch':
            self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
        else:
            raise ValueError(
                '%s is not a recognized param-free norm type in SPADE' %
                param_free_norm_type)

        # The dimension of the intermediate embedding space. Yes, hardcoded.
        nhidden = 128

        pw = ks // 2
        '''
        self.mlp_shared_2 = nn.Sequential(
            nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=2, dilation = 2),
            nn.LeakyReLU(0.2)
        )
        self.mlp_gamma_2 = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=2, dilation = 2)
        self.mlp_beta_2 = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=2, dilation = 2)
        


        self.mlp_shared_4 = nn.Sequential(
            nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=4, dilation = 4),
            nn.LeakyReLU(0.2)
        )
        self.mlp_gamma_4 = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=4, dilation = 4)
        self.mlp_beta_4 = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=4, dilation = 4)
        


        self.mlp_shared_8 = nn.Sequential(
            nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=8, dilation = 8),
            nn.LeakyReLU(0.2)
        )
        self.mlp_gamma_8 = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=8, dilation = 8)
        self.mlp_beta_8 = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=8, dilation = 8)
        


        self.mlp_shared_16 = nn.Sequential(
            nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=16, dilation = 16),
            nn.LeakyReLU(0.2)
        )
        self.mlp_gamma_16 = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=16, dilation = 16)
        self.mlp_beta_16 = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=16, dilation = 16)
        '''

        self.mlp_shared = nn.Sequential(
            nn.Conv2d(label_nc,
                      nhidden,
                      kernel_size=ks,
                      padding=pw,
                      dilation=1), nn.LeakyReLU(0.2))
        self.mlp_gamma = nn.Conv2d(nhidden,
                                   norm_nc,
                                   kernel_size=ks,
                                   padding=pw,
                                   dilation=1)
        self.mlp_beta = nn.Conv2d(nhidden,
                                  norm_nc,
                                  kernel_size=ks,
                                  padding=pw,
                                  dilation=1)
Ejemplo n.º 16
0
    def __init__(self,
                 config_text,
                 norm_nc,
                 label_nc,
                 group_num=0,
                 use_instance=False,
                 data_mode='deepfashion'):
        super().__init__()
        if group_num == 0:
            group_num = label_nc
        assert config_text.startswith('spade')
        parsed = re.search('spade(\D+)(\d)x\d', config_text)
        param_free_norm_type = str(parsed.group(1))
        ks = int(parsed.group(2))

        if param_free_norm_type == 'instance':
            self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
        elif param_free_norm_type == 'syncbatch':
            self.param_free_norm = SynchronizedBatchNorm2d(norm_nc,
                                                           affine=False)
        elif param_free_norm_type == 'batch':
            self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
        elif param_free_norm_type == 'group':
            self.param_free_norm = nn.GroupNorm(label_nc, norm_nc)
        else:
            raise ValueError(
                '%s is not a recognized param-free norm type in SPADE' %
                param_free_norm_type)

        # The dimension of the intermediate embedding space. Yes, hardcoded.
        if use_instance:
            seg_in_dim = label_nc * 2
        else:
            seg_in_dim = label_nc
        pw = ks // 2
        # print(data_mode)
        if data_mode == 'deepfashion':
            nhidden = 128
            self.mlp_shared = nn.Sequential(
                nn.Conv2d(seg_in_dim,
                          nhidden,
                          kernel_size=ks,
                          padding=pw,
                          groups=group_num), nn.ReLU())
            self.mlp_gamma = nn.Conv2d(nhidden,
                                       norm_nc,
                                       kernel_size=ks,
                                       padding=pw,
                                       groups=group_num)
            self.mlp_beta = nn.Conv2d(nhidden,
                                      norm_nc,
                                      kernel_size=ks,
                                      padding=pw,
                                      groups=group_num)
        elif data_mode == 'sketch':
            nhidden = 64
            self.mlp_shared = nn.Sequential(
                nn.Conv2d(seg_in_dim,
                          nhidden,
                          kernel_size=ks,
                          padding=pw,
                          groups=group_num), nn.ReLU())
            self.mlp_gamma = nn.Conv2d(nhidden,
                                       norm_nc,
                                       kernel_size=ks,
                                       padding=pw,
                                       groups=group_num)
            self.mlp_beta = nn.Conv2d(nhidden,
                                      norm_nc,
                                      kernel_size=ks,
                                      padding=pw,
                                      groups=group_num)
        elif data_mode == 'cityscapes':
            nhidden = label_nc * group_num
            # print(nhidden)
            self.mlp_shared = nn.Sequential(
                nn.Conv2d(seg_in_dim,
                          nhidden,
                          kernel_size=ks,
                          padding=pw,
                          groups=label_nc),
                # nn.Conv2d(seg_in_dim, nhidden, kernel_size=ks, padding=pw),
                nn.ReLU())
            self.mlp_gamma = nn.Conv2d(nhidden,
                                       norm_nc,
                                       kernel_size=ks,
                                       padding=pw,
                                       groups=label_nc)
            self.mlp_beta = nn.Conv2d(nhidden,
                                      norm_nc,
                                      kernel_size=ks,
                                      padding=pw,
                                      groups=label_nc)
        elif data_mode == 'ade20k':
            if label_nc % group_num == 0:
                nhidden = label_nc * 2
            else:
                nhidden = label_nc * group_num
            self.mlp_shared = nn.Sequential(
                nn.Conv2d(seg_in_dim,
                          nhidden,
                          kernel_size=ks,
                          padding=pw,
                          groups=label_nc),
                # nn.Conv2d(seg_in_dim, nhidden, kernel_size=ks, padding=pw),
                nn.ReLU())
            self.mlp_gamma = nn.Conv2d(nhidden,
                                       norm_nc,
                                       kernel_size=ks,
                                       padding=pw,
                                       groups=group_num)
            self.mlp_beta = nn.Conv2d(nhidden,
                                      norm_nc,
                                      kernel_size=ks,
                                      padding=pw,
                                      groups=group_num)