def __init__(self, norm_ref, ch, n_shot, n_downsample_G):
        super().__init__()

        # parameters for model
        for i in range(n_downsample_G):
            ch_in, ch_out = ch[i], ch[i + 1]
            setattr(self, 'ref_img_up_%d' % i,
                    SPADEConv2d(ch_out, ch_in, norm=norm_ref))
            setattr(self, 'ref_label_up_%d' % i,
                    SPADEConv2d(ch_out, ch_in, norm=norm_ref))

        # other parameter
        self.n_downsample_G = n_downsample_G
        self.n_shot = n_shot
    def __init__(self, norm_ref, input_nc, nf, ch, n_shot, n_downsample_A):
        super().__init__()

        # parameters for model
        self.atn_query_first = SPADEConv2d(input_nc, nf, norm=norm_ref)
        self.atn_key_first = SPADEConv2d(input_nc, nf, norm=norm_ref)
        for i in range(n_downsample_A):
            f_in, f_out = ch[i], ch[i + 1]
            setattr(self, 'atn_key_%d' % i,
                    SPADEConv2d(f_in, f_out, stride=2, norm=norm_ref))
            setattr(self, 'atn_query_%d' % i,
                    SPADEConv2d(f_in, f_out, stride=2, norm=norm_ref))

        # other parameters
        self.n_shot = n_shot
        self.n_downsample_A = n_downsample_A
    def __init__(self,
                 norm_ref,
                 nf,
                 ch,
                 n_shot,
                 n_downsample_G,
                 n_downsample_A,
                 isTrain,
                 ref_nc=3):
        super().__init__()

        # parameters for model
        self.conv1 = SPADEConv2d(ref_nc, nf, norm=norm_ref)

        for i in range(n_downsample_G):
            ch_in, ch_out = ch[i], ch[i + 1]
            setattr(self, 'ref_down_%d' % i,
                    SPADEConv2d(ch_in, ch_out, stride=2, norm=norm_ref))

        # other parameters
        self.isTrain = isTrain
        self.n_shot = n_shot
        self.n_downsample_G = n_downsample_G
        self.n_downsample_A = n_downsample_A
    def __init__(self,
                 norm_ref,
                 nf,
                 ch,
                 n_shot,
                 n_downsample_G,
                 n_downsample_A,
                 isTrain,
                 ref_nc=3,
                 lmark_nc=1):
        super().__init__()

        # parameters for model
        self.conv1 = SPADEConv2d(ref_nc, nf, norm=norm_ref)
        self.conv2 = SPADEConv2d(lmark_nc, nf, norm=norm_ref)

        for i in range(n_downsample_G):
            ch_in, ch_out = ch[i], ch[i + 1]
            setattr(self, 'ref_down_img_%d' % i,
                    SPADEConv2d(ch_in, ch_out, stride=2, norm=norm_ref))
            setattr(self, 'ref_down_lmark_%d' % i,
                    SPADEConv2d(ch_in, ch_out, stride=2, norm=norm_ref))
            if n_shot > 1 and i == n_downsample_A - 1:
                self.fusion1 = SPADEConv2d(ch_out * 2, ch_out, norm=norm_ref)
                self.fusion2 = SPADEConv2d(ch_out * 2, ch_out, norm=norm_ref)
                self.fusion = SPADEResnetBlock(ch_out * 2,
                                               ch_out,
                                               norm=norm_ref)
                self.atten1 = SPADEConv2d(ch_out * n_shot,
                                          ch_out,
                                          norm=norm_ref)
                self.atten2 = SPADEConv2d(ch_out * n_shot,
                                          ch_out,
                                          norm=norm_ref)

        # other parameters
        self.isTrain = isTrain
        self.n_shot = n_shot
        self.n_downsample_G = n_downsample_G
        self.n_downsample_A = n_downsample_A
Esempio n. 5
0
    def __init__(self, opt):
        super().__init__()
        ########################### define params ##########################
        self.opt = opt
        self.n_downsample_G = n_downsample_G = opt.n_downsample_G  # number of downsamplings in generator
        self.n_downsample_A = n_downsample_A = opt.n_downsample_A  # number of downsamplings in attention module
        self.nf = nf = opt.ngf  # base channel size
        self.nf_max = nf_max = min(1024, nf * (2**n_downsample_G))
        self.ch = ch = [
            min(nf_max, nf * (2**i)) for i in range(n_downsample_G + 2)
        ]

        ### SPADE
        self.norm = norm = opt.norm_G
        self.conv_ks = conv_ks = opt.conv_ks  # conv kernel size in main branch
        self.embed_ks = embed_ks = opt.embed_ks  # conv kernel size in embedding network
        self.spade_ks = spade_ks = opt.spade_ks  # conv kernel size in SPADE
        self.spade_combine = opt.spade_combine  # combine ref/prev frames with current using SPADE
        self.n_sc_layers = opt.n_sc_layers  # number of layers to perform spade combine
        ch_hidden = []  # hidden channel size in SPADE module
        for i in range(n_downsample_G + 1):
            ch_hidden += [[
                ch[i]
            ]] if not self.spade_combine or i >= self.n_sc_layers else [
                [ch[i]] * 3
            ]
        self.ch_hidden = ch_hidden

        ### adaptive SPADE / Convolution
        self.adap_spade = opt.adaptive_spade  # use adaptive weight generation for SPADE
        self.adap_embed = opt.adaptive_spade and not opt.no_adaptive_embed  # use adaptive for the label embedding network
        self.adap_conv = opt.adaptive_conv  # use adaptive for convolution layers in the main branch
        self.n_adaptive_layers = opt.n_adaptive_layers if opt.n_adaptive_layers != -1 else n_downsample_G  # number of adaptive layers

        # for reference image encoding
        self.concat_label_ref = 'concat' in opt.use_label_ref  # how to utilize the reference label map: concat | mul
        self.mul_label_ref = 'mul' in opt.use_label_ref
        self.sh_fix = self.sw_fix = 32  # output spatial size for adaptive pooling layer
        self.sw = opt.fineSize // (
            2**opt.n_downsample_G
        )  # output spatial size at the bottle neck of generator
        self.sh = int(self.sw / opt.aspect_ratio)

        # weight generation
        self.n_fc_layers = n_fc_layers = opt.n_fc_layers  # number of fc layers in weight generation

        ########################### define network ##########################
        norm_ref = norm.replace('spade', '')
        input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc
        ref_nc = opt.output_nc + (0 if not self.concat_label_ref else input_nc)
        self.ref_img_first = SPADEConv2d(ref_nc, nf, norm=norm_ref)
        if self.mul_label_ref:
            self.ref_label_first = SPADEConv2d(input_nc, nf, norm=norm_ref)
        ref_conv = SPADEConv2d if not opt.res_for_ref else SPADEResnetBlock

        ### reference image encoding
        for i in range(n_downsample_G):
            ch_in, ch_out = ch[i], ch[i + 1]
            setattr(self, 'ref_img_down_%d' % i,
                    ref_conv(ch_in, ch_out, stride=2, norm=norm_ref))
            setattr(self, 'ref_img_up_%d' % i,
                    ref_conv(ch_out, ch_in, norm=norm_ref))
            if self.mul_label_ref:
                setattr(self, 'ref_label_down_%d' % i,
                        ref_conv(ch_in, ch_out, stride=2, norm=norm_ref))
                setattr(self, 'ref_label_up_%d' % i,
                        ref_conv(ch_out, ch_in, norm=norm_ref))

        ### SPADE / main branch weight generation
        if self.adap_spade or self.adap_conv:
            for i in range(self.n_adaptive_layers):
                ch_in, ch_out = ch[i], ch[i + 1]
                conv_ks2 = conv_ks**2
                embed_ks2 = embed_ks**2
                spade_ks2 = spade_ks**2
                ch_h = ch_hidden[i][0]

                fc_names, fc_outs = [], []
                if self.adap_spade:
                    fc0_out = fcs_out = (ch_h * spade_ks2 + 1) * 2
                    fc1_out = (ch_h * spade_ks2 +
                               1) * (1 if ch_in != ch_out else 2)
                    fc_names += ['fc_spade_0', 'fc_spade_1', 'fc_spade_s']
                    fc_outs += [fc0_out, fc1_out, fcs_out]
                    if self.adap_embed:
                        fc_names += ['fc_spade_e']
                        fc_outs += [ch_in * embed_ks2 + 1]
                if self.adap_conv:
                    fc0_out = ch_out * conv_ks2 + 1
                    fc1_out = ch_in * conv_ks2 + 1
                    fcs_out = ch_out + 1
                    fc_names += ['fc_conv_0', 'fc_conv_1', 'fc_conv_s']
                    fc_outs += [fc0_out, fc1_out, fcs_out]

                for n, l in enumerate(fc_names):
                    fc_in = ch_out if self.mul_label_ref else self.sh_fix * self.sw_fix
                    fc_layer = [sn(nn.Linear(fc_in, ch_out))]
                    for k in range(1, n_fc_layers):
                        fc_layer += [sn(nn.Linear(ch_out, ch_out))]
                    fc_layer += [sn(nn.Linear(ch_out, fc_outs[n]))]
                    setattr(self, '%s_%d' % (l, i), nn.Sequential(*fc_layer))

        ### label embedding network
        self.label_embedding = LabelEmbedder(
            opt,
            input_nc,
            opt.netS,
            params_free_layers=(self.n_adaptive_layers
                                if self.adap_embed else 0))

        ### main branch layers
        for i in reversed(range(n_downsample_G + 1)):
            setattr(
                self, 'up_%d' % i,
                SPADEResnetBlock(
                    ch[i + 1],
                    ch[i],
                    norm=norm,
                    hidden_nc=ch_hidden[i],
                    conv_ks=conv_ks,
                    spade_ks=spade_ks,
                    conv_params_free=(self.adap_conv
                                      and i < self.n_adaptive_layers),
                    norm_params_free=(self.adap_spade
                                      and i < self.n_adaptive_layers)))

        self.conv_img = nn.Conv2d(nf, 3, kernel_size=3, padding=1)
        self.up = functools.partial(F.interpolate, scale_factor=2)

        ### for multiple reference images
        if opt.n_shot > 1:
            self.atn_query_first = SPADEConv2d(input_nc, nf, norm=norm_ref)
            self.atn_key_first = SPADEConv2d(input_nc, nf, norm=norm_ref)
            for i in range(n_downsample_A):
                f_in, f_out = ch[i], ch[i + 1]
                setattr(self, 'atn_key_%d' % i,
                        SPADEConv2d(f_in, f_out, stride=2, norm=norm_ref))
                setattr(self, 'atn_query_%d' % i,
                        SPADEConv2d(f_in, f_out, stride=2, norm=norm_ref))

        ### kld loss
        self.use_kld = opt.lambda_kld > 0
        self.z_dim = 256
        if self.use_kld:
            f_in = min(nf_max, nf * (2**n_downsample_G)) * self.sh * self.sw
            f_out = min(nf_max, nf * (2**n_downsample_G)) * self.sh * self.sw
            self.fc_mu_ref = nn.Linear(f_in, self.z_dim)
            self.fc_var_ref = nn.Linear(f_in, self.z_dim)
            self.fc = nn.Linear(self.z_dim, f_out)

        ### flow
        self.warp_prev = False  # whether to warp prev image (set when starting training multiple frames)
        self.warp_ref = opt.warp_ref and not opt.for_face  # whether to warp reference image and combine with the synthesized
        if self.warp_ref:
            self.flow_network_ref = FlowGenerator(opt, 2)
            if self.spade_combine:
                self.img_ref_embedding = LabelEmbedder(opt, opt.output_nc + 1,
                                                       opt.sc_arch)