Ejemplo n.º 1
0
    def __init__(self, dis_cfg, data_cfg):
        super().__init__()
        self.data_cfg = data_cfg
        num_input_channels = get_paired_input_label_channel_number(data_cfg)
        if num_input_channels == 0:
            num_input_channels = getattr(data_cfg, 'label_channels', 1)
        num_img_channels = get_paired_input_image_channel_number(data_cfg)
        self.num_frames_D = data_cfg.num_frames_D
        self.num_scales = get_nested_attr(dis_cfg, 'temporal.num_scales', 0)
        num_netD_input_channels = (num_input_channels + num_img_channels)
        self.use_few_shot = 'few_shot' in data_cfg.type
        if self.use_few_shot:
            num_netD_input_channels *= 2
        self.net_D = MultiPatchDiscriminator(dis_cfg.image,
                                             num_netD_input_channels)

        self.add_dis_cfg = getattr(dis_cfg, 'additional_discriminators', None)
        if self.add_dis_cfg is not None:
            for name in self.add_dis_cfg:
                add_dis_cfg = self.add_dis_cfg[name]
                num_ch = num_img_channels * (2 if self.use_few_shot else 1)
                setattr(self, 'net_D_' + name,
                        MultiPatchDiscriminator(add_dis_cfg, num_ch))

        # Temporal discriminator.
        self.num_netDT_input_channels = num_img_channels * self.num_frames_D
        for n in range(self.num_scales):
            setattr(self, 'net_DT%d' % n,
                    MultiPatchDiscriminator(dis_cfg.temporal,
                                            self.num_netDT_input_channels))
        self.has_fg = getattr(data_cfg, 'has_foreground', False)
Ejemplo n.º 2
0
 def __init__(self, gen_cfg, data_cfg, num_input_channels, padding_mode,
              base_conv_block, base_res_block):
     super(GlobalGenerator, self).__init__()
     num_img_channels = get_paired_input_image_channel_number(data_cfg)
     num_filters = getattr(gen_cfg, 'num_filters', 64)
     num_downsamples = getattr(gen_cfg, 'num_downsamples', 4)
     num_res_blocks = getattr(gen_cfg, 'num_res_blocks', 9)
     # First layer.
     model = [base_conv_block(num_input_channels, num_filters,
                              kernel_size=7, padding=3)]
     # Downsample.
     for i in range(num_downsamples):
         ch = num_filters * (2 ** i)
         model += [base_conv_block(ch, ch * 2, 3, padding=1, stride=2)]
     # ResNet blocks.
     ch = num_filters * (2 ** num_downsamples)
     for i in range(num_res_blocks):
         model += [base_res_block(ch, ch, 3, padding=1)]
     # Upsample.
     num_upsamples = num_downsamples
     for i in reversed(range(num_upsamples)):
         ch = num_filters * (2 ** i)
         model += \
             [NearestUpsample(scale_factor=2),
              base_conv_block(ch * 2, ch, 3, padding=1)]
     model += [Conv2dBlock(num_filters, num_img_channels, 7, padding=3,
                           padding_mode=padding_mode, nonlinearity='tanh')]
     self.model = nn.Sequential(*model)
Ejemplo n.º 3
0
    def __init__(self, gen_cfg, data_cfg, num_input_channels, num_filters,
                 padding_mode, base_conv_block, base_res_block,
                 output_img=False):
        super(LocalEnhancer, self).__init__()
        num_res_blocks = getattr(gen_cfg, 'num_res_blocks', 3)
        num_img_channels = get_paired_input_image_channel_number(data_cfg)
        # Downsample.
        model_downsample = \
            [base_conv_block(num_input_channels, num_filters, 7, padding=3),
             base_conv_block(num_filters, num_filters * 2, 3, stride=2,
                             padding=1)]
        # Residual blocks.
        model_upsample = []
        for i in range(num_res_blocks):
            model_upsample += [base_res_block(num_filters * 2, num_filters * 2,
                                              3, padding=1)]
        # Upsample.
        model_upsample += \
            [NearestUpsample(scale_factor=2),
             base_conv_block(num_filters * 2, num_filters, 3, padding=1)]

        # Final convolution.
        if output_img:
            model_upsample += [Conv2dBlock(num_filters, num_img_channels, 7,
                                           padding=3, padding_mode=padding_mode,
                                           nonlinearity='tanh')]

        self.model_downsample = nn.Sequential(*model_downsample)
        self.model_upsample = nn.Sequential(*model_upsample)
Ejemplo n.º 4
0
    def __init__(self, dis_cfg, data_cfg):
        super(Discriminator, self).__init__()
        print('Multi-resolution patch discriminator initialization.')
        # We assume the first datum is the ground truth image.
        image_channels = get_paired_input_image_channel_number(data_cfg)
        # Calculate number of channels in the input label.
        num_labels = get_paired_input_label_channel_number(data_cfg)

        # Build the discriminator.
        kernel_size = getattr(dis_cfg, 'kernel_size', 3)
        num_filters = getattr(dis_cfg, 'num_filters', 128)
        max_num_filters = getattr(dis_cfg, 'max_num_filters', 512)
        num_discriminators = getattr(dis_cfg, 'num_discriminators', 2)
        num_layers = getattr(dis_cfg, 'num_layers', 5)
        activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none')
        weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral')
        print('\tBase filter number: %d' % num_filters)
        print('\tNumber of discriminators: %d' % num_discriminators)
        print('\tNumber of layers in a discriminator: %d' % num_layers)
        print('\tWeight norm type: %s' % weight_norm_type)
        num_input_channels = image_channels + num_labels
        self.model = MultiResPatchDiscriminator(
            num_discriminators, kernel_size, num_input_channels, num_filters,
            num_layers, max_num_filters, activation_norm_type,
            weight_norm_type)
        print('Done with the Multi-resolution patch '
              'discriminator initialization.')
Ejemplo n.º 5
0
    def __init__(self, dis_cfg, data_cfg):
        super(Discriminator, self).__init__()
        print('Multi-resolution patch discriminator initialization.')
        # We assume the first datum is the ground truth image.
        image_channels = get_paired_input_image_channel_number(data_cfg)
        # Calculate number of channels in the input label.
        if data_cfg.type == 'imaginaire.datasets.paired_videos':
            num_labels = get_paired_input_label_channel_number(data_cfg,
                                                               video=True)
        else:
            num_labels = get_paired_input_label_channel_number(data_cfg)

        # Build the discriminator.
        kernel_size = getattr(dis_cfg, 'kernel_size', 3)
        num_filters = getattr(dis_cfg, 'num_filters', 128)
        max_num_filters = getattr(dis_cfg, 'max_num_filters', 512)
        num_discriminators = getattr(dis_cfg, 'num_discriminators', 2)
        num_layers = getattr(dis_cfg, 'num_layers', 5)
        activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none')
        weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral')
        print('\tBase filter number: %d' % num_filters)
        print('\tNumber of discriminators: %d' % num_discriminators)
        print('\tNumber of layers in a discriminator: %d' % num_layers)
        print('\tWeight norm type: %s' % weight_norm_type)
        num_input_channels = image_channels + num_labels
        self.discriminators = nn.ModuleList()
        for i in range(num_discriminators):
            net_discriminator = NLayerPatchDiscriminator(
                kernel_size, num_input_channels, num_filters, num_layers,
                max_num_filters, activation_norm_type, weight_norm_type)
            self.discriminators.append(net_discriminator)
        print('Done with the Multi-resolution patch '
              'discriminator initialization.')
        fpse_kernel_size = getattr(dis_cfg, 'fpse_kernel_size', 3)
        fpse_activation_norm_type = getattr(dis_cfg,
                                            'fpse_activation_norm_type',
                                            'none')
        self.fpse_discriminator = FPSEDiscriminator(image_channels, num_labels,
                                                    num_filters,
                                                    fpse_kernel_size,
                                                    weight_norm_type,
                                                    fpse_activation_norm_type)
Ejemplo n.º 6
0
    def __init__(self, enc_cfg, data_cfg):
        super(Encoder, self).__init__()
        label_nc = get_paired_input_label_channel_number(data_cfg)
        feat_nc = enc_cfg.num_feat_channels
        n_clusters = getattr(enc_cfg, 'num_clusters', 10)
        for i in range(label_nc):
            dummy_arr = np.zeros((n_clusters, feat_nc), dtype=np.float32)
            self.register_buffer('cluster_%d' % i,
                                 torch.tensor(dummy_arr, dtype=torch.float32))
        num_img_channels = get_paired_input_image_channel_number(data_cfg)
        self.num_feat_channels = getattr(enc_cfg, 'num_feat_channels', 3)
        num_filters = getattr(enc_cfg, 'num_filters', 64)
        num_downsamples = getattr(enc_cfg, 'num_downsamples', 4)
        weight_norm_type = getattr(enc_cfg, 'weight_norm_type', 'none')
        activation_norm_type = getattr(enc_cfg, 'activation_norm_type',
                                       'instance')
        padding_mode = getattr(enc_cfg, 'padding_mode', 'reflect')
        base_conv_block = partial(Conv2dBlock,
                                  padding_mode=padding_mode,
                                  weight_norm_type=weight_norm_type,
                                  activation_norm_type=activation_norm_type,
                                  nonlinearity='relu')
        model = [base_conv_block(num_img_channels, num_filters, 7, padding=3)]
        # Downsample.
        for i in range(num_downsamples):
            ch = num_filters * (2**i)
            model += [base_conv_block(ch, ch * 2, 3, stride=2, padding=1)]
        # Upsample.
        for i in reversed(range(num_downsamples)):
            ch = num_filters * (2 ** i)
            model += [NearestUpsample(scale_factor=2),
                      base_conv_block(ch * 2, ch, 3, padding=1)]

        model += [Conv2dBlock(num_filters, self.num_feat_channels, 7,
                              padding=3, padding_mode=padding_mode,
                              nonlinearity='tanh')]
        self.model = nn.Sequential(*model)
Ejemplo n.º 7
0
    def __init__(self, gen_cfg, data_cfg):
        super().__init__()
        self.gen_cfg = gen_cfg
        self.data_cfg = data_cfg
        self.num_frames_G = data_cfg.num_frames_G
        # Number of residual blocks in generator.
        self.num_layers = num_layers = getattr(gen_cfg, 'num_layers', 7)
        # Number of downsamplings for previous frame.
        self.num_downsamples_img = getattr(gen_cfg, 'num_downsamples_img', 4)
        # Number of filters in the first layer.
        self.num_filters = num_filters = getattr(gen_cfg, 'num_filters', 32)
        self.max_num_filters = getattr(gen_cfg, 'max_num_filters', 1024)
        self.kernel_size = kernel_size = getattr(gen_cfg, 'kernel_size', 3)
        padding = kernel_size // 2

        # For pose dataset.
        self.is_pose_data = hasattr(data_cfg, 'for_pose_dataset')
        if self.is_pose_data:
            pose_cfg = data_cfg.for_pose_dataset
            self.pose_type = getattr(pose_cfg, 'pose_type', 'both')
            self.remove_face_labels = getattr(pose_cfg, 'remove_face_labels',
                                              False)

        # Input data params.
        self.num_input_channels = num_input_channels = get_paired_input_label_channel_number(
            data_cfg)

        num_img_channels = get_paired_input_image_channel_number(data_cfg)
        aug_cfg = data_cfg.val.augmentations
        if hasattr(aug_cfg, 'center_crop_h_w'):
            crop_h_w = aug_cfg.center_crop_h_w
        elif hasattr(aug_cfg, 'resize_h_w'):
            crop_h_w = aug_cfg.resize_h_w
        else:
            raise ValueError('Need to specify output size.')
        crop_h, crop_w = crop_h_w.split(',')
        crop_h, crop_w = int(crop_h), int(crop_w)
        # Spatial size at the bottle neck of generator.
        self.sh = crop_h // (2**num_layers)
        self.sw = crop_w // (2**num_layers)

        # Noise vector dimension.
        self.z_dim = getattr(gen_cfg, 'style_dims', 256)
        self.use_segmap_as_input = \
            getattr(gen_cfg, 'use_segmap_as_input', False)

        # Label / image embedding network.
        self.emb_cfg = emb_cfg = getattr(gen_cfg, 'embed', None)
        self.use_embed = getattr(emb_cfg, 'use_embed', 'True')
        self.num_downsamples_embed = getattr(emb_cfg, 'num_downsamples', 5)
        if self.use_embed:
            self.label_embedding = LabelEmbedder(emb_cfg, num_input_channels)

        # Flow network.
        self.flow_cfg = flow_cfg = gen_cfg.flow
        # Use SPADE to combine warped and hallucinated frames instead of
        # linear combination.
        self.spade_combine = getattr(flow_cfg, 'multi_spade_combine', True)
        # Number of layers to perform multi-spade combine.
        self.num_multi_spade_layers = getattr(flow_cfg.multi_spade_combine,
                                              'num_layers', 3)
        # At beginning of training, only train an image generator.
        self.temporal_initialized = False
        # Whether to output hallucinated frame (when training temporal network)
        # for additional loss.
        self.generate_raw_output = False

        # Image generation network.
        weight_norm_type = getattr(gen_cfg, 'weight_norm_type', 'spectral')
        activation_norm_type = gen_cfg.activation_norm_type
        activation_norm_params = gen_cfg.activation_norm_params
        if self.use_embed and \
                not hasattr(activation_norm_params, 'num_filters'):
            activation_norm_params.num_filters = 0
        nonlinearity = 'leakyrelu'

        self.base_res_block = base_res_block = partial(
            Res2dBlock,
            kernel_size=kernel_size,
            padding=padding,
            weight_norm_type=weight_norm_type,
            activation_norm_type=activation_norm_type,
            activation_norm_params=activation_norm_params,
            nonlinearity=nonlinearity,
            order='NACNAC')

        # Upsampling residual blocks.
        for i in range(num_layers, -1, -1):
            activation_norm_params.cond_dims = self.get_cond_dims(i)
            activation_norm_params.partial = self.get_partial(i) if hasattr(
                self, 'get_partial') else False
            layer = base_res_block(self.get_num_filters(i + 1),
                                   self.get_num_filters(i))
            setattr(self, 'up_%d' % i, layer)

        # Final conv layer.
        self.conv_img = Conv2dBlock(num_filters,
                                    num_img_channels,
                                    kernel_size,
                                    padding=padding,
                                    nonlinearity=nonlinearity,
                                    order='AC')

        num_filters = min(self.max_num_filters,
                          num_filters * (2**(self.num_layers + 1)))
        if self.use_segmap_as_input:
            self.fc = Conv2dBlock(num_input_channels,
                                  num_filters,
                                  kernel_size=3,
                                  padding=1)
        else:
            self.fc = LinearBlock(self.z_dim, num_filters * self.sh * self.sw)

        # Misc.
        self.downsample = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
        self.upsample = partial(F.interpolate, scale_factor=2)
        self.init_temporal_network()
Ejemplo n.º 8
0
    def __init__(self, flow_cfg, data_cfg):
        super().__init__()
        num_input_channels = get_paired_input_label_channel_number(data_cfg)
        num_prev_img_channels = get_paired_input_image_channel_number(data_cfg)
        num_frames = data_cfg.num_frames_G  # Num. of input frames.

        self.num_filters = num_filters = getattr(flow_cfg, 'num_filters', 32)
        self.max_num_filters = getattr(flow_cfg, 'max_num_filters', 1024)
        num_downsamples = getattr(flow_cfg, 'num_downsamples', 5)
        kernel_size = getattr(flow_cfg, 'kernel_size', 3)
        padding = kernel_size // 2
        self.num_res_blocks = getattr(flow_cfg, 'num_res_blocks', 6)
        # Multiplier on the flow output.
        self.flow_output_multiplier = getattr(flow_cfg,
                                              'flow_output_multiplier', 20)

        activation_norm_type = getattr(flow_cfg, 'activation_norm_type',
                                       'sync_batch')
        weight_norm_type = getattr(flow_cfg, 'weight_norm_type', 'spectral')

        base_conv_block = partial(Conv2dBlock,
                                  kernel_size=kernel_size,
                                  padding=padding,
                                  weight_norm_type=weight_norm_type,
                                  activation_norm_type=activation_norm_type,
                                  nonlinearity='leakyrelu')

        # Will downsample the labels and prev frames separately, then combine.
        down_lbl = [
            base_conv_block(num_input_channels * num_frames, num_filters)
        ]
        down_img = [
            base_conv_block(num_prev_img_channels * (num_frames - 1),
                            num_filters)
        ]
        for i in range(num_downsamples):
            down_lbl += [
                base_conv_block(self.get_num_filters(i),
                                self.get_num_filters(i + 1),
                                stride=2)
            ]
            down_img += [
                base_conv_block(self.get_num_filters(i),
                                self.get_num_filters(i + 1),
                                stride=2)
            ]

        # Resnet blocks.
        res_flow = []
        ch = self.get_num_filters(num_downsamples)
        for i in range(self.num_res_blocks):
            res_flow += [
                Res2dBlock(ch,
                           ch,
                           kernel_size,
                           padding=padding,
                           weight_norm_type=weight_norm_type,
                           activation_norm_type=activation_norm_type,
                           order='CNACN')
            ]

        # Upsample.
        up_flow = []
        for i in reversed(range(num_downsamples)):
            up_flow += [
                nn.Upsample(scale_factor=2),
                base_conv_block(self.get_num_filters(i + 1),
                                self.get_num_filters(i))
            ]

        conv_flow = [Conv2dBlock(num_filters, 2, kernel_size, padding=padding)]
        conv_mask = [
            Conv2dBlock(num_filters,
                        1,
                        kernel_size,
                        padding=padding,
                        nonlinearity='sigmoid')
        ]

        self.down_lbl = nn.Sequential(*down_lbl)
        self.down_img = nn.Sequential(*down_img)
        self.res_flow = nn.Sequential(*res_flow)
        self.up_flow = nn.Sequential(*up_flow)
        self.conv_flow = nn.Sequential(*conv_flow)
        self.conv_mask = nn.Sequential(*conv_mask)
Ejemplo n.º 9
0
    def init_temporal_network(self, cfg_init=None):
        r"""When starting training multiple frames, initialize the
        downsampling network and flow network.

        Args:
            cfg_init (dict) : Weight initialization config.
        """
        # Number of image downsamplings for the previous frame.
        num_downsamples_img = self.num_downsamples_img
        # Number of residual blocks for the previous frame.
        self.num_res_blocks = int(
            np.ceil((self.num_layers - num_downsamples_img) / 2.0) * 2)

        # First conv layer.
        num_img_channels = get_paired_input_image_channel_number(self.data_cfg)
        self.down_first = \
            Conv2dBlock(num_img_channels,
                        self.num_filters, self.kernel_size,
                        padding=self.kernel_size // 2)
        if cfg_init is not None:
            self.down_first.apply(weights_init(cfg_init.type, cfg_init.gain))

        # Downsampling residual blocks.
        activation_norm_params = self.gen_cfg.activation_norm_params
        for i in range(num_downsamples_img + 1):
            activation_norm_params.cond_dims = self.get_cond_dims(i)
            layer = self.base_res_block(self.get_num_filters(i),
                                        self.get_num_filters(i + 1))
            if cfg_init is not None:
                layer.apply(weights_init(cfg_init.type, cfg_init.gain))
            setattr(self, 'down_%d' % i, layer)

        # Additional residual blocks.
        res_ch = self.get_num_filters(num_downsamples_img + 1)
        activation_norm_params.cond_dims = \
            self.get_cond_dims(num_downsamples_img + 1)
        for i in range(self.num_res_blocks):
            layer = self.base_res_block(res_ch, res_ch)
            if cfg_init is not None:
                layer.apply(weights_init(cfg_init.type, cfg_init.gain))
            setattr(self, 'res_%d' % i, layer)

        # Flow network.
        flow_cfg = self.flow_cfg
        self.temporal_initialized = True
        self.generate_raw_output = getattr(flow_cfg, 'generate_raw_output',
                                           False) and self.spade_combine
        self.flow_network_temp = FlowGenerator(flow_cfg, self.data_cfg)
        if cfg_init is not None:
            self.flow_network_temp.apply(
                weights_init(cfg_init.type, cfg_init.gain))

        self.spade_combine = getattr(flow_cfg, 'multi_spade_combine', True)
        if self.spade_combine:
            emb_cfg = flow_cfg.multi_spade_combine.embed
            num_img_channels = get_paired_input_image_channel_number(
                self.data_cfg)
            self.img_prev_embedding = LabelEmbedder(emb_cfg,
                                                    num_img_channels + 1)
            if cfg_init is not None:
                self.img_prev_embedding.apply(
                    weights_init(cfg_init.type, cfg_init.gain))
Ejemplo n.º 10
0
    def __init__(self, gen_cfg, data_cfg):
        super(Generator, self).__init__()
        print('SPADE generator initialization.')
        # We assume the first datum is the ground truth image.
        image_channels = get_paired_input_image_channel_number(data_cfg)
        # Calculate number of channels in the input label.
        num_labels = get_paired_input_label_channel_number(data_cfg)
        crop_h, crop_w = get_crop_h_w(data_cfg.train.augmentations)
        # Build the generator
        out_image_small_side_size = crop_w if crop_w < crop_h else crop_h
        num_filters = getattr(gen_cfg, 'num_filters', 128)
        kernel_size = getattr(gen_cfg, 'kernel_size', 3)
        weight_norm_type = getattr(gen_cfg, 'weight_norm_type', 'spectral')

        cond_dims = 0
        # Check whether we use the style code.
        style_dims = getattr(gen_cfg, 'style_dims', None)
        self.style_dims = style_dims
        if style_dims is not None:
            print('\tStyle code dimensions: %d' % style_dims)
            cond_dims += style_dims
            self.use_style = True
        else:
            self.use_style = False
        # Check whether we use the attribute code.
        if hasattr(gen_cfg, 'attribute_dims'):
            self.use_attribute = True
            self.attribute_dims = gen_cfg.attribute_dims
            cond_dims += gen_cfg.attribute_dims
        else:
            self.use_attribute = False

        if not self.use_style and not self.use_attribute:
            self.use_style_encoder = False
        else:
            self.use_style_encoder = True
        print('\tBase filter number: %d' % num_filters)
        print('\tConvolution kernel size: %d' % kernel_size)
        print('\tWeight norm type: %s' % weight_norm_type)
        skip_activation_norm = \
            getattr(gen_cfg, 'skip_activation_norm', True)
        activation_norm_params = \
            getattr(gen_cfg, 'activation_norm_params', None)
        if activation_norm_params is None:
            activation_norm_params = types.SimpleNamespace()
        if not hasattr(activation_norm_params, 'num_filters'):
            setattr(activation_norm_params, 'num_filters', 128)
        if not hasattr(activation_norm_params, 'kernel_size'):
            setattr(activation_norm_params, 'kernel_size', 3)
        if not hasattr(activation_norm_params, 'activation_norm_type'):
            setattr(activation_norm_params, 'activation_norm_type',
                    'sync_batch')
        if not hasattr(activation_norm_params, 'separate_projection'):
            setattr(activation_norm_params, 'separate_projection', False)
        if not hasattr(activation_norm_params, 'activation_norm_params'):
            activation_norm_params.activation_norm_params = \
                types.SimpleNamespace()
            activation_norm_params.activation_norm_params.affine = True
        setattr(activation_norm_params, 'cond_dims', num_labels)
        if not hasattr(activation_norm_params, 'weight_norm_type'):
            setattr(activation_norm_params, 'weight_norm_type',
                    weight_norm_type)
        global_adaptive_norm_type = getattr(gen_cfg,
                                            'global_adaptive_norm_type',
                                            'sync_batch')
        use_posenc_in_input_layer = getattr(gen_cfg,
                                            'use_posenc_in_input_layer', True)
        print(activation_norm_params)
        self.spade_generator = SPADEGenerator(
            num_labels, out_image_small_side_size, image_channels, num_filters,
            kernel_size, cond_dims, activation_norm_params, weight_norm_type,
            global_adaptive_norm_type, skip_activation_norm,
            use_posenc_in_input_layer, self.use_style_encoder)
        if self.use_style:
            # Build the encoder.
            style_enc_cfg = getattr(gen_cfg, 'style_enc', None)
            if style_enc_cfg is None:
                style_enc_cfg = types.SimpleNamespace()
            if not hasattr(style_enc_cfg, 'num_filters'):
                setattr(style_enc_cfg, 'num_filters', 128)
            if not hasattr(style_enc_cfg, 'kernel_size'):
                setattr(style_enc_cfg, 'kernel_size', 3)
            if not hasattr(style_enc_cfg, 'weight_norm_type'):
                setattr(style_enc_cfg, 'weight_norm_type', weight_norm_type)
            setattr(style_enc_cfg, 'input_image_channels', image_channels)
            setattr(style_enc_cfg, 'style_dims', style_dims)
            self.style_encoder = StyleEncoder(style_enc_cfg)

        self.z = None
        print('Done with the SPADE generator initialization.')