예제 #1
0
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        nf = opt.ngf

        self.sw, self.sh = self.compute_latent_vector_size(opt)

        if opt.use_vae:
            # In case of VAE, we will sample from random z vector
            self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)
        else:
            # Otherwise, we make the network deterministic by starting with
            # downsampled segmentation map instead of random z
            self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)

        self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)

        self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
        self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt)

        self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt)
        self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt)
        self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt)
        self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt)

        final_nc = nf

        if opt.num_upsampling_layers == 'most':
            self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt)
            final_nc = nf // 2

        self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)

        self.up = nn.Upsample(scale_factor=2)
예제 #2
0
파일: generator.py 프로젝트: XianWu18/SMIS
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        nf = opt.ngf

        self.sw, self.sh = self.compute_latent_vector_size(opt)
        # print(self.opt.semantic_nc)
        self.fc = nn.Conv2d(opt.semantic_nc * 8,
                            16 * nf,
                            kernel_size=3,
                            padding=1,
                            groups=self.opt.semantic_nc)
        self.head_0 = SPADEV2ResnetBlock(16 * nf, 16 * nf, opt,
                                         self.opt.semantic_nc)

        self.G_middle_0 = SPADEV2ResnetBlock(16 * nf, 16 * nf, opt,
                                             self.opt.semantic_nc)
        self.G_middle_1 = SPADEV2ResnetBlock(16 * nf, 16 * nf, opt, 20)
        self.up_0 = SPADEV2ResnetBlock(16 * nf, 8 * nf, opt, 14)
        self.up_1 = SPADEV2ResnetBlock(8 * nf, 4 * nf, opt, 10)
        self.up_2 = SPADEV2ResnetBlock(4 * nf, 2 * nf, opt, 4)
        self.up_3 = SPADEV2ResnetBlock(2 * nf, 1 * nf, opt, 1)
        final_nc = nf

        if opt.num_upsampling_layers == 'most':
            self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt)
            final_nc = nf // 2

        self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)

        self.up = nn.Upsample(scale_factor=2)
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        nf = opt.ngf

        self.sw, self.sh = self.compute_latent_vector_size(opt)

        ic = 0 + (3 if 'warp' in self.opt.CBN_intype else 0) + (self.opt.semantic_nc if 'mask' in self.opt.CBN_intype else 0)
        self.fc = nn.Conv2d(ic, 16 * nf, 3, padding=1)
        if opt.eqlr_sn:
            self.fc = equal_lr(self.fc)

        self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)

        self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
        self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt)

        self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt)
        self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt)
        if opt.use_attention:
            self.attn = Attention(4 * nf, 'spectral' in opt.norm_G)
        self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt)
        self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt)

        final_nc = nf

        self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)
        self.up = nn.Upsample(scale_factor=2)
 def __init__(self, opt, ic, oc, size):
     super().__init__()
     self.opt = opt
     self.downsample = True if size == 256 else False
     nf = opt.ngf
     opt.spade_ic = ic
     if opt.warp_reverseG_s:
         self.backbone_0 = SPADEResnetBlock(4 * nf, 4 * nf, opt)
     else:
         self.backbone_0 = SPADEResnetBlock(4 * nf, 8 * nf, opt)
         self.backbone_1 = SPADEResnetBlock(8 * nf, 8 * nf, opt)
         self.backbone_2 = SPADEResnetBlock(8 * nf, 8 * nf, opt)
         self.backbone_3 = SPADEResnetBlock(8 * nf, 4 * nf, opt)
     self.backbone_4 = SPADEResnetBlock(4 * nf, 2 * nf, opt)
     self.backbone_5 = SPADEResnetBlock(2 * nf, nf, opt)
     del opt.spade_ic
     if self.downsample:
         kw = 3
         pw = int(np.ceil((kw - 1.0) / 2))
         ndf = opt.ngf
         norm_layer = get_nonspade_norm_layer(opt, opt.norm_E)
         self.layer1 = norm_layer(nn.Conv2d(ic, ndf, kw, stride=1, padding=pw))
         self.layer2 = norm_layer(nn.Conv2d(ndf * 1, ndf * 2, 4, stride=2, padding=pw))
         self.layer3 = norm_layer(nn.Conv2d(ndf * 2, ndf * 4, kw, stride=1, padding=pw))
         self.layer4 = norm_layer(nn.Conv2d(ndf * 4, ndf * 4, 4, stride=2, padding=pw))
         self.up = nn.Upsample(scale_factor=2)
     self.actvn = nn.LeakyReLU(0.2, False)
     self.conv_img = nn.Conv2d(nf, oc, 3, padding=1)
예제 #5
0
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        nf = opt.ngf

        self.sw, self.sh = self.compute_latent_vector_size(opt)

        if opt.use_vae:
            # In case of VAE, we will sample from random z vector
            self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)
        elif opt.use_encoder:
            # In case of encoder, we will encoder the image
            if self.opt.Image_encoder_mode == 'norm':
                self.fc = ImageEncoder(opt, self.sw, self.sh)
            elif self.opt.Image_encoder_mode == 'instance':
                self.fc = ImageEncoder2(opt, self.sw, self.sh)
            elif self.opt.Image_encoder_mode == 'partialconv':
                self.fc = ImageEncoder3(opt, self.sw, self.sh)
        else:
            # Otherwise, we make the network deterministic by starting with
            # downsampled segmentation map instead of random z
            if not opt.no_orientation:
                # self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1) # for mask input
                self.fc = nn.Conv2d(3, 16 * nf, 3,
                                    padding=1)  # for image input
            else:
                # self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1) # for mask input
                self.fc = nn.Conv2d(3, 16 * nf, 3,
                                    padding=1)  # for image input

        self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)

        self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
        self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt)

        self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt)
        self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt)
        self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt)
        self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt)

        final_nc = nf

        if opt.num_upsampling_layers == 'most':
            self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt)
            final_nc = nf // 2

        self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)

        self.up = nn.Upsample(scale_factor=2)

        if not self.opt.noise_background:
            self.backgroud_enc = BackgroundEncode(opt)
        else:
            self.backgroud_enc = BackgroundEncode2(opt)
예제 #6
0
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        nf = opt.ngf

        self.sw, self.sh = self.compute_latent_vector_size(opt)

        self.Zencoder = Zencoder(3, 512)

        self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)

        self.head_0 = SPADEResnetBlock(16 * nf,
                                       16 * nf,
                                       opt,
                                       Block_Name='head_0')

        self.G_middle_0 = SPADEResnetBlock(16 * nf,
                                           16 * nf,
                                           opt,
                                           Block_Name='G_middle_0')
        self.G_middle_1 = SPADEResnetBlock(16 * nf,
                                           16 * nf,
                                           opt,
                                           Block_Name='G_middle_1')

        self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt, Block_Name='up_0')
        self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt, Block_Name='up_1')
        self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt, Block_Name='up_2')
        self.up_3 = SPADEResnetBlock(2 * nf,
                                     1 * nf,
                                     opt,
                                     Block_Name='up_3',
                                     use_rgb=False)

        final_nc = nf

        if opt.num_upsampling_layers == 'most':
            self.up_4 = SPADEResnetBlock(1 * nf,
                                         nf // 2,
                                         opt,
                                         Block_Name='up_4')
            final_nc = nf // 2

        self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)

        self.up = nn.Upsample(scale_factor=2)
예제 #7
0
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        nf = opt.ngf

        self.sw, self.sh = self.compute_latent_vector_size(opt)

        if opt.use_vae:
            # In case of VAE, we will sample from random z vector
            self.fc_0 = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)
        else:
            # Otherwise, we make the network deterministic by starting with
            # downsampled segmentation map instead of random z
            self.fc_0 = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)

        # embedding and fc layer for conditioning on the scanner class as well
        self.embedding_0 = nn.Embedding(opt.condition_nc, 1024)
        self.fc_1 = nn.Linear(1024, 16 * nf * self.sh * self.sw)

        # multiplying the in dimensions of the head SPADE block by 2 because of the scanner vector
        self.head_0 = SPADEResnetBlock(16 * nf * 2, 16 * nf, opt)

        self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
        self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt)

        self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt)
        self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt)
        self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt)
        self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt)

        final_nc = nf

        if opt.num_upsampling_layers == 'most':
            self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt)
            final_nc = nf // 2

        self.conv_img = nn.Conv2d(final_nc, opt.output_nc, 3, padding=1)

        self.up = nn.Upsample(scale_factor=2)
    def __init__(self, opt):
        # TODO: kernel=4, concat noise, or change architecture to vgg feature pyramid
        super().__init__()
        self.opt = opt
        kw = 3
        pw = int(np.ceil((kw - 1.0) / 2))
        ndf = opt.ngf
        norm_layer = get_nonspade_norm_layer(opt, opt.norm_E)
        self.layer1 = norm_layer(nn.Conv2d(opt.spade_ic, ndf, kw, stride=1, padding=pw))
        self.layer2 = norm_layer(nn.Conv2d(ndf * 1, ndf * 2, opt.adaptor_kernel, stride=2, padding=pw))
        self.layer3 = norm_layer(nn.Conv2d(ndf * 2, ndf * 4, kw, stride=1, padding=pw))
        if opt.warp_stride == 2:
            self.layer4 = norm_layer(nn.Conv2d(ndf * 4, ndf * 8, kw, stride=1, padding=pw))
        else:
            self.layer4 = norm_layer(nn.Conv2d(ndf * 4, ndf * 8, opt.adaptor_kernel, stride=2, padding=pw))
        self.layer5 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=1, padding=pw))

        self.actvn = nn.LeakyReLU(0.2, False)
        self.opt = opt
        
        nf = opt.ngf

        self.head_0 = SPADEResnetBlock(8 * nf, 8 * nf, opt, use_se=opt.adaptor_se)
        if opt.adaptor_nonlocal:
            self.attn = Attention(8 * nf, False)
        self.G_middle_0 = SPADEResnetBlock(8 * nf, 8 * nf, opt, use_se=opt.adaptor_se)
        self.G_middle_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt, use_se=opt.adaptor_se)

        if opt.adaptor_res_deeper:
            self.deeper0 = SPADEResnetBlock(4 * nf, 4 * nf, opt)
            if opt.dilation_conv:
                self.deeper1 = SPADEResnetBlock(4 * nf, 4 * nf, opt, dilation=2)
                self.deeper2 = SPADEResnetBlock(4 * nf, 4 * nf, opt, dilation=4)
                self.degridding0 = norm_layer(nn.Conv2d(ndf * 4, ndf * 4, 3, stride=1, padding=2, dilation=2))
                self.degridding1 = norm_layer(nn.Conv2d(ndf * 4, ndf * 4, 3, stride=1, padding=1))
            else:
                self.deeper1 = SPADEResnetBlock(4 * nf, 4 * nf, opt)
                self.deeper2 = SPADEResnetBlock(4 * nf, 4 * nf, opt)
    def __init__(self,
                 norm_ref,
                 nf,
                 ch,
                 n_shot,
                 n_downsample_G,
                 n_downsample_A,
                 isTrain,
                 ref_nc=3,
                 lmark_nc=1):
        super().__init__()

        # parameters for model
        self.conv1 = SPADEConv2d(ref_nc, nf, norm=norm_ref)
        self.conv2 = SPADEConv2d(lmark_nc, nf, norm=norm_ref)

        for i in range(n_downsample_G):
            ch_in, ch_out = ch[i], ch[i + 1]
            setattr(self, 'ref_down_img_%d' % i,
                    SPADEConv2d(ch_in, ch_out, stride=2, norm=norm_ref))
            setattr(self, 'ref_down_lmark_%d' % i,
                    SPADEConv2d(ch_in, ch_out, stride=2, norm=norm_ref))
            if n_shot > 1 and i == n_downsample_A - 1:
                self.fusion1 = SPADEConv2d(ch_out * 2, ch_out, norm=norm_ref)
                self.fusion2 = SPADEConv2d(ch_out * 2, ch_out, norm=norm_ref)
                self.fusion = SPADEResnetBlock(ch_out * 2,
                                               ch_out,
                                               norm=norm_ref)
                self.atten1 = SPADEConv2d(ch_out * n_shot,
                                          ch_out,
                                          norm=norm_ref)
                self.atten2 = SPADEConv2d(ch_out * n_shot,
                                          ch_out,
                                          norm=norm_ref)

        # other parameters
        self.isTrain = isTrain
        self.n_shot = n_shot
        self.n_downsample_G = n_downsample_G
        self.n_downsample_A = n_downsample_A
예제 #10
0
 def __init__(self, opt, n_frames_G):
     super().__init__()
     self.opt = opt
     input_nc = (opt.label_nc if opt.label_nc != 0 else opt.input_nc) * n_frames_G
     input_nc += opt.output_nc * (n_frames_G - 1)        
     nf = opt.nff
     n_blocks = opt.n_blocks_F
     n_downsample_F = opt.n_downsample_F
     self.flow_multiplier = opt.flow_multiplier        
     nf_max = 1024
     ch = [min(nf_max, nf * (2 ** i)) for i in range(n_downsample_F + 1)]
             
     norm = opt.norm_F
     norm_layer = get_nonspade_norm_layer(opt, norm)
     activation = nn.LeakyReLU(0.2)
     
     down_flow = [norm_layer(nn.Conv2d(input_nc, nf, kernel_size=3, padding=1)), activation]        
     for i in range(n_downsample_F):            
         down_flow += [norm_layer(nn.Conv2d(ch[i], ch[i+1], kernel_size=3, padding=1, stride=2)), activation]            
                
     ### resnet blocks
     res_flow = []
     ch_r = min(nf_max, nf * (2**n_downsample_F))        
     for i in range(n_blocks):
         res_flow += [SPADEResnetBlock(ch_r, ch_r, norm=norm)]
 
     ### upsample
     up_flow = []                         
     for i in reversed(range(n_downsample_F)):
         up_flow += [nn.Upsample(scale_factor=2), norm_layer(nn.Conv2d(ch[i+1], ch[i], kernel_size=3, padding=1)), activation]
                           
     conv_flow = [nn.Conv2d(nf, 2, kernel_size=3, padding=1)]
     conv_mask = [nn.Conv2d(nf, 1, kernel_size=3, padding=1), nn.Sigmoid()] 
   
     self.down_flow = nn.Sequential(*down_flow)        
     self.res_flow = nn.Sequential(*res_flow)                                            
     self.up_flow = nn.Sequential(*up_flow)
     self.conv_flow = nn.Sequential(*conv_flow)        
     self.conv_mask = nn.Sequential(*conv_mask)
예제 #11
0
    def __init__(self, opt):
        super().__init__(opt)
        self.opt = opt
        nf = opt.ngf

        self.sw, self.sh, self.scale_ratio = self.compute_latent_vector_size(opt)

        if opt.use_vae:
            # In case of VAE, we will sample from random z vector
            self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)
        else:
            # Otherwise, we make the network deterministic by starting with
            # downsampled segmentation map instead of random z
            self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)

        self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, 16 * nf)

        self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, 16 * nf)
        self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt, 16 * nf)

        # 20200211 test 4x with only 3 stage

        self.ups = nn.ModuleList([
            SPADEResnetBlock(16 * nf, 8 * nf, opt, 8 * nf),
            SPADEResnetBlock(8 * nf, 4 * nf, opt, 4 * nf),
            SPADEResnetBlock(4 * nf, 2 * nf, opt, 2 * nf),
            SPADEResnetBlock(2 * nf, 1 * nf, opt, 1 * nf)  # here
            ])

        self.to_rgbs = nn.ModuleList([
            nn.Conv2d(8 * nf, 3, 3, padding=1),
            nn.Conv2d(4 * nf, 3, 3, padding=1),
            nn.Conv2d(2 * nf, 3, 3, padding=1),
            nn.Conv2d(1 * nf, 3, 3, padding=1)      # here
            ])

        self.up = nn.Upsample(scale_factor=2)
        self.encoder = ContentAdaptiveSuppresor(opt, self.sw, self.sh, self.scale_ratio)
예제 #12
0
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        nf = opt.ngf

        self.sw = 7
        self.sh = 7

        self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)

        self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
        self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt)

        self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt)
        self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt)
        self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt)
        self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt)

        final_nc = nf

        self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)

        self.up = nn.Upsample(scale_factor=2)
예제 #13
0
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        nf = opt.ngf  # of gen filters in first conv layer

        self.sw, self.sh = self.compute_latent_vector_size(opt)
        # print(self.sw, self.sh) 8, 4

        if opt.use_vae:  # False
            # In case of VAE, we will sample from random z vector
            self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)
        else:
            # Otherwise, we make the network deterministic by starting with
            # downsampled segmentation map instead of random z
            self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3,
                                padding=1)  # print(self.opt.semantic_nc) # 36

        self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)

        self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
        self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt)

        self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt)
        self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt)
        self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt)
        self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt)

        final_nc = nf

        if opt.num_upsampling_layers == 'most':  # opt.num_upsampling_layers: more
            self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt)
            final_nc = nf // 2

        self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)

        self.up = nn.Upsample(scale_factor=2)

        # local branch
        self.conv1 = nn.Conv2d(151, 64, 7, 1, 0)  # change
        self.conv1_norm = nn.InstanceNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, 3, 2, 1)
        self.conv2_norm = nn.InstanceNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, 3, 2, 1)
        self.conv3_norm = nn.InstanceNorm2d(256)
        self.conv4 = nn.Conv2d(256, 512, 3, 2, 1)
        self.conv4_norm = nn.InstanceNorm2d(512)
        self.conv5 = nn.Conv2d(512, 1024, 3, 2, 1)
        self.conv5_norm = nn.InstanceNorm2d(1024)

        self.resnet_blocks1 = resnet_block(256, 3, 1, 1)
        self.resnet_blocks1.weight_init(0, 0.02)
        self.resnet_blocks2 = resnet_block(256, 3, 1, 1)
        self.resnet_blocks2.weight_init(0, 0.02)
        self.resnet_blocks3 = resnet_block(256, 3, 1, 1)
        self.resnet_blocks3.weight_init(0, 0.02)
        self.resnet_blocks4 = resnet_block(256, 3, 1, 1)
        self.resnet_blocks4.weight_init(0, 0.02)
        self.resnet_blocks5 = resnet_block(256, 3, 1, 1)
        self.resnet_blocks5.weight_init(0, 0.02)
        self.resnet_blocks6 = resnet_block(256, 3, 1, 1)
        self.resnet_blocks6.weight_init(0, 0.02)
        self.resnet_blocks7 = resnet_block(256, 3, 1, 1)
        self.resnet_blocks7.weight_init(0, 0.02)
        self.resnet_blocks8 = resnet_block(256, 3, 1, 1)
        self.resnet_blocks8.weight_init(0, 0.02)
        self.resnet_blocks9 = resnet_block(256, 3, 1, 1)
        self.resnet_blocks9.weight_init(0, 0.02)

        self.deconv3_local = nn.ConvTranspose2d(256, 128, 3, 2, 1, 1)
        self.deconv3_norm_local = nn.InstanceNorm2d(128)
        self.deconv4_local = nn.ConvTranspose2d(128, 64, 3, 2, 1, 1)
        self.deconv4_norm_local = nn.InstanceNorm2d(64)

        self.deconv9 = nn.Conv2d(3 * 52, 3, 3, 1, 1)

        self.deconv5_0 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_1 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_2 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_3 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_4 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_5 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_6 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_7 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_8 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_9 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_10 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_11 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_12 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_13 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_14 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_15 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_16 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_17 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_18 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_19 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_20 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_21 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_22 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_23 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_24 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_25 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_26 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_27 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_28 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_29 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_30 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_31 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_32 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_33 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_34 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_35 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_36 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_37 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_38 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_39 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_40 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_41 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_42 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_43 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_44 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_45 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_46 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_47 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_48 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_49 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_50 = nn.Conv2d(64, 3, 7, 1, 0)
        self.deconv5_51 = nn.Conv2d(64, 3, 7, 1, 0)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        # self.fc1 = nn.Linear(64*256 * 512, 512)
        self.fc2 = nn.Linear(64, 51)

        self.deconv3_attention = nn.ConvTranspose2d(256, 128, 3, 2, 1, 1)
        self.deconv3_norm_attention = nn.InstanceNorm2d(128)
        self.deconv4_attention = nn.ConvTranspose2d(128, 64, 3, 2, 1, 1)
        self.deconv4_norm_attention = nn.InstanceNorm2d(64)
        self.deconv5_attention = nn.Conv2d(64, 2, 1, 1, 0)
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        nf = opt.ngf

        self.sw, self.sh = self.compute_latent_vector_size(opt)

        print("The size of the latent vector size is [%d,%d]" % (self.sw, self.sh))

        if opt.use_vae:
            # In case of VAE, we will sample from random z vector
            self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)
        else:
            # Otherwise, we make the network deterministic by starting with
            # downsampled segmentation map instead of random z
            if self.opt.no_parsing_map:
                self.fc = nn.Conv2d(3, 16 * nf, 3, padding=1)
            else:
                self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)

        if self.opt.injection_layer == "all" or self.opt.injection_layer == "1":
            self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
        else:
            self.head_0 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)

        if self.opt.injection_layer == "all" or self.opt.injection_layer == "2":
            self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
            self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt)

        else:
            self.G_middle_0 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)
            self.G_middle_1 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)

        if self.opt.injection_layer == "all" or self.opt.injection_layer == "3":
            self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt)
        else:
            self.up_0 = SPADEResnetBlock_non_spade(16 * nf, 8 * nf, opt)

        if self.opt.injection_layer == "all" or self.opt.injection_layer == "4":
            self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt)
        else:
            self.up_1 = SPADEResnetBlock_non_spade(8 * nf, 4 * nf, opt)

        if self.opt.injection_layer == "all" or self.opt.injection_layer == "5":
            self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt)
        else:
            self.up_2 = SPADEResnetBlock_non_spade(4 * nf, 2 * nf, opt)

        if self.opt.injection_layer == "all" or self.opt.injection_layer == "6":
            self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt)
        else:
            self.up_3 = SPADEResnetBlock_non_spade(2 * nf, 1 * nf, opt)

        final_nc = nf

        if opt.num_upsampling_layers == "most":
            self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt)
            final_nc = nf // 2

        self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)

        self.up = nn.Upsample(scale_factor=2)
예제 #15
0
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        nf = opt.ngf

        self.sw, self.sh = self.compute_latent_vector_size(opt)
        '''
        if opt.use_vae:
            # In case of VAE, we will sample from random z vector
            self.fc_surface = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)                   
        else:
            #Otherwise, we make the network deterministic by starting with 
            #downsampled segmentation map instead of random z
            self.fc_surface = nn.Conv2d(3, 16 * nf, 3, padding=1)
        '''

        self.surface_down_0 = SPADEResnetBlock(1, 1 * nf, opt)
        self.surface_down_1 = SPADEResnetBlock(1 * nf, 2 * nf, opt)
        self.surface_down_2 = SPADEResnetBlock(2 * nf, 4 * nf, opt)
        self.surface_down_3 = SPADEResnetBlock(4 * nf, 8 * nf, opt)
        self.surface_down_4 = SPADEResnetBlock(8 * nf, 16 * nf, opt)

        self.head_0_surface = SPADEResnetBlock(16 * nf, 16 * nf, opt)

        self.G_middle_0_surface = SPADEResnetBlock(16 * nf, 16 * nf, opt)
        self.G_middle_1_surface = SPADEResnetBlock(16 * nf, 16 * nf, opt)

        #Surface geneator layers
        self.surface_up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt)
        self.surface_up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt)
        self.surface_up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt)
        self.surface_up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt)

        final_nc = nf

        if opt.num_upsampling_layers == 'most':
            self.surface_up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt)
            final_nc = nf // 2

        self.surface_conv_img = nn.Conv2d(final_nc, 1, 3, padding=1)

        #Color geneator layers
        extra_channels = True
        self.head_0_color = ModifiedSPADEResnetBlock(16 * nf, 16 * nf, opt,
                                                     extra_channels)

        self.G_middle_0_color = ModifiedSPADEResnetBlock(
            16 * nf, 16 * nf, opt, extra_channels)
        self.G_middle_1_color = ModifiedSPADEResnetBlock(
            16 * nf, 16 * nf, opt, extra_channels)

        self.color_up_0 = ModifiedSPADEResnetBlock(16 * nf, 8 * nf, opt,
                                                   extra_channels)
        self.color_up_1 = ModifiedSPADEResnetBlock(8 * nf, 4 * nf, opt,
                                                   extra_channels)
        self.color_up_2 = ModifiedSPADEResnetBlock(4 * nf, 2 * nf, opt,
                                                   extra_channels)
        self.color_up_3 = ModifiedSPADEResnetBlock(2 * nf, 1 * nf, opt,
                                                   extra_channels)

        final_nc = nf

        if opt.num_upsampling_layers == 'most':
            self.color_up_4 = ModifiedSPADEResnetBlock(1 * nf, nf // 2, opt,
                                                       extra_channels)
            final_nc = nf // 2

        self.color_conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)

        self.b1x1_conv = nn.Conv2d(8 * nf, 16, 3, padding=1)
        self.b1x2_conv = nn.Conv2d(4 * nf, 16, 3, padding=1)
        self.b1x3_conv = nn.Conv2d(2 * nf, 16, 3, padding=1)
        self.b1x4_conv = nn.Conv2d(1 * nf, 16, 3, padding=1)
        self.b1x5_conv = nn.Conv2d(nf // 2, 16, 3, padding=1)

        self.up = nn.Upsample(scale_factor=2)
        self.down = nn.Upsample(scale_factor=0.5)
예제 #16
0
    def __init__(self, opt):
        super().__init__()
        ########################### define params ##########################
        self.opt = opt
        self.n_downsample_G = n_downsample_G = opt.n_downsample_G  # number of downsamplings in generator
        self.n_downsample_A = n_downsample_A = opt.n_downsample_A  # number of downsamplings in attention module
        self.nf = nf = opt.ngf  # base channel size
        self.nf_max = nf_max = min(1024, nf * (2**n_downsample_G))
        self.ch = ch = [
            min(nf_max, nf * (2**i)) for i in range(n_downsample_G + 2)
        ]

        ### SPADE
        self.norm = norm = opt.norm_G
        self.conv_ks = conv_ks = opt.conv_ks  # conv kernel size in main branch
        self.embed_ks = embed_ks = opt.embed_ks  # conv kernel size in embedding network
        self.spade_ks = spade_ks = opt.spade_ks  # conv kernel size in SPADE
        self.spade_combine = opt.spade_combine  # combine ref/prev frames with current using SPADE
        self.n_sc_layers = opt.n_sc_layers  # number of layers to perform spade combine
        ch_hidden = []  # hidden channel size in SPADE module
        for i in range(n_downsample_G + 1):
            ch_hidden += [[
                ch[i]
            ]] if not self.spade_combine or i >= self.n_sc_layers else [
                [ch[i]] * 3
            ]
        self.ch_hidden = ch_hidden

        ### adaptive SPADE / Convolution
        self.adap_spade = opt.adaptive_spade  # use adaptive weight generation for SPADE
        self.adap_embed = opt.adaptive_spade and not opt.no_adaptive_embed  # use adaptive for the label embedding network
        self.adap_conv = opt.adaptive_conv  # use adaptive for convolution layers in the main branch
        self.n_adaptive_layers = opt.n_adaptive_layers if opt.n_adaptive_layers != -1 else n_downsample_G  # number of adaptive layers

        # for reference image encoding
        self.concat_label_ref = 'concat' in opt.use_label_ref  # how to utilize the reference label map: concat | mul
        self.mul_label_ref = 'mul' in opt.use_label_ref
        self.sh_fix = self.sw_fix = 32  # output spatial size for adaptive pooling layer
        self.sw = opt.fineSize // (
            2**opt.n_downsample_G
        )  # output spatial size at the bottle neck of generator
        self.sh = int(self.sw / opt.aspect_ratio)

        # weight generation
        self.n_fc_layers = n_fc_layers = opt.n_fc_layers  # number of fc layers in weight generation

        ########################### define network ##########################
        norm_ref = norm.replace('spade', '')
        input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc
        ref_nc = opt.output_nc + (0 if not self.concat_label_ref else input_nc)
        self.ref_img_first = SPADEConv2d(ref_nc, nf, norm=norm_ref)
        if self.mul_label_ref:
            self.ref_label_first = SPADEConv2d(input_nc, nf, norm=norm_ref)
        ref_conv = SPADEConv2d if not opt.res_for_ref else SPADEResnetBlock

        ### reference image encoding
        for i in range(n_downsample_G):
            ch_in, ch_out = ch[i], ch[i + 1]
            setattr(self, 'ref_img_down_%d' % i,
                    ref_conv(ch_in, ch_out, stride=2, norm=norm_ref))
            setattr(self, 'ref_img_up_%d' % i,
                    ref_conv(ch_out, ch_in, norm=norm_ref))
            if self.mul_label_ref:
                setattr(self, 'ref_label_down_%d' % i,
                        ref_conv(ch_in, ch_out, stride=2, norm=norm_ref))
                setattr(self, 'ref_label_up_%d' % i,
                        ref_conv(ch_out, ch_in, norm=norm_ref))

        ### SPADE / main branch weight generation
        if self.adap_spade or self.adap_conv:
            for i in range(self.n_adaptive_layers):
                ch_in, ch_out = ch[i], ch[i + 1]
                conv_ks2 = conv_ks**2
                embed_ks2 = embed_ks**2
                spade_ks2 = spade_ks**2
                ch_h = ch_hidden[i][0]

                fc_names, fc_outs = [], []
                if self.adap_spade:
                    fc0_out = fcs_out = (ch_h * spade_ks2 + 1) * 2
                    fc1_out = (ch_h * spade_ks2 +
                               1) * (1 if ch_in != ch_out else 2)
                    fc_names += ['fc_spade_0', 'fc_spade_1', 'fc_spade_s']
                    fc_outs += [fc0_out, fc1_out, fcs_out]
                    if self.adap_embed:
                        fc_names += ['fc_spade_e']
                        fc_outs += [ch_in * embed_ks2 + 1]
                if self.adap_conv:
                    fc0_out = ch_out * conv_ks2 + 1
                    fc1_out = ch_in * conv_ks2 + 1
                    fcs_out = ch_out + 1
                    fc_names += ['fc_conv_0', 'fc_conv_1', 'fc_conv_s']
                    fc_outs += [fc0_out, fc1_out, fcs_out]

                for n, l in enumerate(fc_names):
                    fc_in = ch_out if self.mul_label_ref else self.sh_fix * self.sw_fix
                    fc_layer = [sn(nn.Linear(fc_in, ch_out))]
                    for k in range(1, n_fc_layers):
                        fc_layer += [sn(nn.Linear(ch_out, ch_out))]
                    fc_layer += [sn(nn.Linear(ch_out, fc_outs[n]))]
                    setattr(self, '%s_%d' % (l, i), nn.Sequential(*fc_layer))

        ### label embedding network
        self.label_embedding = LabelEmbedder(
            opt,
            input_nc,
            opt.netS,
            params_free_layers=(self.n_adaptive_layers
                                if self.adap_embed else 0))

        ### main branch layers
        for i in reversed(range(n_downsample_G + 1)):
            setattr(
                self, 'up_%d' % i,
                SPADEResnetBlock(
                    ch[i + 1],
                    ch[i],
                    norm=norm,
                    hidden_nc=ch_hidden[i],
                    conv_ks=conv_ks,
                    spade_ks=spade_ks,
                    conv_params_free=(self.adap_conv
                                      and i < self.n_adaptive_layers),
                    norm_params_free=(self.adap_spade
                                      and i < self.n_adaptive_layers)))

        self.conv_img = nn.Conv2d(nf, 3, kernel_size=3, padding=1)
        self.up = functools.partial(F.interpolate, scale_factor=2)

        ### for multiple reference images
        if opt.n_shot > 1:
            self.atn_query_first = SPADEConv2d(input_nc, nf, norm=norm_ref)
            self.atn_key_first = SPADEConv2d(input_nc, nf, norm=norm_ref)
            for i in range(n_downsample_A):
                f_in, f_out = ch[i], ch[i + 1]
                setattr(self, 'atn_key_%d' % i,
                        SPADEConv2d(f_in, f_out, stride=2, norm=norm_ref))
                setattr(self, 'atn_query_%d' % i,
                        SPADEConv2d(f_in, f_out, stride=2, norm=norm_ref))

        ### kld loss
        self.use_kld = opt.lambda_kld > 0
        self.z_dim = 256
        if self.use_kld:
            f_in = min(nf_max, nf * (2**n_downsample_G)) * self.sh * self.sw
            f_out = min(nf_max, nf * (2**n_downsample_G)) * self.sh * self.sw
            self.fc_mu_ref = nn.Linear(f_in, self.z_dim)
            self.fc_var_ref = nn.Linear(f_in, self.z_dim)
            self.fc = nn.Linear(self.z_dim, f_out)

        ### flow
        self.warp_prev = False  # whether to warp prev image (set when starting training multiple frames)
        self.warp_ref = opt.warp_ref and not opt.for_face  # whether to warp reference image and combine with the synthesized
        if self.warp_ref:
            self.flow_network_ref = FlowGenerator(opt, 2)
            if self.spade_combine:
                self.img_ref_embedding = LabelEmbedder(opt, opt.output_nc + 1,
                                                       opt.sc_arch)
예제 #17
0
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        # Does the forward pass use sean or not?
        if opt.norm_mode == 'sean':
            self.use_sean = True
        else:
            self.use_sean = False
        nf = opt.ngf

        self.sw, self.sh = self.compute_latent_vector_size(opt)

        # We are not going to use the variational encoding in our project.
        # if opt.use_vae:
        #     # In case of VAE, we will sample from random z vector
        #     self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)
        # else:
        #     # Otherwise, we make the network deterministic by starting with
        #     # downsampled segmentation map instead of random z

        # We don't want to use the style encoder when we are using CLADE or SPADE,
        # that's why we need this if statement.
        if opt.norm_mode != 'sean':
            self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)

            self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)

            self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
            self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt)

            self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt)
            self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt)
            self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt)
            self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt)

            final_nc = nf
        elif opt.norm_mode == 'sean':
            self.Zencoder = Zencoder(3, 512)

            self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)

            self.head_0 = SPADEResnetBlock(16 * nf,
                                           16 * nf,
                                           opt,
                                           Block_Name='head_0')

            self.G_middle_0 = SPADEResnetBlock(16 * nf,
                                               16 * nf,
                                               opt,
                                               Block_Name='G_middle_0')
            self.G_middle_1 = SPADEResnetBlock(16 * nf,
                                               16 * nf,
                                               opt,
                                               Block_Name='G_middle_1')

            self.up_0 = SPADEResnetBlock(16 * nf,
                                         8 * nf,
                                         opt,
                                         Block_Name='up_0')
            self.up_1 = SPADEResnetBlock(8 * nf,
                                         4 * nf,
                                         opt,
                                         Block_Name='up_1')
            self.up_2 = SPADEResnetBlock(4 * nf,
                                         2 * nf,
                                         opt,
                                         Block_Name='up_2')
            self.up_3 = SPADEResnetBlock(2 * nf,
                                         1 * nf,
                                         opt,
                                         Block_Name='up_3',
                                         use_rgb=False)

            final_nc = nf

        if opt.num_upsampling_layers == 'most':
            if opt.norm_mode != 'sean':
                self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt)
            elif opt.norm_mode == 'sean':
                self.up_4 = SPADEResnetBlock(1 * nf,
                                             nf // 2,
                                             opt,
                                             Block_Name='up_4')

            final_nc = nf // 2

        self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)

        self.up = nn.Upsample(scale_factor=2)
    def __init__(self, opt):
        super().__init__()
        ########################### define params ##########################
        self.opt = opt
        self.add_raw_loss = opt.add_raw_loss and opt.spade_combine
        self.n_downsample_G = n_downsample_G = opt.n_downsample_G  # number of downsamplings in generator
        self.n_downsample_A = opt.n_downsample_A  # number of downsamplings in attention module
        self.nf = nf = opt.ngf  # base channel size

        nf_max = min(1024, nf * (2**n_downsample_G))
        self.ch = ch = [
            min(nf_max, nf * (2**i)) for i in range(n_downsample_G + 2)
        ]

        ### SPADE
        self.norm = norm = opt.norm_G
        self.conv_ks = conv_ks = opt.conv_ks  # conv kernel size in main branch
        self.embed_ks = opt.embed_ks  # conv kernel size in embedding network
        self.spade_ks = spade_ks = opt.spade_ks  # conv kernel size in SPADE
        self.spade_combine = opt.spade_combine  # combine ref/prev frames with current using SPADE
        self.n_sc_layers = opt.n_sc_layers  # number of layers to perform spade combine
        ch_hidden = []  # hidden channel size in SPADE module
        for i in range(n_downsample_G + 1):
            ch_hidden += [[
                ch[i]
            ]] if not self.spade_combine or i >= self.n_sc_layers else [
                [ch[i]] * 4
            ]
        self.ch_hidden = ch_hidden

        ### adaptive SPADE / Convolution
        self.adap_spade = opt.adaptive_spade  # use adaptive weight generation for SPADE
        self.adap_embed = opt.adaptive_spade and not opt.no_adaptive_embed  # use adaptive for the label embedding network
        self.n_adaptive_layers = opt.n_adaptive_layers if opt.n_adaptive_layers != -1 else n_downsample_G  # number of adaptive layers

        # weight generation
        self.n_fc_layers = opt.n_fc_layers  # number of fc layers in weight generation

        ########################### define network ##########################
        norm_ref = norm.replace('spade', '')
        input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc  #1
        ref_nc = opt.output_nc  #3
        if not opt.use_new or opt.transfer_initial:
            self.image_encoder = Encoder(norm_ref=norm_ref,
                                         nf=self.nf,
                                         ch=self.ch,
                                         n_shot=self.opt.n_shot,
                                         n_downsample_G=self.n_downsample_G,
                                         n_downsample_A=self.n_downsample_A,
                                         isTrain=self.opt.isTrain,
                                         ref_nc=ref_nc)

            self.lmark_encoder = Encoder(norm_ref=norm_ref,
                                         nf=self.nf,
                                         ch=self.ch,
                                         n_shot=self.opt.n_shot,
                                         n_downsample_G=self.n_downsample_G,
                                         n_downsample_A=self.n_downsample_A,
                                         isTrain=self.opt.isTrain,
                                         ref_nc=input_nc)
        if opt.use_new:
            self.encoder = EncoderSelfAtten(norm_ref=norm_ref,
                                            nf=self.nf,
                                            ch=self.ch,
                                            n_shot=self.opt.n_shot,
                                            n_downsample_G=self.n_downsample_G,
                                            n_downsample_A=self.n_downsample_A,
                                            isTrain=self.opt.isTrain,
                                            ref_nc=ref_nc,
                                            lmark_nc=input_nc)

        self.comb_encoder = CombEncoder(norm_ref=norm_ref,
                                        ch=self.ch,
                                        n_shot=self.opt.n_shot,
                                        n_downsample_G=self.n_downsample_G)

        if not opt.no_atten:
            self.atten_gen = AttenGen(norm_ref=norm_ref,
                                      input_nc=input_nc,
                                      nf=self.nf,
                                      ch=self.ch,
                                      n_shot=self.opt.n_shot,
                                      n_downsample_A=self.n_downsample_A)

        ### SPADE / main branch weight generation
        if self.adap_spade:
            self.weight_gen = WeightGen(
                ch_hidden=self.ch_hidden,
                embed_ks=self.embed_ks,
                spade_ks=self.spade_ks,
                n_fc_layers=self.n_fc_layers,
                n_adaptive_layers=self.n_adaptive_layers,
                ch=self.ch,
                adap_embed=self.adap_embed)

        ### label embedding network
        self.label_embedding = LabelEmbedder(
            opt,
            input_nc,
            opt.netS,
            params_free_layers=(self.n_adaptive_layers
                                if self.adap_embed else 0))

        ### main branch layers
        for i in reversed(range(n_downsample_G + 1)):
            hidden_nc = ch_hidden[i]
            if i >= self.n_sc_layers or not opt.use_new:
                setattr(
                    self, 'up_%d' % i,
                    SPADEResnetBlock(
                        ch[i + 1],
                        ch[i],
                        norm=norm,
                        hidden_nc=hidden_nc,
                        conv_ks=conv_ks,
                        spade_ks=spade_ks,
                        conv_params_free=False,
                        norm_params_free=(self.adap_spade
                                          and i < self.n_adaptive_layers)))
            else:
                setattr(
                    self, 'up_%d' % i,
                    SPADEResnetBlockConcat(
                        ch[i + 1],
                        ch[i],
                        norm=norm,
                        hidden_nc=hidden_nc,
                        conv_ks=conv_ks,
                        spade_ks=spade_ks,
                        conv_params_free=False,
                        norm_params_free=(self.adap_spade
                                          and i < self.n_adaptive_layers)))

        self.conv_img = nn.Conv2d(nf, 3, kernel_size=3, padding=1)
        self.up = functools.partial(F.interpolate, scale_factor=2)
예제 #19
0
 def __init__(self, fin, fout, opt, reduction=8):
     super().__init__()
     self.spade_resnet_block = SPADEResnetBlock(fin, fout, opt)
     self.se_block = SEBlock(fout, reduction)
예제 #20
0
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        nf = opt.ngf

        self.sw, self.sh = self.compute_latent_vector_size(opt)

        if opt.use_vae:
            # In case of VAE, we will sample from random z vector
            self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)
        else:
            # Otherwise, we make the network deterministic by starting with
            # downsampled segmentation map instead of random z
            self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)
            if self.opt.retrival_memory:
                self.fc = nn.Conv2d(self.opt.semantic_nc * 1,
                                    16 * nf,
                                    3,
                                    padding=1)

        self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)

        self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
        self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt)

        self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt)
        self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt)
        self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt)
        self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt)

        final_nc = nf

        if opt.num_upsampling_layers == 'most':
            self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt)
            final_nc = nf // 2
        self.softmax = nn.LogSoftmax()
        self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)
        self.conv_seg = nn.Conv2d(final_nc, self.opt.semantic_nc, 3, padding=1)

        self.up = nn.Upsample(scale_factor=2)

        # fcn encoder
        nhidden = 128
        self.encode_shared = nn.Sequential(
            nn.Conv2d(self.opt.semantic_nc,
                      self.opt.semantic_nc,
                      kernel_size=3,
                      padding=1), nn.ReLU())
        self.memory_shared = nn.Sequential(
            nn.Conv2d(self.opt.semantic_nc,
                      self.opt.semantic_nc,
                      kernel_size=3,
                      padding=1), nn.ReLU())
        self.conv1_1 = nn.Conv2d(35, 64, 3, padding=1)
        self.conv1_1_memory = nn.Conv2d(70, 64, 3, padding=1)
        self.relu1_1 = nn.ReLU(inplace=True)
        self.conv1_2 = nn.Conv2d(64, 64, 5, padding=1)
        self.relu1_2 = nn.ReLU(inplace=True)
        self.pool1 = nn.MaxPool2d(4, stride=4, ceil_mode=True)  # 1/2

        self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1)
        self.relu2_1 = nn.ReLU(inplace=True)
        self.conv2_2 = nn.Conv2d(128, 128, 5, padding=1)
        self.relu2_2 = nn.ReLU(inplace=True)
        self.pool2 = nn.MaxPool2d(4, stride=4, ceil_mode=True)  # 1/4
        self.upscore = nn.ConvTranspose2d(128, 35, 32, stride=16, bias=False)

        self.drop2 = nn.Dropout2d()
예제 #21
0
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        output_nc = 3
        label_nc = opt.label_nc

        input_nc = label_nc + (1 if opt.contain_dontcare_label else
                               0) + (0 if opt.no_instance else 1)
        if opt.mix_input_gen:
            input_nc += 4

        norm_layer = get_nonspade_norm_layer(opt, 'instance')
        activation = nn.ReLU(False)

        # initial block
        self.init_block = nn.Sequential(*[
            nn.ReflectionPad2d(opt.resnet_initial_kernel_size // 2),
            norm_layer(
                nn.Conv2d(input_nc,
                          opt.ngf,
                          kernel_size=opt.resnet_initial_kernel_size,
                          padding=0)), activation
        ])

        # Downsampling blocks
        self.downlayers = nn.ModuleList()
        mult = 1
        for i in range(opt.resnet_n_downsample):
            self.downlayers.append(
                nn.Sequential(*[
                    norm_layer(
                        nn.Conv2d(opt.ngf * mult,
                                  opt.ngf * mult * 2,
                                  kernel_size=3,
                                  stride=2,
                                  padding=1)), activation
                ]))
            mult *= 2

        # Semantic core blocks
        self.resnet_core = nn.ModuleList()
        if opt.wide:
            self.resnet_core += [
                ResnetBlock(opt.ngf * mult,
                            dim2=opt.ngf * mult * 2,
                            norm_layer=norm_layer,
                            activation=activation,
                            kernel_size=opt.resnet_kernel_size)
            ]
            mult *= 2
        else:
            self.resnet_core += [
                ResnetBlock(opt.ngf * mult,
                            norm_layer=norm_layer,
                            activation=activation,
                            kernel_size=opt.resnet_kernel_size)
            ]

        for i in range(opt.resnet_n_blocks - 1):
            self.resnet_core += [
                ResnetBlock(opt.ngf * mult,
                            norm_layer=norm_layer,
                            activation=activation,
                            kernel_size=opt.resnet_kernel_size,
                            dilation=2)
            ]

        self.spade_core = nn.ModuleList()
        for i in range(opt.spade_n_blocks - 1):
            self.spade_core += [
                SPADEResnetBlock(opt.ngf * mult,
                                 opt.ngf * mult,
                                 opt,
                                 dilation=2)
            ]

        if opt.wide:
            self.spade_core += [
                SPADEResnetBlock(
                    opt.ngf * mult *
                    (2 if not self.opt.no_skip_connections else 1),
                    opt.ngf * mult // 2, opt)
            ]
            mult //= 2
        else:
            self.spade_core += [
                SPADEResnetBlock(
                    opt.ngf * mult *
                    (2 if not self.opt.no_skip_connections else 1),
                    opt.ngf * mult, opt)
            ]

        # Upsampling blocks
        self.uplayers = nn.ModuleList()
        for i in range(opt.resnet_n_downsample):
            self.uplayers.append(
                SPADEResnetBlock(
                    mult * opt.ngf *
                    (3 if not self.opt.no_skip_connections else 2) // 2,
                    opt.ngf * mult // 2, opt))
            mult //= 2

        final_nc = opt.ngf

        self.conv_img = nn.Conv2d(
            (input_nc +
             final_nc) if not self.opt.no_skip_connections else final_nc,
            output_nc,
            3,
            padding=1)

        self.up = nn.Upsample(scale_factor=2)