コード例 #1
0
    def __init__(self, dis_cfg, data_cfg):
        super().__init__()
        self.data_cfg = data_cfg
        num_input_channels = get_paired_input_label_channel_number(data_cfg)
        if num_input_channels == 0:
            num_input_channels = getattr(data_cfg, 'label_channels', 1)
        num_img_channels = get_paired_input_image_channel_number(data_cfg)
        self.num_frames_D = data_cfg.num_frames_D
        self.num_scales = get_nested_attr(dis_cfg, 'temporal.num_scales', 0)
        num_netD_input_channels = (num_input_channels + num_img_channels)
        self.use_few_shot = 'few_shot' in data_cfg.type
        if self.use_few_shot:
            num_netD_input_channels *= 2
        self.net_D = MultiPatchDiscriminator(dis_cfg.image,
                                             num_netD_input_channels)

        self.add_dis_cfg = getattr(dis_cfg, 'additional_discriminators', None)
        if self.add_dis_cfg is not None:
            for name in self.add_dis_cfg:
                add_dis_cfg = self.add_dis_cfg[name]
                num_ch = num_img_channels * (2 if self.use_few_shot else 1)
                setattr(self, 'net_D_' + name,
                        MultiPatchDiscriminator(add_dis_cfg, num_ch))

        # Temporal discriminator.
        self.num_netDT_input_channels = num_img_channels * self.num_frames_D
        for n in range(self.num_scales):
            setattr(self, 'net_DT%d' % n,
                    MultiPatchDiscriminator(dis_cfg.temporal,
                                            self.num_netDT_input_channels))
        self.has_fg = getattr(data_cfg, 'has_foreground', False)
コード例 #2
0
    def __init__(self, dis_cfg, data_cfg):
        super(Discriminator, self).__init__()
        print('Multi-resolution patch discriminator initialization.')
        # We assume the first datum is the ground truth image.
        image_channels = get_paired_input_image_channel_number(data_cfg)
        # Calculate number of channels in the input label.
        num_labels = get_paired_input_label_channel_number(data_cfg)

        # Build the discriminator.
        kernel_size = getattr(dis_cfg, 'kernel_size', 3)
        num_filters = getattr(dis_cfg, 'num_filters', 128)
        max_num_filters = getattr(dis_cfg, 'max_num_filters', 512)
        num_discriminators = getattr(dis_cfg, 'num_discriminators', 2)
        num_layers = getattr(dis_cfg, 'num_layers', 5)
        activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none')
        weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral')
        print('\tBase filter number: %d' % num_filters)
        print('\tNumber of discriminators: %d' % num_discriminators)
        print('\tNumber of layers in a discriminator: %d' % num_layers)
        print('\tWeight norm type: %s' % weight_norm_type)
        num_input_channels = image_channels + num_labels
        self.model = MultiResPatchDiscriminator(
            num_discriminators, kernel_size, num_input_channels, num_filters,
            num_layers, max_num_filters, activation_norm_type,
            weight_norm_type)
        print('Done with the Multi-resolution patch '
              'discriminator initialization.')
コード例 #3
0
    def __init__(self, dis_cfg, data_cfg):
        super(Discriminator, self).__init__()
        print('Multi-resolution patch discriminator initialization.')
        # We assume the first datum is the ground truth image.
        image_channels = get_paired_input_image_channel_number(data_cfg)
        # Calculate number of channels in the input label.
        if data_cfg.type == 'imaginaire.datasets.paired_videos':
            num_labels = get_paired_input_label_channel_number(data_cfg,
                                                               video=True)
        else:
            num_labels = get_paired_input_label_channel_number(data_cfg)

        # Build the discriminator.
        kernel_size = getattr(dis_cfg, 'kernel_size', 3)
        num_filters = getattr(dis_cfg, 'num_filters', 128)
        max_num_filters = getattr(dis_cfg, 'max_num_filters', 512)
        num_discriminators = getattr(dis_cfg, 'num_discriminators', 2)
        num_layers = getattr(dis_cfg, 'num_layers', 5)
        activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none')
        weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral')
        print('\tBase filter number: %d' % num_filters)
        print('\tNumber of discriminators: %d' % num_discriminators)
        print('\tNumber of layers in a discriminator: %d' % num_layers)
        print('\tWeight norm type: %s' % weight_norm_type)
        num_input_channels = image_channels + num_labels
        self.discriminators = nn.ModuleList()
        for i in range(num_discriminators):
            net_discriminator = NLayerPatchDiscriminator(
                kernel_size, num_input_channels, num_filters, num_layers,
                max_num_filters, activation_norm_type, weight_norm_type)
            self.discriminators.append(net_discriminator)
        print('Done with the Multi-resolution patch '
              'discriminator initialization.')
        fpse_kernel_size = getattr(dis_cfg, 'fpse_kernel_size', 3)
        fpse_activation_norm_type = getattr(dis_cfg,
                                            'fpse_activation_norm_type',
                                            'none')
        self.fpse_discriminator = FPSEDiscriminator(image_channels, num_labels,
                                                    num_filters,
                                                    fpse_kernel_size,
                                                    weight_norm_type,
                                                    fpse_activation_norm_type)
コード例 #4
0
def cluster_features(cfg, train_data_loader, net_E,
                     preprocess=None, small_ratio=0.0625, is_cityscapes=True):
    r"""Use clustering to compute the features.

    Args:
        cfg (obj): Global configuration file.
        train_data_loader (obj): Dataloader for iterate through the training
            set.
        net_E (nn.Module): Pytorch network.
        preprocess (function): Pre-processing function.
        small_ratio (float): We only consider instance that at least occupy
            $(small_ratio) amount of image space.
        is_cityscapes (bool): Is this is the cityscape dataset? In the
            Cityscapes dataset, the instance labels for car start with 26001,
            26002, ...

    Returns:
        ( num_labels x num_cluster_centers x feature_dims): cluster centers.
    """
    # Encode features.
    label_nc = get_paired_input_label_channel_number(cfg.data)
    feat_nc = cfg.gen.enc.num_feat_channels
    n_clusters = getattr(cfg.gen.enc, 'num_clusters', 10)
    # Compute features.
    features = {}
    for label in range(label_nc):
        features[label] = np.zeros((0, feat_nc + 1))
    for data in train_data_loader:
        if preprocess is not None:
            data = preprocess(data)
        feat = encode_features(net_E, feat_nc, label_nc,
                               data['images'], data['instance_maps'],
                               is_cityscapes)
        # We only collect the feature vectors for the master GPU.
        if is_master():
            for label in range(label_nc):
                features[label] = np.append(
                    features[label], feat[label], axis=0)
    # Clustering.
    # We only perform clustering for the master GPU.
    if is_master():
        for label in range(label_nc):
            feat = features[label]
            # We only consider segments that are greater than a pre-set
            # threshold.
            feat = feat[feat[:, -1] > small_ratio, :-1]
            if feat.shape[0]:
                n_clusters = min(feat.shape[0], n_clusters)
                kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(feat)
                n, d = kmeans.cluster_centers_.shape
                this_cluster = getattr(net_E, 'cluster_%d' % label)
                this_cluster[0:n, :] = torch.Tensor(
                    kmeans.cluster_centers_).float()
コード例 #5
0
ファイル: pix2pixHD.py プロジェクト: yejees/ObjectSwap
    def __init__(self, enc_cfg, data_cfg):
        super(Encoder, self).__init__()
        label_nc = get_paired_input_label_channel_number(data_cfg)
        feat_nc = enc_cfg.num_feat_channels
        n_clusters = getattr(enc_cfg, 'num_clusters', 10)
        for i in range(label_nc):
            dummy_arr = np.zeros((n_clusters, feat_nc), dtype=np.float32)
            self.register_buffer('cluster_%d' % i,
                                 torch.tensor(dummy_arr, dtype=torch.float32))
        num_img_channels = get_paired_input_image_channel_number(data_cfg)
        self.num_feat_channels = getattr(enc_cfg, 'num_feat_channels', 3)
        num_filters = getattr(enc_cfg, 'num_filters', 64)
        num_downsamples = getattr(enc_cfg, 'num_downsamples', 4)
        weight_norm_type = getattr(enc_cfg, 'weight_norm_type', 'none')
        activation_norm_type = getattr(enc_cfg, 'activation_norm_type',
                                       'instance')
        padding_mode = getattr(enc_cfg, 'padding_mode', 'reflect')
        base_conv_block = partial(Conv2dBlock,
                                  padding_mode=padding_mode,
                                  weight_norm_type=weight_norm_type,
                                  activation_norm_type=activation_norm_type,
                                  nonlinearity='relu')
        model = [base_conv_block(num_img_channels, num_filters, 7, padding=3)]
        # Downsample.
        for i in range(num_downsamples):
            ch = num_filters * (2**i)
            model += [base_conv_block(ch, ch * 2, 3, stride=2, padding=1)]
        # Upsample.
        for i in reversed(range(num_downsamples)):
            ch = num_filters * (2 ** i)
            model += [NearestUpsample(scale_factor=2),
                      base_conv_block(ch * 2, ch, 3, padding=1)]

        model += [Conv2dBlock(num_filters, self.num_feat_channels, 7,
                              padding=3, padding_mode=padding_mode,
                              nonlinearity='tanh')]
        self.model = nn.Sequential(*model)
コード例 #6
0
    def __init__(self, gen_cfg, data_cfg):
        super().__init__()
        self.gen_cfg = gen_cfg
        self.data_cfg = data_cfg
        self.num_frames_G = data_cfg.num_frames_G
        # Number of residual blocks in generator.
        self.num_layers = num_layers = getattr(gen_cfg, 'num_layers', 7)
        # Number of downsamplings for previous frame.
        self.num_downsamples_img = getattr(gen_cfg, 'num_downsamples_img', 4)
        # Number of filters in the first layer.
        self.num_filters = num_filters = getattr(gen_cfg, 'num_filters', 32)
        self.max_num_filters = getattr(gen_cfg, 'max_num_filters', 1024)
        self.kernel_size = kernel_size = getattr(gen_cfg, 'kernel_size', 3)
        padding = kernel_size // 2

        # For pose dataset.
        self.is_pose_data = hasattr(data_cfg, 'for_pose_dataset')
        if self.is_pose_data:
            pose_cfg = data_cfg.for_pose_dataset
            self.pose_type = getattr(pose_cfg, 'pose_type', 'both')
            self.remove_face_labels = getattr(pose_cfg, 'remove_face_labels',
                                              False)

        # Input data params.
        self.num_input_channels = num_input_channels = get_paired_input_label_channel_number(
            data_cfg)

        num_img_channels = get_paired_input_image_channel_number(data_cfg)
        aug_cfg = data_cfg.val.augmentations
        if hasattr(aug_cfg, 'center_crop_h_w'):
            crop_h_w = aug_cfg.center_crop_h_w
        elif hasattr(aug_cfg, 'resize_h_w'):
            crop_h_w = aug_cfg.resize_h_w
        else:
            raise ValueError('Need to specify output size.')
        crop_h, crop_w = crop_h_w.split(',')
        crop_h, crop_w = int(crop_h), int(crop_w)
        # Spatial size at the bottle neck of generator.
        self.sh = crop_h // (2**num_layers)
        self.sw = crop_w // (2**num_layers)

        # Noise vector dimension.
        self.z_dim = getattr(gen_cfg, 'style_dims', 256)
        self.use_segmap_as_input = \
            getattr(gen_cfg, 'use_segmap_as_input', False)

        # Label / image embedding network.
        self.emb_cfg = emb_cfg = getattr(gen_cfg, 'embed', None)
        self.use_embed = getattr(emb_cfg, 'use_embed', 'True')
        self.num_downsamples_embed = getattr(emb_cfg, 'num_downsamples', 5)
        if self.use_embed:
            self.label_embedding = LabelEmbedder(emb_cfg, num_input_channels)

        # Flow network.
        self.flow_cfg = flow_cfg = gen_cfg.flow
        # Use SPADE to combine warped and hallucinated frames instead of
        # linear combination.
        self.spade_combine = getattr(flow_cfg, 'multi_spade_combine', True)
        # Number of layers to perform multi-spade combine.
        self.num_multi_spade_layers = getattr(flow_cfg.multi_spade_combine,
                                              'num_layers', 3)
        # At beginning of training, only train an image generator.
        self.temporal_initialized = False
        # Whether to output hallucinated frame (when training temporal network)
        # for additional loss.
        self.generate_raw_output = False

        # Image generation network.
        weight_norm_type = getattr(gen_cfg, 'weight_norm_type', 'spectral')
        activation_norm_type = gen_cfg.activation_norm_type
        activation_norm_params = gen_cfg.activation_norm_params
        if self.use_embed and \
                not hasattr(activation_norm_params, 'num_filters'):
            activation_norm_params.num_filters = 0
        nonlinearity = 'leakyrelu'

        self.base_res_block = base_res_block = partial(
            Res2dBlock,
            kernel_size=kernel_size,
            padding=padding,
            weight_norm_type=weight_norm_type,
            activation_norm_type=activation_norm_type,
            activation_norm_params=activation_norm_params,
            nonlinearity=nonlinearity,
            order='NACNAC')

        # Upsampling residual blocks.
        for i in range(num_layers, -1, -1):
            activation_norm_params.cond_dims = self.get_cond_dims(i)
            activation_norm_params.partial = self.get_partial(i) if hasattr(
                self, 'get_partial') else False
            layer = base_res_block(self.get_num_filters(i + 1),
                                   self.get_num_filters(i))
            setattr(self, 'up_%d' % i, layer)

        # Final conv layer.
        self.conv_img = Conv2dBlock(num_filters,
                                    num_img_channels,
                                    kernel_size,
                                    padding=padding,
                                    nonlinearity=nonlinearity,
                                    order='AC')

        num_filters = min(self.max_num_filters,
                          num_filters * (2**(self.num_layers + 1)))
        if self.use_segmap_as_input:
            self.fc = Conv2dBlock(num_input_channels,
                                  num_filters,
                                  kernel_size=3,
                                  padding=1)
        else:
            self.fc = LinearBlock(self.z_dim, num_filters * self.sh * self.sw)

        # Misc.
        self.downsample = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
        self.upsample = partial(F.interpolate, scale_factor=2)
        self.init_temporal_network()
コード例 #7
0
    def __init__(self, flow_cfg, data_cfg):
        super().__init__()
        num_input_channels = get_paired_input_label_channel_number(data_cfg)
        num_prev_img_channels = get_paired_input_image_channel_number(data_cfg)
        num_frames = data_cfg.num_frames_G  # Num. of input frames.

        self.num_filters = num_filters = getattr(flow_cfg, 'num_filters', 32)
        self.max_num_filters = getattr(flow_cfg, 'max_num_filters', 1024)
        num_downsamples = getattr(flow_cfg, 'num_downsamples', 5)
        kernel_size = getattr(flow_cfg, 'kernel_size', 3)
        padding = kernel_size // 2
        self.num_res_blocks = getattr(flow_cfg, 'num_res_blocks', 6)
        # Multiplier on the flow output.
        self.flow_output_multiplier = getattr(flow_cfg,
                                              'flow_output_multiplier', 20)

        activation_norm_type = getattr(flow_cfg, 'activation_norm_type',
                                       'sync_batch')
        weight_norm_type = getattr(flow_cfg, 'weight_norm_type', 'spectral')

        base_conv_block = partial(Conv2dBlock,
                                  kernel_size=kernel_size,
                                  padding=padding,
                                  weight_norm_type=weight_norm_type,
                                  activation_norm_type=activation_norm_type,
                                  nonlinearity='leakyrelu')

        # Will downsample the labels and prev frames separately, then combine.
        down_lbl = [
            base_conv_block(num_input_channels * num_frames, num_filters)
        ]
        down_img = [
            base_conv_block(num_prev_img_channels * (num_frames - 1),
                            num_filters)
        ]
        for i in range(num_downsamples):
            down_lbl += [
                base_conv_block(self.get_num_filters(i),
                                self.get_num_filters(i + 1),
                                stride=2)
            ]
            down_img += [
                base_conv_block(self.get_num_filters(i),
                                self.get_num_filters(i + 1),
                                stride=2)
            ]

        # Resnet blocks.
        res_flow = []
        ch = self.get_num_filters(num_downsamples)
        for i in range(self.num_res_blocks):
            res_flow += [
                Res2dBlock(ch,
                           ch,
                           kernel_size,
                           padding=padding,
                           weight_norm_type=weight_norm_type,
                           activation_norm_type=activation_norm_type,
                           order='CNACN')
            ]

        # Upsample.
        up_flow = []
        for i in reversed(range(num_downsamples)):
            up_flow += [
                nn.Upsample(scale_factor=2),
                base_conv_block(self.get_num_filters(i + 1),
                                self.get_num_filters(i))
            ]

        conv_flow = [Conv2dBlock(num_filters, 2, kernel_size, padding=padding)]
        conv_mask = [
            Conv2dBlock(num_filters,
                        1,
                        kernel_size,
                        padding=padding,
                        nonlinearity='sigmoid')
        ]

        self.down_lbl = nn.Sequential(*down_lbl)
        self.down_img = nn.Sequential(*down_img)
        self.res_flow = nn.Sequential(*res_flow)
        self.up_flow = nn.Sequential(*up_flow)
        self.conv_flow = nn.Sequential(*conv_flow)
        self.conv_mask = nn.Sequential(*conv_mask)
コード例 #8
0
ファイル: pix2pixHD.py プロジェクト: yejees/ObjectSwap
    def __init__(self, gen_cfg, data_cfg):
        super().__init__()
        # pix2pixHD has a global generator.
        global_gen_cfg = gen_cfg.global_generator
        num_filters_global = getattr(global_gen_cfg, 'num_filters', 64)
        # Optionally, it can have several local enhancers. They are useful
        # for generating high resolution images.
        local_gen_cfg = gen_cfg.local_enhancer
        self.num_local_enhancers = num_local_enhancers = \
            getattr(local_gen_cfg, 'num_enhancers', 1)
        # By default, pix2pixHD using instance normalization.
        activation_norm_type = getattr(gen_cfg, 'activation_norm_type',
                                       'instance')
        activation_norm_params = getattr(gen_cfg, 'activation_norm_params',
                                         None)
        weight_norm_type = getattr(gen_cfg, 'weight_norm_type', '')
        padding_mode = getattr(gen_cfg, 'padding_mode', 'reflect')
        base_conv_block = partial(Conv2dBlock,
                                  padding_mode=padding_mode,
                                  weight_norm_type=weight_norm_type,
                                  activation_norm_type=activation_norm_type,
                                  activation_norm_params=activation_norm_params,
                                  nonlinearity='relu')
        base_res_block = partial(Res2dBlock,
                                 padding_mode=padding_mode,
                                 weight_norm_type=weight_norm_type,
                                 activation_norm_type=activation_norm_type,
                                 activation_norm_params=activation_norm_params,
                                 nonlinearity='relu', order='CNACN')
        # Know what is the number of available segmentation labels.
        num_input_channels = get_paired_input_label_channel_number(data_cfg)
        self.concat_features = False
        # Check whether label input contains specific type of data (e.g.
        # instance_maps).
        self.contain_instance_map = False
        if data_cfg.input_labels[-1] == 'instance_maps':
            self.contain_instance_map = True
        # The feature encoder is only useful when the instance map is provided.
        if hasattr(gen_cfg, 'enc') and self.contain_instance_map:
            num_feat_channels = getattr(gen_cfg.enc, 'num_feat_channels', 0)
            if num_feat_channels > 0:
                num_input_channels += num_feat_channels
                self.concat_features = True
                self.encoder = Encoder(gen_cfg.enc, data_cfg)

        # Global generator model.
        global_model = GlobalGenerator(global_gen_cfg, data_cfg,
                                       num_input_channels, padding_mode,
                                       base_conv_block, base_res_block)
        if num_local_enhancers == 0:
            self.global_model = global_model
        else:
            # Get rid of the last layer.
            global_model = global_model.model
            global_model = [global_model[i]
                            for i in range(len(global_model) - 1)]
            # global_model = [global_model[i]
            #                 for i in range(len(global_model) - 2)]
            self.global_model = nn.Sequential(*global_model)

        # Local enhancer model.
        for n in range(num_local_enhancers):
            # num_filters = num_filters_global // (2 ** n)
            num_filters = num_filters_global // (2 ** (n + 1))
            output_img = (n == num_local_enhancers - 1)
            setattr(self, 'enhancer_%d' % n,
                    LocalEnhancer(local_gen_cfg, data_cfg,
                                  num_input_channels, num_filters,
                                  padding_mode, base_conv_block,
                                  base_res_block, output_img))

        self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1],
                                       count_include_pad=False)
コード例 #9
0
    def __init__(self, gen_cfg, data_cfg):
        super(Generator, self).__init__()
        print('SPADE generator initialization.')
        # We assume the first datum is the ground truth image.
        image_channels = get_paired_input_image_channel_number(data_cfg)
        # Calculate number of channels in the input label.
        num_labels = get_paired_input_label_channel_number(data_cfg)
        crop_h, crop_w = get_crop_h_w(data_cfg.train.augmentations)
        # Build the generator
        out_image_small_side_size = crop_w if crop_w < crop_h else crop_h
        num_filters = getattr(gen_cfg, 'num_filters', 128)
        kernel_size = getattr(gen_cfg, 'kernel_size', 3)
        weight_norm_type = getattr(gen_cfg, 'weight_norm_type', 'spectral')

        cond_dims = 0
        # Check whether we use the style code.
        style_dims = getattr(gen_cfg, 'style_dims', None)
        self.style_dims = style_dims
        if style_dims is not None:
            print('\tStyle code dimensions: %d' % style_dims)
            cond_dims += style_dims
            self.use_style = True
        else:
            self.use_style = False
        # Check whether we use the attribute code.
        if hasattr(gen_cfg, 'attribute_dims'):
            self.use_attribute = True
            self.attribute_dims = gen_cfg.attribute_dims
            cond_dims += gen_cfg.attribute_dims
        else:
            self.use_attribute = False

        if not self.use_style and not self.use_attribute:
            self.use_style_encoder = False
        else:
            self.use_style_encoder = True
        print('\tBase filter number: %d' % num_filters)
        print('\tConvolution kernel size: %d' % kernel_size)
        print('\tWeight norm type: %s' % weight_norm_type)
        skip_activation_norm = \
            getattr(gen_cfg, 'skip_activation_norm', True)
        activation_norm_params = \
            getattr(gen_cfg, 'activation_norm_params', None)
        if activation_norm_params is None:
            activation_norm_params = types.SimpleNamespace()
        if not hasattr(activation_norm_params, 'num_filters'):
            setattr(activation_norm_params, 'num_filters', 128)
        if not hasattr(activation_norm_params, 'kernel_size'):
            setattr(activation_norm_params, 'kernel_size', 3)
        if not hasattr(activation_norm_params, 'activation_norm_type'):
            setattr(activation_norm_params, 'activation_norm_type',
                    'sync_batch')
        if not hasattr(activation_norm_params, 'separate_projection'):
            setattr(activation_norm_params, 'separate_projection', False)
        if not hasattr(activation_norm_params, 'activation_norm_params'):
            activation_norm_params.activation_norm_params = \
                types.SimpleNamespace()
            activation_norm_params.activation_norm_params.affine = True
        setattr(activation_norm_params, 'cond_dims', num_labels)
        if not hasattr(activation_norm_params, 'weight_norm_type'):
            setattr(activation_norm_params, 'weight_norm_type',
                    weight_norm_type)
        global_adaptive_norm_type = getattr(gen_cfg,
                                            'global_adaptive_norm_type',
                                            'sync_batch')
        use_posenc_in_input_layer = getattr(gen_cfg,
                                            'use_posenc_in_input_layer', True)
        print(activation_norm_params)
        self.spade_generator = SPADEGenerator(
            num_labels, out_image_small_side_size, image_channels, num_filters,
            kernel_size, cond_dims, activation_norm_params, weight_norm_type,
            global_adaptive_norm_type, skip_activation_norm,
            use_posenc_in_input_layer, self.use_style_encoder)
        if self.use_style:
            # Build the encoder.
            style_enc_cfg = getattr(gen_cfg, 'style_enc', None)
            if style_enc_cfg is None:
                style_enc_cfg = types.SimpleNamespace()
            if not hasattr(style_enc_cfg, 'num_filters'):
                setattr(style_enc_cfg, 'num_filters', 128)
            if not hasattr(style_enc_cfg, 'kernel_size'):
                setattr(style_enc_cfg, 'kernel_size', 3)
            if not hasattr(style_enc_cfg, 'weight_norm_type'):
                setattr(style_enc_cfg, 'weight_norm_type', weight_norm_type)
            setattr(style_enc_cfg, 'input_image_channels', image_channels)
            setattr(style_enc_cfg, 'style_dims', style_dims)
            self.style_encoder = StyleEncoder(style_enc_cfg)

        self.z = None
        print('Done with the SPADE generator initialization.')