Ejemplo n.º 1
0
    def init_temporal_network(self, ):
        flow_cfg = self.flow_cfg
        data_cfg = self.data_cfg
        emb_cfg = self.emb_cfg
        gen_cfg = self.gen_cfg

        print("initialize temporal network")
        self.sep_prev_flownet = flow_cfg.sep_prev_flow or (
            self.num_frames_G != 2) or not flow_cfg.warp_ref
        if self.sep_prev_flownet:  # False
            self.flow_network_temp = FlowGenerator(flow_cfg, data_cfg,
                                                   self.num_frames_G)
        else:
            self.flow_network_temp = self.flow_network_ref

        self.sep_prev_embedding = emb_cfg.sep_warp_embed or not flow_cfg.warp_ref
        if self.sep_prev_embedding:  # True
            num_img_channels = data_utils.get_paired_input_image_channel_number(
                self.data_cfg)
            self.prev_image_embedding = LabelEmbedding(gen_cfg, emb_cfg,
                                                       data_cfg,
                                                       num_img_channels + 1)
        else:
            self.prev_image_embedding = self.ref_image_embedding

        self.flow_temp_is_initalized = True

        self.temporal_initialized = True
Ejemplo n.º 2
0
    def __init__(self, dis_cfg, data_cfg):
        super().__init__()
        self.data_cfg = data_cfg
        num_input_channels = get_paired_input_label_channel_number(data_cfg) # 6
        if num_input_channels == 0:
            num_input_channels = getattr(data_cfg, 'label_channels', 1)
        num_img_channels = get_paired_input_image_channel_number(data_cfg) # 3
        self.num_frames_D = data_cfg.num_frames_D # 2
        self.num_scales = get_nested_attr(dis_cfg, "temporal.num_scales", 0)
        num_netD_input_channels = (num_input_channels + num_img_channels)
        self.use_few_shot = 'few_shot' in data_cfg.type # True
        if self.use_few_shot:
            num_netD_input_channels *= 2
        self.net_D = MultiPatchDiscriminator(dis_cfg.image, num_netD_input_channels)

        self.add_dis_cfg = getattr(dis_cfg, 'additional_discriminators', None)
        if self.add_dis_cfg is not None:
            for name in self.add_dis_cfg:
                add_dis_cfg = self.add_dis_cfg[name]
                num_ch = num_img_channels * (2 if self.use_few_shot else 1)
                self.add_sublayer('net_D_' + name, MultiPatchDiscriminator(add_dis_cfg, num_ch))

        # Temporal discriminator.
        self.num_netDT_input_channels = num_img_channels * self.num_frames_D
        for n in range(self.num_scales):
            self.add_sublayer('net_DT%d' % n, MultiPatchDiscriminator(dis_cfg.temporal, self.num_netDT_input_channels))
        
        self.has_fg = getattr(data_cfg, 'has_foreground', False)
Ejemplo n.º 3
0
    def __init__(self, dis_cfg, data_cfg):
        super(Discriminator, self).__init__()
        print('Multi-resolution patch discriminator initialization.')
        # We assume the first datum is the ground truth image.
        image_channels = get_paired_input_image_channel_number(data_cfg) # 3
        # Calculate number of channels in the input label. 
        num_labels = get_paired_input_label_channel_number(data_cfg) # 6

        # Build the discriminator.
        kernel_size = getattr(dis_cfg, 'kernel_size', 3)
        num_filters = getattr(dis_cfg, 'num_filters', 128)
        max_num_filters = getattr(dis_cfg, 'max-num_filters', 512)
        num_discriminators = getattr(dis_cfg, 'num_discriminators', 2)
        num_layers = getattr(dis_cfg, 'num_layers', 5)
        activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none')
        weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral')
        print('\tBase filter number: %d' % num_filters)
        print('\tNumber of discriminators: %d' % num_discriminators)
        print('\tNumber of layers in a discriminator: %d' % num_layers)
        print('\tWeight norm type: %s' % weight_norm_type)
        num_input_channels = image_channels + num_labels
        self.model = MultiResPatchDiscriminator(num_discriminators, kernel_size,
            num_input_channels, num_filters, num_layers, max_num_filters, activation_norm_type, weight_norm_type)
        print("Done with the Multi-resolution patch discriminator initialization.")
    def __init__(self, gen_cfg, data_cfg):
        super().__init__()

        hyper_cfg = gen_cfg.hyper

        num_filters = gen_cfg.num_filters  # 32
        self.max_num_filters = gen_cfg.max_num_filters  # 1024
        self.num_downsamples = num_downsamples = gen_cfg.num_downsamples  # 5
        self.num_filters_each_layer = num_filters_each_layer = \
            [min(self.max_num_filters, num_filters * (2 ** i)) for i in range(num_downsamples + 2)]

        kernel_size = getattr(gen_cfg.activation_norm_params, 'kernel_size', 1)
        activation_norm_type = getattr(hyper_cfg, 'activation_norm_type',
                                       'instance')  # instance
        weight_norm_type = getattr(hyper_cfg, 'weight_norm_type', 'spectral')

        self.concat_ref_label = 'concat' in hyper_cfg.method_to_use_ref_labels
        self.mul_ref_label = 'mul' in hyper_cfg.method_to_use_ref_labels

        num_input_channels = data_utils.get_paired_input_label_channel_number(
            data_cfg)  # 6: densepose + openpose
        if num_input_channels == 0:
            num_input_channels = getattr(data_cfg, 'label_channels', 1)
        elif misc_utils.get_nested_attr(data_cfg, 'for_pose_dataset.pose_type',
                                        'both') == 'open':
            num_input_channels -= 3
        data_cfg.num_input_channels = num_input_channels
        num_img_channels = data_utils.get_paired_input_image_channel_number(
            data_cfg)  # 3
        num_ref_channels = num_img_channels + (
            num_input_channels if self.concat_ref_label else 0
        )  # 3 for mul_ref

        conv_2d_block = partial(
            Conv2dBlock,
            kernel_size=kernel_size,
            padding=(kernel_size // 2),
            weight_norm_type=weight_norm_type,
            activation_norm_type=activation_norm_type,
            nonlinearity='leakyrelu',
        )
        self.ref_img_first = conv_2d_block(num_ref_channels, num_filters)
        if self.mul_ref_label:
            self.ref_label_first = conv_2d_block(num_input_channels,
                                                 num_filters)

        for i in range(num_downsamples):
            in_ch, out_ch = num_filters_each_layer[i], num_filters_each_layer[
                i + 1]
            self.add_sublayer('ref_img_down_%d' % i,
                              conv_2d_block(in_ch, out_ch, stride=2))
            self.add_sublayer('ref_img_up_%d' % i,
                              conv_2d_block(out_ch, in_ch))
            if self.mul_ref_label:
                self.add_sublayer('ref_label_down_%d' % i,
                                  conv_2d_block(in_ch, out_ch, stride=2))
                self.add_sublayer('ref_label_up_%d' % i,
                                  conv_2d_block(out_ch, in_ch))

        if hasattr(hyper_cfg, 'attention'):
            self.num_downsample_atn = misc_utils.get_and_setattr(
                hyper_cfg.attention, 'num_downsamples', 2)  # 2
            if data_cfg.initial_few_shot_K > 1:
                self.attention_module = AttentionModule(
                    hyper_cfg, data_cfg, conv_2d_block, num_filters_each_layer)
        else:
            self.num_downsample_atn = 0
Ejemplo n.º 5
0
    def __init__(self, gen_cfg, data_cfg):
        super().__init__()
        self.gen_cfg = gen_cfg
        self.data_cfg = data_cfg
        self.flow_cfg = flow_cfg = gen_cfg.flow
        self.emb_cfg = emb_cfg = flow_cfg.multi_spade_combine.embed
        hyper_cfg = gen_cfg.hyper

        self.use_hyper_embed = hyper_cfg.is_hyper_embed  # True
        self.num_frames_G = data_cfg.num_frames_G

        num_img_channels = data_utils.get_paired_input_image_channel_number(
            data_cfg)

        num_input_channels = data_utils.get_paired_input_label_channel_number(
            data_cfg)
        if num_input_channels == 0:
            num_input_channels = getattr(data_cfg, 'label_channels', 1)
        elif misc_utils.get_nested_attr(data_cfg, 'for_pose_dataset.pose_type',
                                        'both') == 'open':
            num_input_channels -= 3

        # Number of hyper layers
        self.num_multi_spade_layers = getattr(flow_cfg.multi_spade_combine,
                                              'num_layers', 3)

        # Whether to generate raw output for additional losses.
        self.generate_raw_output = getattr(flow_cfg, 'generate_raw_output',
                                           False)

        # For pose dataset.
        self.is_pose_data = hasattr(data_cfg, 'for_pose_dataset')
        if self.is_pose_data:
            pose_cfg = data_cfg.for_pose_dataset
            self.pose_type = getattr(pose_cfg, 'pose_type', 'both')
            self.remove_face_labels = getattr(pose_cfg, 'remove_face_labels',
                                              False)

        self.main_generator = Generator(gen_cfg, data_cfg)
        self.reference_encoder = ReferenceEncoder(gen_cfg, data_cfg)
        self.weight_generator = WeightGenerator(gen_cfg, data_cfg)
        self.label_embedding = LabelEmbedding(gen_cfg,
                                              gen_cfg.embed,
                                              data_cfg,
                                              num_input_channels,
                                              num_hyper_layers=-1)

        # Flow estimation module.
        # Whether to warp reference image and combine with the synthesized.
        self.warp_ref = getattr(flow_cfg, 'warp_ref', True)  # True
        if self.warp_ref:
            self.flow_network_ref = FlowGenerator(flow_cfg, data_cfg, 2)
            self.ref_image_embedding = LabelEmbedding(gen_cfg, emb_cfg,
                                                      data_cfg,
                                                      num_img_channels + 1)

        # At beginning of training, only train an image generator.
        # When starting training multiple frames, initialize the flow network.
        self.temporal_initialized = False
        if getattr(gen_cfg, 'init_temporal', False):
            self.init_temporal_network()
Ejemplo n.º 6
0
    def __init__(self, gen_cfg, data_cfg):
        super().__init__()
        self.gen_cfg = gen_cfg
        self.data_cfg = data_cfg
        self.num_frames_G = data_cfg.num_frames_G # 2
        self.flow_cfg = flow_cfg = gen_cfg.flow

        num_img_channels = data_utils.get_paired_input_image_channel_number(data_cfg)
        self.num_downsamples = num_downsamples = misc_utils.get_and_setattr(gen_cfg, 'num_downsamples', 5)
        conv_kernel_size = misc_utils.get_and_setattr(gen_cfg, 'kernel_size', 3)
        num_filters = misc_utils.get_and_setattr(gen_cfg, 'num_filters', 32)
        max_num_filters = misc_utils.get_and_setattr(gen_cfg, 'max_num_filters', 1024)
        self.max_num_filters = gen_cfg.max_num_filters = min(max_num_filters, num_filters * (2 ** num_downsamples))

        # Get number of filters at each layer in the main branch
        num_filters_each_layer = [min(self.max_num_filters, num_filters * (2 ** i)) for i in range(num_downsamples + 2)]

        # Hyper normalization / convolution.
        hyper_cfg = gen_cfg.hyper
        # Use adaptive weight generation for SPADE
        self.use_hyper_spade = hyper_cfg.is_hyper_spade # True
        # Use adaptive for convolutional layers in the main branch.
        self.use_hyper_conv = hyper_cfg.is_hyper_conv # True
        # Number of hyper layers.
        self.num_hyper_layers = getattr(hyper_cfg, 'num_hyper_layers', 4)
        if self.num_hyper_layers == -1:
            self.num_hyper_layers = num_downsamples
        gen_cfg.hyper.num_hyper_layers = self.num_hyper_layers

        # Number of layers to perform multi-spade combine. 
        self.num_multi_spade_layers = getattr(flow_cfg.multi_spade_combine, 'num_layers', 3)

        # Whether to generate raw output for additional losses.
        self.generate_raw_output = getattr(flow_cfg, 'generate_raw_output', False)

        # Main branch image generation.
        padding = conv_kernel_size // 2
        activation_norm_type = misc_utils.get_and_setattr(gen_cfg, 'activation_norm_type', 'sync_batch')
        weight_norm_type = misc_utils.get_and_setattr(gen_cfg, 'weight_norm_type', 'spectral')
        activation_norm_params = misc_utils.get_and_setattr(gen_cfg, 'activation_norm_params', None) # spatially_adaptive

        spade_in_channels = [] # Input channel size in SPADE module.
        for i in range(num_downsamples + 1):
            spade_in_channels += [[num_filters_each_layer[i]]] \
                if i >= self.num_multi_spade_layers else [[num_filters_each_layer[i]] * 3]
        
        order = getattr(gen_cfg.hyper, 'hyper_block_order', 'NAC')
        for i in reversed(range(num_downsamples + 1)): # 5 -> 0
            activation_norm_params.cond_dims = spade_in_channels[i]
            is_hyper_conv = self.use_hyper_conv and i < self.num_hyper_layers
            is_hyper_norm = self.use_hyper_spade and i < self.num_hyper_layers

            self.add_sublayer('up_%d' % i, HyperRes2dBlock(
                num_filters_each_layer[i + 1], num_filters_each_layer[i], conv_kernel_size, padding=padding,
                weight_norm_type=weight_norm_type, activation_norm_type=activation_norm_type,
                activation_norm_params=activation_norm_params, order=order*2, 
                is_hyper_conv=is_hyper_conv, is_hyper_norm=is_hyper_norm))
        
        self.conv_img = Conv2dBlock(num_filters, num_img_channels, conv_kernel_size, padding=padding,
            nonlinearity='leakyrelu', order='AC')
        self.upsample = partial(L.image_resize, scale=2)
Ejemplo n.º 7
0
    def __init__(self, flow_cfg, data_cfg, num_frames):
        super().__init__()
        num_input_channels = data_cfg.num_input_channels  # 6
        if num_input_channels == 0:
            num_input_channels = 1
        num_prev_img_channels = data_utils.get_paired_input_image_channel_number(
            data_cfg)  # 3
        num_downsamples = getattr(flow_cfg, 'num_downsamples', 3)
        kernel_size = getattr(flow_cfg, 'kernel_size', 3)
        padding = kernel_size // 2
        num_blocks = getattr(flow_cfg, 'num_blocks', 6)
        num_filters = getattr(flow_cfg, 'num_filters', 32)
        max_num_filters = getattr(flow_cfg, 'max_num_filters', 1024)
        num_filters_each_layer = [
            min(max_num_filters, num_filters * (2**i))
            for i in range(num_downsamples + 1)
        ]

        self.flow_output_multiplier = getattr(flow_cfg,
                                              'flow_output_multiplier', 20)
        self.sep_up_mask = getattr(flow_cfg, 'sep_up_mask', False)
        activation_norm_type = getattr(flow_cfg, 'activation_norm_type',
                                       'sync_batch')
        weight_norm_type = getattr(flow_cfg, 'weight_norm_type', 'spectral')
        base_conv_block = partial(Conv2dBlock,
                                  kernel_size=kernel_size,
                                  padding=padding,
                                  weight_norm_type=weight_norm_type,
                                  activation_norm_type=activation_norm_type,
                                  nonlinearity='leakyrelu')

        num_input_channels = num_input_channels * num_frames + num_prev_img_channels * (
            num_frames - 1)

        # First layer.
        down_flow = [('0', base_conv_block(num_input_channels, num_filters))]

        # Downsamples.
        for i in range(num_downsamples):
            down_flow += [(str(i + 1),
                           base_conv_block(num_filters_each_layer[i],
                                           num_filters_each_layer[i + 1],
                                           stride=2))]

        # Resnet blocks.
        res_flow = []
        ch = num_filters_each_layer[num_downsamples]
        for i in range(num_blocks):
            res_flow += [(str(i),
                          Res2dBlock(ch,
                                     ch,
                                     kernel_size,
                                     padding=padding,
                                     weight_norm_type=weight_norm_type,
                                     activation_norm_type=activation_norm_type,
                                     order='NACNAC'))]

        # Upsamples.
        up_flow = []
        for i in reversed(range(num_downsamples)):
            up_flow += [(str(
                (num_downsamples - 1 - i) * 2), Upsample(scale=2)),
                        (str((num_downsamples - 1 - i) * 2 + 1),
                         base_conv_block(num_filters_each_layer[i + 1],
                                         num_filters_each_layer[i]))]

        conv_flow = [('0',
                      Conv2dBlock(num_filters, 2, kernel_size,
                                  padding=padding))]
        conv_mask = [('0',
                      Conv2dBlock(num_filters,
                                  1,
                                  kernel_size,
                                  padding=padding,
                                  nonlinearity='sigmoid'))]

        self.down_flow = dg.Sequential(*down_flow)
        self.res_flow = dg.Sequential(*res_flow)
        self.up_flow = dg.Sequential(*up_flow)

        if self.sep_up_mask:
            self.up_mask = dg.Sequential(*copy.deepcopy(up_flow))
        self.conv_flow = dg.Sequential(*conv_flow)
        self.conv_mask = dg.Sequential(*conv_mask)