예제 #1
0
    def __init__(self, gen_cfg, data_cfg):
        super(Generator, self).__init__()
        # input downsample
        nonlinearity = gen_cfg.nonlinearity
        self.downsample1 = nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=True)
        self.downsample2 = nn.Upsample(scale_factor=0.25, mode='bilinear', align_corners=True)
        self.downsample3 = nn.Upsample(scale_factor=0.125, mode='bilinear', align_corners=True)
        self.downsample4 = nn.Upsample(scale_factor=0.0625, mode='bilinear', align_corners=True)

        conv_params = dict(kernel_size=3,
                           padding=1,
                           activation_norm_type="instance",
                           nonlinearity=nonlinearity,
                           inplace_nonlinearity=True)
        # encoder
        self.layer1 = Conv2dBlock(in_channels=9, out_channels=64, kernel_size=3, padding=1, stride=2,
                                  nonlinearity=nonlinearity, inplace_nonlinearity=True)
        self.layer2 = Conv2dBlock(in_channels=64+6, out_channels=128, stride=2, **conv_params)
        self.layer3 = Conv2dBlock(in_channels=128+6, out_channels=256, stride=2, **conv_params)
        self.layer4 = Conv2dBlock(in_channels=256+6, out_channels=512, stride=2, **conv_params)

        # decoder
        self.layer5 = UpscaleBlock(in_channels=512+256+6+6, out_channels=256)
        self.layer6 = UpscaleBlock(in_channels=256+128 + 6, out_channels=128)
        self.layer7 = UpscaleBlock(in_channels=128+64 + 6, out_channels=64)
        self.layer8 = UpscaleBlock(in_channels=64 +6, out_channels=64, use_norm=True, use_act=True)
        self.outlayer = Out_Branch(nc_in=64)
예제 #2
0
 def __init__(self, num_downsamples, image_channels, num_filters,
              style_channels, padding_mode, activation_norm_type,
              weight_norm_type, nonlinearity):
     super().__init__()
     conv_params = dict(padding_mode=padding_mode,
                        activation_norm_type=activation_norm_type,
                        weight_norm_type=weight_norm_type,
                        nonlinearity=nonlinearity,
                        inplace_nonlinearity=True)
     model = []
     model += [
         Conv2dBlock(image_channels, num_filters, 7, 1, 3, **conv_params)
     ]
     for i in range(2):
         model += [
             Conv2dBlock(num_filters, 2 * num_filters, 4, 2, 1,
                         **conv_params)
         ]
         num_filters *= 2
     for i in range(num_downsamples - 2):
         model += [
             Conv2dBlock(num_filters, num_filters, 4, 2, 1, **conv_params)
         ]
     model += [nn.AdaptiveAvgPool2d(1)]
     model += [nn.Conv2d(num_filters, style_channels, 1, 1, 0)]
     self.model = nn.Sequential(*model)
     self.output_dim = num_filters
예제 #3
0
    def __init__(self,
                 image_channels=3,
                 num_filters=64,
                 max_num_filters=512,
                 first_kernel_size=1,
                 num_layers=4,
                 padding_mode='zeros',
                 activation_norm_type='',
                 weight_norm_type='',
                 aggregation='conv',
                 order='pre_act',
                 anti_aliased=False,
                 **kwargs):
        super().__init__()
        for key in kwargs:
            if key != 'type' and key != 'patch_wise':
                warnings.warn(
                    "Discriminator argument {} is not used".format(key))

        conv_params = dict(padding_mode=padding_mode,
                           activation_norm_type=activation_norm_type,
                           weight_norm_type=weight_norm_type,
                           nonlinearity='leakyrelu')

        first_padding = (first_kernel_size - 1) // 2
        model = [
            Conv2dBlock(image_channels, num_filters, first_kernel_size, 1,
                        first_padding, **conv_params)
        ]
        for _ in range(num_layers):
            num_filters_prev = num_filters
            num_filters = min(num_filters * 2, max_num_filters)
            model.append(
                Res2dBlock(num_filters_prev,
                           num_filters,
                           order=order,
                           **conv_params))
            model.append(nn.AvgPool2d(2, stride=2))
        if aggregation == 'pool':
            model += [torch.nn.AdaptiveAvgPool2d(1)]
        elif aggregation == 'conv':
            model += [
                Conv2dBlock(num_filters,
                            num_filters,
                            4,
                            1,
                            0,
                            nonlinearity='leakyrelu')
            ]
        else:
            raise ValueError('The aggregation mode is not recognized' %
                             self.aggregation)
        self.model = nn.Sequential(*model)
        self.classifier = nn.Linear(num_filters, 1)
예제 #4
0
    def __init__(self,
                 image_channels=3,
                 num_classes=119,
                 num_filters=64,
                 max_num_filters=1024,
                 num_layers=6,
                 padding_mode='reflect',
                 weight_norm_type='',
                 **kwargs):
        super().__init__()
        for key in kwargs:
            if key != 'type':
                warnings.warn(
                    "Discriminator argument {} is not used".format(key))

        conv_params = dict(padding_mode=padding_mode,
                           activation_norm_type='none',
                           weight_norm_type=weight_norm_type,
                           bias=[True, True, True],
                           nonlinearity='leakyrelu',
                           order='NACNAC')

        first_kernel_size = 7
        first_padding = (first_kernel_size - 1) // 2
        model = [
            Conv2dBlock(image_channels,
                        num_filters,
                        first_kernel_size,
                        1,
                        first_padding,
                        padding_mode=padding_mode,
                        weight_norm_type=weight_norm_type)
        ]
        for i in range(num_layers):
            num_filters_prev = num_filters
            num_filters = min(num_filters * 2, max_num_filters)
            model += [
                Res2dBlock(num_filters_prev, num_filters_prev, **conv_params),
                Res2dBlock(num_filters_prev, num_filters, **conv_params)
            ]
            if i != num_layers - 1:
                model += [nn.ReflectionPad2d(1), nn.AvgPool2d(3, stride=2)]
        self.model = nn.Sequential(*model)
        self.classifier = Conv2dBlock(num_filters,
                                      1,
                                      1,
                                      1,
                                      0,
                                      nonlinearity='leakyrelu',
                                      weight_norm_type=weight_norm_type,
                                      order='NACNAC')

        self.embedder = nn.Embedding(num_classes, num_filters)
예제 #5
0
파일: unit.py 프로젝트: yejees/ObjectSwap
    def __init__(self,
                 num_upsamples,
                 num_res_blocks,
                 num_filters,
                 num_image_channels,
                 padding_mode,
                 activation_norm_type,
                 weight_norm_type,
                 nonlinearity,
                 output_nonlinearity,
                 pre_act=False,
                 apply_noise=False):
        super().__init__()

        conv_params = dict(padding_mode=padding_mode,
                           nonlinearity=nonlinearity,
                           inplace_nonlinearity=True,
                           apply_noise=apply_noise,
                           weight_norm_type=weight_norm_type,
                           activation_norm_type=activation_norm_type)

        # The order of operations in residual blocks.
        order = 'pre_act' if pre_act else 'CNACNA'

        # Residual blocks.
        self.decoder = nn.ModuleList()
        for _ in range(num_res_blocks):
            self.decoder += [
                Res2dBlock(num_filters,
                           num_filters,
                           **conv_params,
                           order=order)
            ]

        # Convolutional blocks with upsampling.
        for i in range(num_upsamples):
            self.decoder += [NearestUpsample(scale_factor=2)]
            self.decoder += [
                Conv2dBlock(num_filters, num_filters // 2, 5, 1, 2,
                            **conv_params)
            ]
            num_filters //= 2
        self.decoder += [
            Conv2dBlock(num_filters,
                        num_image_channels,
                        7,
                        1,
                        3,
                        nonlinearity=output_nonlinearity,
                        padding_mode=padding_mode)
        ]
예제 #6
0
파일: unit.py 프로젝트: yejees/ObjectSwap
    def __init__(self,
                 num_downsamples,
                 num_res_blocks,
                 num_image_channels,
                 num_filters,
                 max_num_filters,
                 padding_mode,
                 activation_norm_type,
                 weight_norm_type,
                 nonlinearity,
                 pre_act=False):
        super().__init__()
        conv_params = dict(padding_mode=padding_mode,
                           activation_norm_type=activation_norm_type,
                           weight_norm_type=weight_norm_type,
                           nonlinearity=nonlinearity)
        # Whether or not it is safe to use inplace nonlinear activation.
        if not pre_act or (activation_norm_type != ''
                           and activation_norm_type != 'none'):
            conv_params['inplace_nonlinearity'] = True

        # The order of operations in residual blocks.
        order = 'pre_act' if pre_act else 'CNACNA'

        model = []
        model += [
            Conv2dBlock(num_image_channels, num_filters, 7, 1, 3,
                        **conv_params)
        ]

        # Downsampling blocks.
        for i in range(num_downsamples):
            num_filters_prev = num_filters
            num_filters = min(num_filters * 2, max_num_filters)
            model += [
                Conv2dBlock(num_filters_prev, num_filters, 4, 2, 1,
                            **conv_params)
            ]

        # Residual blocks.
        for _ in range(num_res_blocks):
            model += [
                Res2dBlock(num_filters,
                           num_filters,
                           **conv_params,
                           order=order)
            ]
        self.model = nn.Sequential(*model)
        self.output_dim = num_filters
예제 #7
0
 def __init__(self, gen_cfg, data_cfg, num_input_channels, padding_mode,
              base_conv_block, base_res_block):
     super(GlobalGenerator, self).__init__()
     num_img_channels = get_paired_input_image_channel_number(data_cfg)
     num_filters = getattr(gen_cfg, 'num_filters', 64)
     num_downsamples = getattr(gen_cfg, 'num_downsamples', 4)
     num_res_blocks = getattr(gen_cfg, 'num_res_blocks', 9)
     # First layer.
     model = [base_conv_block(num_input_channels, num_filters,
                              kernel_size=7, padding=3)]
     # Downsample.
     for i in range(num_downsamples):
         ch = num_filters * (2 ** i)
         model += [base_conv_block(ch, ch * 2, 3, padding=1, stride=2)]
     # ResNet blocks.
     ch = num_filters * (2 ** num_downsamples)
     for i in range(num_res_blocks):
         model += [base_res_block(ch, ch, 3, padding=1)]
     # Upsample.
     num_upsamples = num_downsamples
     for i in reversed(range(num_upsamples)):
         ch = num_filters * (2 ** i)
         model += \
             [NearestUpsample(scale_factor=2),
              base_conv_block(ch * 2, ch, 3, padding=1)]
     model += [Conv2dBlock(num_filters, num_img_channels, 7, padding=3,
                           padding_mode=padding_mode, nonlinearity='tanh')]
     self.model = nn.Sequential(*model)
예제 #8
0
    def __init__(self, gen_cfg, data_cfg, num_input_channels, num_filters,
                 padding_mode, base_conv_block, base_res_block,
                 output_img=False):
        super(LocalEnhancer, self).__init__()
        num_res_blocks = getattr(gen_cfg, 'num_res_blocks', 3)
        num_img_channels = get_paired_input_image_channel_number(data_cfg)
        # Downsample.
        model_downsample = \
            [base_conv_block(num_input_channels, num_filters, 7, padding=3),
             base_conv_block(num_filters, num_filters * 2, 3, stride=2,
                             padding=1)]
        # Residual blocks.
        model_upsample = []
        for i in range(num_res_blocks):
            model_upsample += [base_res_block(num_filters * 2, num_filters * 2,
                                              3, padding=1)]
        # Upsample.
        model_upsample += \
            [NearestUpsample(scale_factor=2),
             base_conv_block(num_filters * 2, num_filters, 3, padding=1)]

        # Final convolution.
        if output_img:
            model_upsample += [Conv2dBlock(num_filters, num_img_channels, 7,
                                           padding=3, padding_mode=padding_mode,
                                           nonlinearity='tanh')]

        self.model_downsample = nn.Sequential(*model_downsample)
        self.model_upsample = nn.Sequential(*model_upsample)
예제 #9
0
    def __init__(self,
                 num_enc_output_channels,
                 style_channels,
                 num_image_channels=3,
                 num_upsamples=4,
                 padding_type='reflect',
                 weight_norm_type='none',
                 nonlinearity='relu'):
        super(Decoder, self).__init__()
        adain_params = SimpleNamespace(
            activation_norm_type='instance',
            activation_norm_params=SimpleNamespace(affine=False),
            cond_dims=style_channels)

        base_res_block = partial(Res2dBlock,
                                 kernel_size=3,
                                 padding=1,
                                 padding_mode=padding_type,
                                 nonlinearity=nonlinearity,
                                 activation_norm_type='adaptive',
                                 activation_norm_params=adain_params,
                                 weight_norm_type=weight_norm_type)

        base_up_res_block = partial(UpRes2dBlock,
                                    kernel_size=5,
                                    padding=2,
                                    padding_mode=padding_type,
                                    weight_norm_type=weight_norm_type,
                                    activation_norm_type='adaptive',
                                    activation_norm_params=adain_params,
                                    skip_activation_norm='instance',
                                    skip_nonlinearity=nonlinearity,
                                    nonlinearity=nonlinearity,
                                    hidden_channels_equal_out_channels=True)

        dims = num_enc_output_channels

        # Residual blocks with AdaIN.
        self.decoder = nn.ModuleList()
        self.decoder += [base_res_block(dims, dims)]
        self.decoder += [base_res_block(dims, dims)]
        for _ in range(num_upsamples):
            self.decoder += [base_up_res_block(dims, dims // 2)]
            dims = dims // 2
        self.decoder += [
            Conv2dBlock(dims,
                        num_image_channels,
                        kernel_size=7,
                        stride=1,
                        padding=3,
                        padding_mode='reflect',
                        nonlinearity='tanh')
        ]
예제 #10
0
    def __init__(self, gen_cfg, data_cfg):
        super(Discriminator, self).__init__()
        nonlinearity = gen_cfg.nonlinearity
        # input downsample
        self.downsample1 = nn.Upsample(scale_factor=0.5, mode='bilinear')
        self.downsample2 = nn.Upsample(scale_factor=0.25, mode='bilinear')
        self.downsample3 = nn.Upsample(scale_factor=0.125, mode='bilinear')
        self.downsample4 = nn.Upsample(scale_factor=0.0625, mode='bilinear')

        conv_params = dict(kernel_size=3,
                           padding=1,
                           activation_norm_type="instance",
                           nonlinearity=nonlinearity,
                           inplace_nonlinearity=True)
        # encoder
        self.apply_noise = ApplyNoise()
        self.layer1 = Conv2dBlock(in_channels=6,
                                  out_channels=64,
                                  kernel_size=3,
                                  padding=1,
                                  stride=2,
                                  nonlinearity=nonlinearity,
                                  inplace_nonlinearity=True)
        self.layer2 = Conv2dBlock(in_channels=64 + 6,
                                  out_channels=128,
                                  stride=2,
                                  **conv_params)
        self.layer3 = Conv2dBlock(in_channels=128 + 6,
                                  out_channels=256,
                                  stride=2,
                                  **conv_params)
        self.layer4 = Conv2dBlock(in_channels=256 + 6,
                                  out_channels=512,
                                  stride=2,
                                  **conv_params)
        self.outlayer = Conv2dBlock(in_channels=512 + 6,
                                    out_channels=1,
                                    kernel_size=3,
                                    nonlinearity="sigmoid")
예제 #11
0
    def __init__(self, num_downsamples, num_res_blocks, image_channels,
                 num_filters, padding_mode, activation_norm_type,
                 weight_norm_type, nonlinearity):
        super().__init__()
        conv_params = dict(padding_mode=padding_mode,
                           activation_norm_type=activation_norm_type,
                           weight_norm_type=weight_norm_type,
                           nonlinearity=nonlinearity,
                           inplace_nonlinearity=True,
                           order='CNACNA')
        model = []
        model += [
            Conv2dBlock(image_channels, num_filters, 7, 1, 3, **conv_params)
        ]
        dims = num_filters
        for i in range(num_downsamples):
            model += [Conv2dBlock(dims, dims * 2, 4, 2, 1, **conv_params)]
            dims *= 2

        for _ in range(num_res_blocks):
            model += [Res2dBlock(dims, dims, **conv_params)]
        self.model = nn.Sequential(*model)
        self.output_dim = dims
예제 #12
0
    def __init__(self, enc_cfg, data_cfg):
        super(Encoder, self).__init__()
        label_nc = get_paired_input_label_channel_number(data_cfg)
        feat_nc = enc_cfg.num_feat_channels
        n_clusters = getattr(enc_cfg, 'num_clusters', 10)
        for i in range(label_nc):
            dummy_arr = np.zeros((n_clusters, feat_nc), dtype=np.float32)
            self.register_buffer('cluster_%d' % i,
                                 torch.tensor(dummy_arr, dtype=torch.float32))
        num_img_channels = get_paired_input_image_channel_number(data_cfg)
        self.num_feat_channels = getattr(enc_cfg, 'num_feat_channels', 3)
        num_filters = getattr(enc_cfg, 'num_filters', 64)
        num_downsamples = getattr(enc_cfg, 'num_downsamples', 4)
        weight_norm_type = getattr(enc_cfg, 'weight_norm_type', 'none')
        activation_norm_type = getattr(enc_cfg, 'activation_norm_type',
                                       'instance')
        padding_mode = getattr(enc_cfg, 'padding_mode', 'reflect')
        base_conv_block = partial(Conv2dBlock,
                                  padding_mode=padding_mode,
                                  weight_norm_type=weight_norm_type,
                                  activation_norm_type=activation_norm_type,
                                  nonlinearity='relu')
        model = [base_conv_block(num_img_channels, num_filters, 7, padding=3)]
        # Downsample.
        for i in range(num_downsamples):
            ch = num_filters * (2**i)
            model += [base_conv_block(ch, ch * 2, 3, stride=2, padding=1)]
        # Upsample.
        for i in reversed(range(num_downsamples)):
            ch = num_filters * (2 ** i)
            model += [NearestUpsample(scale_factor=2),
                      base_conv_block(ch * 2, ch, 3, padding=1)]

        model += [Conv2dBlock(num_filters, self.num_feat_channels, 7,
                              padding=3, padding_mode=padding_mode,
                              nonlinearity='tanh')]
        self.model = nn.Sequential(*model)
예제 #13
0
 def __init__(self, kernel_size, num_input_channels, num_filters,
              num_layers, max_num_filters, activation_norm_type,
              weight_norm_type):
     super(NLayerPatchDiscriminator, self).__init__()
     self.num_layers = num_layers
     padding = int(np.floor((kernel_size - 1.0) / 2))
     nonlinearity = 'leakyrelu'
     base_conv2d_block = \
         functools.partial(Conv2dBlock,
                           kernel_size=kernel_size,
                           padding=padding,
                           weight_norm_type=weight_norm_type,
                           activation_norm_type=activation_norm_type,
                           nonlinearity=nonlinearity,
                           # inplace_nonlinearity=True,
                           order='CNA')
     layers = [[
         base_conv2d_block(num_input_channels, num_filters, stride=2)
     ]]
     for n in range(num_layers):
         num_filters_prev = num_filters
         num_filters = min(num_filters * 2, max_num_filters)
         stride = 2 if n < (num_layers - 1) else 1
         layers += [[
             base_conv2d_block(num_filters_prev, num_filters, stride=stride)
         ]]
     layers += [[
         Conv2dBlock(num_filters,
                     1,
                     3,
                     1,
                     padding,
                     weight_norm_type=weight_norm_type)
     ]]
     for n in range(len(layers)):
         setattr(self, 'layer' + str(n), nn.Sequential(*layers[n]))
예제 #14
0
파일: fpse.py 프로젝트: yejees/ObjectSwap
    def __init__(self, num_input_channels, num_labels, num_filters,
                 kernel_size, weight_norm_type, activation_norm_type):
        super().__init__()
        padding = int(np.ceil((kernel_size - 1.0) / 2))
        nonlinearity = 'leakyrelu'
        stride1_conv2d_block = \
            functools.partial(Conv2dBlock,
                              kernel_size=kernel_size,
                              stride=1,
                              padding=padding,
                              weight_norm_type=weight_norm_type,
                              activation_norm_type=activation_norm_type,
                              nonlinearity=nonlinearity,
                              # inplace_nonlinearity=True,
                              order='CNA')
        down_conv2d_block = \
            functools.partial(Conv2dBlock,
                              kernel_size=kernel_size,
                              stride=2,
                              padding=padding,
                              weight_norm_type=weight_norm_type,
                              activation_norm_type=activation_norm_type,
                              nonlinearity=nonlinearity,
                              # inplace_nonlinearity=True,
                              order='CNA')
        latent_conv2d_block = \
            functools.partial(Conv2dBlock,
                              kernel_size=1,
                              stride=1,
                              weight_norm_type=weight_norm_type,
                              activation_norm_type=activation_norm_type,
                              nonlinearity=nonlinearity,
                              # inplace_nonlinearity=True,
                              order='CNA')
        # bottom-up pathway

        self.enc1 = down_conv2d_block(num_input_channels, num_filters)
        self.enc2 = down_conv2d_block(1 * num_filters, 2 * num_filters)
        self.enc3 = down_conv2d_block(2 * num_filters, 4 * num_filters)
        self.enc4 = down_conv2d_block(4 * num_filters, 8 * num_filters)
        self.enc5 = down_conv2d_block(8 * num_filters, 8 * num_filters)

        # top-down pathway
        self.lat2 = latent_conv2d_block(2 * num_filters, 4 * num_filters)
        self.lat3 = latent_conv2d_block(4 * num_filters, 4 * num_filters)
        self.lat4 = latent_conv2d_block(8 * num_filters, 4 * num_filters)
        self.lat5 = latent_conv2d_block(8 * num_filters, 4 * num_filters)

        # upsampling
        self.upsample2x = nn.Upsample(scale_factor=2,
                                      mode='bilinear',
                                      align_corners=False)

        # final layers
        self.final2 = stride1_conv2d_block(4 * num_filters, 2 * num_filters)
        self.final3 = stride1_conv2d_block(4 * num_filters, 2 * num_filters)
        self.final4 = stride1_conv2d_block(4 * num_filters, 2 * num_filters)

        # true/false prediction and semantic alignment prediction
        self.output = Conv2dBlock(num_filters * 2, 1, kernel_size=1)
        self.seg = Conv2dBlock(num_filters * 2, num_filters * 2, kernel_size=1)
        self.embedding = Conv2dBlock(num_labels,
                                     num_filters * 2,
                                     kernel_size=1)
예제 #15
0
    def __init__(self, gen_cfg, data_cfg):
        super().__init__()
        self.gen_cfg = gen_cfg
        self.data_cfg = data_cfg
        self.num_frames_G = data_cfg.num_frames_G
        # Number of residual blocks in generator.
        self.num_layers = num_layers = getattr(gen_cfg, 'num_layers', 7)
        # Number of downsamplings for previous frame.
        self.num_downsamples_img = getattr(gen_cfg, 'num_downsamples_img', 4)
        # Number of filters in the first layer.
        self.num_filters = num_filters = getattr(gen_cfg, 'num_filters', 32)
        self.max_num_filters = getattr(gen_cfg, 'max_num_filters', 1024)
        self.kernel_size = kernel_size = getattr(gen_cfg, 'kernel_size', 3)
        padding = kernel_size // 2

        # For pose dataset.
        self.is_pose_data = hasattr(data_cfg, 'for_pose_dataset')
        if self.is_pose_data:
            pose_cfg = data_cfg.for_pose_dataset
            self.pose_type = getattr(pose_cfg, 'pose_type', 'both')
            self.remove_face_labels = getattr(pose_cfg, 'remove_face_labels',
                                              False)

        # Input data params.
        self.num_input_channels = num_input_channels = get_paired_input_label_channel_number(
            data_cfg)

        num_img_channels = get_paired_input_image_channel_number(data_cfg)
        aug_cfg = data_cfg.val.augmentations
        if hasattr(aug_cfg, 'center_crop_h_w'):
            crop_h_w = aug_cfg.center_crop_h_w
        elif hasattr(aug_cfg, 'resize_h_w'):
            crop_h_w = aug_cfg.resize_h_w
        else:
            raise ValueError('Need to specify output size.')
        crop_h, crop_w = crop_h_w.split(',')
        crop_h, crop_w = int(crop_h), int(crop_w)
        # Spatial size at the bottle neck of generator.
        self.sh = crop_h // (2**num_layers)
        self.sw = crop_w // (2**num_layers)

        # Noise vector dimension.
        self.z_dim = getattr(gen_cfg, 'style_dims', 256)
        self.use_segmap_as_input = \
            getattr(gen_cfg, 'use_segmap_as_input', False)

        # Label / image embedding network.
        self.emb_cfg = emb_cfg = getattr(gen_cfg, 'embed', None)
        self.use_embed = getattr(emb_cfg, 'use_embed', 'True')
        self.num_downsamples_embed = getattr(emb_cfg, 'num_downsamples', 5)
        if self.use_embed:
            self.label_embedding = LabelEmbedder(emb_cfg, num_input_channels)

        # Flow network.
        self.flow_cfg = flow_cfg = gen_cfg.flow
        # Use SPADE to combine warped and hallucinated frames instead of
        # linear combination.
        self.spade_combine = getattr(flow_cfg, 'multi_spade_combine', True)
        # Number of layers to perform multi-spade combine.
        self.num_multi_spade_layers = getattr(flow_cfg.multi_spade_combine,
                                              'num_layers', 3)
        # At beginning of training, only train an image generator.
        self.temporal_initialized = False
        # Whether to output hallucinated frame (when training temporal network)
        # for additional loss.
        self.generate_raw_output = False

        # Image generation network.
        weight_norm_type = getattr(gen_cfg, 'weight_norm_type', 'spectral')
        activation_norm_type = gen_cfg.activation_norm_type
        activation_norm_params = gen_cfg.activation_norm_params
        if self.use_embed and \
                not hasattr(activation_norm_params, 'num_filters'):
            activation_norm_params.num_filters = 0
        nonlinearity = 'leakyrelu'

        self.base_res_block = base_res_block = partial(
            Res2dBlock,
            kernel_size=kernel_size,
            padding=padding,
            weight_norm_type=weight_norm_type,
            activation_norm_type=activation_norm_type,
            activation_norm_params=activation_norm_params,
            nonlinearity=nonlinearity,
            order='NACNAC')

        # Upsampling residual blocks.
        for i in range(num_layers, -1, -1):
            activation_norm_params.cond_dims = self.get_cond_dims(i)
            activation_norm_params.partial = self.get_partial(i) if hasattr(
                self, 'get_partial') else False
            layer = base_res_block(self.get_num_filters(i + 1),
                                   self.get_num_filters(i))
            setattr(self, 'up_%d' % i, layer)

        # Final conv layer.
        self.conv_img = Conv2dBlock(num_filters,
                                    num_img_channels,
                                    kernel_size,
                                    padding=padding,
                                    nonlinearity=nonlinearity,
                                    order='AC')

        num_filters = min(self.max_num_filters,
                          num_filters * (2**(self.num_layers + 1)))
        if self.use_segmap_as_input:
            self.fc = Conv2dBlock(num_input_channels,
                                  num_filters,
                                  kernel_size=3,
                                  padding=1)
        else:
            self.fc = LinearBlock(self.z_dim, num_filters * self.sh * self.sw)

        # Misc.
        self.downsample = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
        self.upsample = partial(F.interpolate, scale_factor=2)
        self.init_temporal_network()
예제 #16
0
    def init_temporal_network(self, cfg_init=None):
        r"""When starting training multiple frames, initialize the
        downsampling network and flow network.

        Args:
            cfg_init (dict) : Weight initialization config.
        """
        # Number of image downsamplings for the previous frame.
        num_downsamples_img = self.num_downsamples_img
        # Number of residual blocks for the previous frame.
        self.num_res_blocks = int(
            np.ceil((self.num_layers - num_downsamples_img) / 2.0) * 2)

        # First conv layer.
        num_img_channels = get_paired_input_image_channel_number(self.data_cfg)
        self.down_first = \
            Conv2dBlock(num_img_channels,
                        self.num_filters, self.kernel_size,
                        padding=self.kernel_size // 2)
        if cfg_init is not None:
            self.down_first.apply(weights_init(cfg_init.type, cfg_init.gain))

        # Downsampling residual blocks.
        activation_norm_params = self.gen_cfg.activation_norm_params
        for i in range(num_downsamples_img + 1):
            activation_norm_params.cond_dims = self.get_cond_dims(i)
            layer = self.base_res_block(self.get_num_filters(i),
                                        self.get_num_filters(i + 1))
            if cfg_init is not None:
                layer.apply(weights_init(cfg_init.type, cfg_init.gain))
            setattr(self, 'down_%d' % i, layer)

        # Additional residual blocks.
        res_ch = self.get_num_filters(num_downsamples_img + 1)
        activation_norm_params.cond_dims = \
            self.get_cond_dims(num_downsamples_img + 1)
        for i in range(self.num_res_blocks):
            layer = self.base_res_block(res_ch, res_ch)
            if cfg_init is not None:
                layer.apply(weights_init(cfg_init.type, cfg_init.gain))
            setattr(self, 'res_%d' % i, layer)

        # Flow network.
        flow_cfg = self.flow_cfg
        self.temporal_initialized = True
        self.generate_raw_output = getattr(flow_cfg, 'generate_raw_output',
                                           False) and self.spade_combine
        self.flow_network_temp = FlowGenerator(flow_cfg, self.data_cfg)
        if cfg_init is not None:
            self.flow_network_temp.apply(
                weights_init(cfg_init.type, cfg_init.gain))

        self.spade_combine = getattr(flow_cfg, 'multi_spade_combine', True)
        if self.spade_combine:
            emb_cfg = flow_cfg.multi_spade_combine.embed
            num_img_channels = get_paired_input_image_channel_number(
                self.data_cfg)
            self.img_prev_embedding = LabelEmbedder(emb_cfg,
                                                    num_img_channels + 1)
            if cfg_init is not None:
                self.img_prev_embedding.apply(
                    weights_init(cfg_init.type, cfg_init.gain))
예제 #17
0
    def __init__(self, flow_cfg, data_cfg):
        super().__init__()
        num_input_channels = get_paired_input_label_channel_number(data_cfg)
        num_prev_img_channels = get_paired_input_image_channel_number(data_cfg)
        num_frames = data_cfg.num_frames_G  # Num. of input frames.

        self.num_filters = num_filters = getattr(flow_cfg, 'num_filters', 32)
        self.max_num_filters = getattr(flow_cfg, 'max_num_filters', 1024)
        num_downsamples = getattr(flow_cfg, 'num_downsamples', 5)
        kernel_size = getattr(flow_cfg, 'kernel_size', 3)
        padding = kernel_size // 2
        self.num_res_blocks = getattr(flow_cfg, 'num_res_blocks', 6)
        # Multiplier on the flow output.
        self.flow_output_multiplier = getattr(flow_cfg,
                                              'flow_output_multiplier', 20)

        activation_norm_type = getattr(flow_cfg, 'activation_norm_type',
                                       'sync_batch')
        weight_norm_type = getattr(flow_cfg, 'weight_norm_type', 'spectral')

        base_conv_block = partial(Conv2dBlock,
                                  kernel_size=kernel_size,
                                  padding=padding,
                                  weight_norm_type=weight_norm_type,
                                  activation_norm_type=activation_norm_type,
                                  nonlinearity='leakyrelu')

        # Will downsample the labels and prev frames separately, then combine.
        down_lbl = [
            base_conv_block(num_input_channels * num_frames, num_filters)
        ]
        down_img = [
            base_conv_block(num_prev_img_channels * (num_frames - 1),
                            num_filters)
        ]
        for i in range(num_downsamples):
            down_lbl += [
                base_conv_block(self.get_num_filters(i),
                                self.get_num_filters(i + 1),
                                stride=2)
            ]
            down_img += [
                base_conv_block(self.get_num_filters(i),
                                self.get_num_filters(i + 1),
                                stride=2)
            ]

        # Resnet blocks.
        res_flow = []
        ch = self.get_num_filters(num_downsamples)
        for i in range(self.num_res_blocks):
            res_flow += [
                Res2dBlock(ch,
                           ch,
                           kernel_size,
                           padding=padding,
                           weight_norm_type=weight_norm_type,
                           activation_norm_type=activation_norm_type,
                           order='CNACN')
            ]

        # Upsample.
        up_flow = []
        for i in reversed(range(num_downsamples)):
            up_flow += [
                nn.Upsample(scale_factor=2),
                base_conv_block(self.get_num_filters(i + 1),
                                self.get_num_filters(i))
            ]

        conv_flow = [Conv2dBlock(num_filters, 2, kernel_size, padding=padding)]
        conv_mask = [
            Conv2dBlock(num_filters,
                        1,
                        kernel_size,
                        padding=padding,
                        nonlinearity='sigmoid')
        ]

        self.down_lbl = nn.Sequential(*down_lbl)
        self.down_img = nn.Sequential(*down_img)
        self.res_flow = nn.Sequential(*res_flow)
        self.up_flow = nn.Sequential(*up_flow)
        self.conv_flow = nn.Sequential(*conv_flow)
        self.conv_mask = nn.Sequential(*conv_mask)
예제 #18
0
    def __init__(self, num_labels, out_image_small_side_size, image_channels,
                 num_filters, kernel_size, style_dims, activation_norm_params,
                 weight_norm_type, global_adaptive_norm_type,
                 skip_activation_norm, use_posenc_in_input_layer,
                 use_style_encoder):
        super(SPADEGenerator, self).__init__()
        self.use_style_encoder = use_style_encoder
        self.use_posenc_in_input_layer = use_posenc_in_input_layer
        self.out_image_small_side_size = out_image_small_side_size
        self.num_filters = num_filters
        padding = int(np.ceil((kernel_size - 1.0) / 2))
        nonlinearity = 'leakyrelu'
        activation_norm_type = 'spatially_adaptive'
        base_res2d_block = \
            functools.partial(Res2dBlock,
                              kernel_size=kernel_size,
                              padding=padding,
                              bias=[True, True, False],
                              weight_norm_type=weight_norm_type,
                              activation_norm_type=activation_norm_type,
                              activation_norm_params=activation_norm_params,
                              skip_activation_norm=skip_activation_norm,
                              nonlinearity=nonlinearity,
                              order='NACNAC')
        if self.use_style_encoder:
            self.fc_0 = LinearBlock(style_dims,
                                    2 * style_dims,
                                    weight_norm_type=weight_norm_type,
                                    nonlinearity='relu',
                                    order='CAN')
            self.fc_1 = LinearBlock(2 * style_dims,
                                    2 * style_dims,
                                    weight_norm_type=weight_norm_type,
                                    nonlinearity='relu',
                                    order='CAN')

            adaptive_norm_params = types.SimpleNamespace()
            if not hasattr(adaptive_norm_params, 'cond_dims'):
                setattr(adaptive_norm_params, 'cond_dims', 2 * style_dims)
            if not hasattr(adaptive_norm_params, 'activation_norm_type'):
                setattr(adaptive_norm_params, 'activation_norm_type',
                        global_adaptive_norm_type)
            if not hasattr(adaptive_norm_params, 'weight_norm_type'):
                setattr(adaptive_norm_params, 'weight_norm_type',
                        activation_norm_params.weight_norm_type)
            if not hasattr(adaptive_norm_params, 'separate_projection'):
                setattr(adaptive_norm_params, 'separate_projection',
                        activation_norm_params.separate_projection)
            adaptive_norm_params.activation_norm_params = \
                types.SimpleNamespace()
            setattr(adaptive_norm_params.activation_norm_params, 'affine',
                    activation_norm_params.activation_norm_params.affine)
            base_cbn2d_block = \
                functools.partial(Conv2dBlock,
                                  kernel_size=kernel_size,
                                  stride=1,
                                  padding=padding,
                                  bias=True,
                                  weight_norm_type=weight_norm_type,
                                  activation_norm_type='adaptive',
                                  activation_norm_params=adaptive_norm_params,
                                  nonlinearity=nonlinearity,
                                  order='NAC')
        else:
            base_conv2d_block = \
                functools.partial(Conv2dBlock,
                                  kernel_size=kernel_size,
                                  stride=1,
                                  padding=padding,
                                  bias=True,
                                  weight_norm_type=weight_norm_type,
                                  nonlinearity=nonlinearity,
                                  order='NAC')
        in_num_labels = num_labels
        in_num_labels += 2 if self.use_posenc_in_input_layer else 0
        self.head_0 = Conv2dBlock(in_num_labels,
                                  8 * num_filters,
                                  kernel_size=kernel_size,
                                  stride=1,
                                  padding=padding,
                                  weight_norm_type=weight_norm_type,
                                  activation_norm_type='none',
                                  nonlinearity=nonlinearity)
        if self.use_style_encoder:
            self.cbn_head_0 = base_cbn2d_block(8 * num_filters,
                                               16 * num_filters)
        else:
            self.conv_head_0 = base_conv2d_block(8 * num_filters,
                                                 16 * num_filters)
        self.head_1 = base_res2d_block(16 * num_filters, 16 * num_filters)
        self.head_2 = base_res2d_block(16 * num_filters, 16 * num_filters)

        self.up_0a = base_res2d_block(16 * num_filters, 8 * num_filters)
        if self.use_style_encoder:
            self.cbn_up_0a = base_cbn2d_block(8 * num_filters, 8 * num_filters)
        else:
            self.conv_up_0a = base_conv2d_block(8 * num_filters,
                                                8 * num_filters)
        self.up_0b = base_res2d_block(8 * num_filters, 8 * num_filters)

        self.up_1a = base_res2d_block(8 * num_filters, 4 * num_filters)
        if self.use_style_encoder:
            self.cbn_up_1a = base_cbn2d_block(4 * num_filters, 4 * num_filters)
        else:
            self.conv_up_1a = base_conv2d_block(4 * num_filters,
                                                4 * num_filters)
        self.up_1b = base_res2d_block(4 * num_filters, 4 * num_filters)
        self.up_2a = base_res2d_block(4 * num_filters, 4 * num_filters)
        if self.use_style_encoder:
            self.cbn_up_2a = base_cbn2d_block(4 * num_filters, 4 * num_filters)
        else:
            self.conv_up_2a = base_conv2d_block(4 * num_filters,
                                                4 * num_filters)
        self.up_2b = base_res2d_block(4 * num_filters, 2 * num_filters)
        self.conv_img256 = Conv2dBlock(2 * num_filters,
                                       image_channels,
                                       5,
                                       stride=1,
                                       padding=2,
                                       weight_norm_type=weight_norm_type,
                                       activation_norm_type='none',
                                       nonlinearity=nonlinearity,
                                       order='ANC')
        self.base = 16
        if self.out_image_small_side_size == 512:
            self.up_3a = base_res2d_block(2 * num_filters, 1 * num_filters)
            self.up_3b = base_res2d_block(1 * num_filters, 1 * num_filters)
            self.conv_img512 = Conv2dBlock(1 * num_filters,
                                           image_channels,
                                           5,
                                           stride=1,
                                           padding=2,
                                           weight_norm_type=weight_norm_type,
                                           activation_norm_type='none',
                                           nonlinearity=nonlinearity,
                                           order='ANC')
            self.base = 32
        if self.out_image_small_side_size == 1024:
            self.up_3a = base_res2d_block(2 * num_filters, 1 * num_filters)
            self.up_3b = base_res2d_block(1 * num_filters, 1 * num_filters)
            self.up_4a = base_res2d_block(num_filters, num_filters // 2)
            self.up_4b = base_res2d_block(num_filters // 2, num_filters // 2)
            self.conv_img1024 = Conv2dBlock(num_filters // 2,
                                            image_channels,
                                            5,
                                            stride=1,
                                            padding=2,
                                            weight_norm_type=weight_norm_type,
                                            activation_norm_type='none',
                                            nonlinearity=nonlinearity,
                                            order='ANC')
            self.base = 64
        if self.out_image_small_side_size != 256 and \
                self.out_image_small_side_size != \
                512 and self.out_image_small_side_size != 1024:
            raise ValueError('Generation image size (%d, %d) not supported' %
                             (self.out_image_small_side_size,
                              self.out_image_small_side_size))
        self.nearest_upsample2x = NearestUpsample(scale_factor=2,
                                                  mode='nearest')
        xv, yv = torch.meshgrid(
            [torch.arange(-1, 1.1, 2. / 15),
             torch.arange(-1, 1.1, 2. / 15)])
        self.xy = torch.cat((xv.unsqueeze(0), yv.unsqueeze(0)), 0).unsqueeze(0)
        self.xy = self.xy.cuda()