Example #1
0
    def __init__(self, outer_nc, inner_nc, opt, innerCos_list, shift_list, mask_global, input_nc, \
                 submodule=None, shift_layer=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d,
                 use_spectral_norm=False, layer_to_last=3):
        super(ResPatchSoftUnetSkipConnectionShiftTriple, self).__init__()
        self.outermost = outermost
        if input_nc is None:
            input_nc = outer_nc

        downconv = spectral_norm(
            nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1),
            use_spectral_norm)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc)

        device = 'cpu' if len(opt.gpu_ids) == 0 else 'gpu'
        # As the downconv layer is outer_nc in and inner_nc out.
        # So the shift define like this:
        shift = InnerResPatchSoftShiftTriple(inner_nc,
                                             opt.shift_sz,
                                             opt.stride,
                                             opt.mask_thred,
                                             opt.triple_weight,
                                             opt.fuse,
                                             layer_to_last=layer_to_last,
                                             device=device)

        shift.set_mask(mask_global)
        shift_list.append(shift)

        # Add latent constraint
        # Then add the constraint to the constrain layer list!
        innerCos = InnerCos(strength=opt.strength,
                            skip=opt.skip,
                            layer_to_last=layer_to_last,
                            device=device)
        innerCos.set_mask(
            mask_global)  # Here we need to set mask for innerCos layer too.
        innerCos_list.append(innerCos)

        # Different position only has differences in `upconv`
        # for the outermost, the special is `tanh`
        if outermost:
            upconv = spectral_norm(
                nn.ConvTranspose2d(inner_nc * 2,
                                   outer_nc,
                                   kernel_size=4,
                                   stride=2,
                                   padding=1), use_spectral_norm)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
            # for the innermost, the special is `inner_nc` instead of `inner_nc*2`
        elif innermost:
            upconv = spectral_norm(
                nn.ConvTranspose2d(inner_nc,
                                   outer_nc,
                                   kernel_size=4,
                                   stride=2,
                                   padding=1), use_spectral_norm)
            down = [downrelu, downconv
                    ]  # for the innermost, no submodule, and delete the bn
            up = [uprelu, upconv, upnorm]
            model = down + up
            # else, the normal
        else:
            # Res shift differs with other shift here. It is `*2` not `*3`.
            upconv = spectral_norm(
                nn.ConvTranspose2d(inner_nc * 2,
                                   outer_nc,
                                   kernel_size=4,
                                   stride=2,
                                   padding=1), use_spectral_norm)
            down = [downrelu, downconv, downnorm]
            # shift should be placed after uprelu
            # NB: innerCos are placed before shift. So need to add the latent gredient to
            # to former part.
            up = [uprelu, innerCos, shift, upconv, upnorm]

            model = down + [submodule] + up

        self.model = nn.Sequential(*model)
Example #2
0
class FaceUnetGenerator(nn.Module):
    def __init__(self,
                 input_nc,
                 output_nc,
                 innerCos_list,
                 shift_list,
                 mask_global,
                 opt,
                 ngf=64,
                 norm_layer=nn.BatchNorm2d,
                 use_spectral_norm=False):
        super(FaceUnetGenerator, self).__init__()

        # Encoder layers
        self.e1_c = spectral_norm(
            nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1),
            use_spectral_norm)

        self.e2_c = spectral_norm(
            nn.Conv2d(ngf, ngf * 2, kernel_size=4, stride=2, padding=1),
            use_spectral_norm)
        self.e2_norm = norm_layer(ngf * 2)

        self.e3_c = spectral_norm(
            nn.Conv2d(ngf * 2, ngf * 4, kernel_size=6, stride=2, padding=2),
            use_spectral_norm)
        self.e3_norm = norm_layer(ngf * 4)

        self.e4_c = spectral_norm(
            nn.Conv2d(ngf * 4, ngf * 8, kernel_size=4, stride=2, padding=1),
            use_spectral_norm)
        self.e4_norm = norm_layer(ngf * 8)

        self.e5_c = spectral_norm(
            nn.Conv2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1),
            use_spectral_norm)
        self.e5_norm = norm_layer(ngf * 8)

        self.e6_c = spectral_norm(
            nn.Conv2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1),
            use_spectral_norm)
        self.e6_norm = norm_layer(ngf * 8)

        self.e7_c = spectral_norm(
            nn.Conv2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1),
            use_spectral_norm)
        self.e7_norm = norm_layer(ngf * 8)

        self.e8_c = spectral_norm(
            nn.Conv2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1),
            use_spectral_norm)

        # Deocder layers
        self.d1_dc = spectral_norm(
            nn.ConvTranspose2d(ngf * 8,
                               ngf * 8,
                               kernel_size=4,
                               stride=2,
                               padding=1), use_spectral_norm)
        self.d1_norm = norm_layer(ngf * 8)

        self.d2_dc = spectral_norm(
            nn.ConvTranspose2d(ngf * 8 * 2,
                               ngf * 8,
                               kernel_size=4,
                               stride=2,
                               padding=1), use_spectral_norm)
        self.d2_norm = norm_layer(ngf * 8)

        self.d3_dc = spectral_norm(
            nn.ConvTranspose2d(ngf * 8 * 2,
                               ngf * 8,
                               kernel_size=4,
                               stride=2,
                               padding=1), use_spectral_norm)
        self.d3_norm = norm_layer(ngf * 8)

        self.d4_dc = spectral_norm(
            nn.ConvTranspose2d(ngf * 8 * 2,
                               ngf * 8,
                               kernel_size=4,
                               stride=2,
                               padding=1), use_spectral_norm)
        self.d4_norm = norm_layer(ngf * 8)

        self.d5_dc = spectral_norm(
            nn.ConvTranspose2d(ngf * 8 * 2,
                               ngf * 4,
                               kernel_size=4,
                               stride=2,
                               padding=1), use_spectral_norm)
        self.d5_norm = norm_layer(ngf * 4)

        # shift before this layer
        self.d6_dc = spectral_norm(
            nn.ConvTranspose2d(ngf * 4 * 3,
                               ngf * 2,
                               kernel_size=4,
                               stride=2,
                               padding=1), use_spectral_norm)
        self.d6_norm = norm_layer(ngf * 2)

        self.d7_dc = spectral_norm(
            nn.ConvTranspose2d(ngf * 2 * 2,
                               ngf,
                               kernel_size=4,
                               stride=2,
                               padding=1), use_spectral_norm)
        self.d7_norm = norm_layer(ngf)

        self.d8_dc = spectral_norm(
            nn.ConvTranspose2d(ngf * 2,
                               output_nc,
                               kernel_size=4,
                               stride=2,
                               padding=1), use_spectral_norm)

        # construct shift and innerCos
        device = 'cpu' if len(opt.gpu_ids) == 0 else 'gpu'
        self.shift = InnerFaceShiftTriple(opt.shift_sz,
                                          opt.stride,
                                          opt.mask_thred,
                                          opt.triple_weight,
                                          layer_to_last=3,
                                          device=device)
        self.shift.set_mask(mask_global)
        shift_list.append(self.shift)

        self.innerCos = InnerCos(strength=opt.strength,
                                 skip=opt.skip,
                                 layer_to_last=3,
                                 device=device)
        self.innerCos.set_mask(
            mask_global)  # Here we need to set mask for innerCos layer too.
        innerCos_list.append(self.innerCos)

    # In this case, we have very flexible unet construction mode.
    def forward(self, input, flip_feat=None):
        # Encoder
        # No norm on the first layer
        e1 = self.e1_c(input)
        e2 = self.e2_norm(self.e2_c(F.leaky_relu_(e1, negative_slope=0.2)))
        e3 = self.e3_norm(self.e3_c(F.leaky_relu_(e2, negative_slope=0.2)))
        e4 = self.e4_norm(self.e4_c(F.leaky_relu_(e3, negative_slope=0.2)))
        e5 = self.e5_norm(self.e5_c(F.leaky_relu_(e4, negative_slope=0.2)))
        e6 = self.e6_norm(self.e6_c(F.leaky_relu_(e5, negative_slope=0.2)))

        e7 = self.e7_norm(self.e7_c(F.leaky_relu_(e6, negative_slope=0.2)))
        # No norm in the inner_most layer
        e8 = self.e8_c(F.leaky_relu_(e7, negative_slope=0.2))

        # Decoder
        d1 = self.d1_norm(self.d1_dc(F.relu_(e8)))
        d2 = self.d2_norm(self.d2_dc(F.relu_(self.cat_feat(d1, e7))))
        d3 = self.d3_norm(self.d3_dc(F.relu_(self.cat_feat(d2, e6))))
        d4 = self.d4_norm(self.d4_dc(F.relu_(self.cat_feat(d3, e5))))
        d5 = self.d5_norm(self.d5_dc(F.relu_(self.cat_feat(d4, e4))))
        tmp, innerFeat = self.shift(
            self.innerCos(F.relu_(self.cat_feat(d5, e3))), flip_feat)
        d6 = self.d6_norm(self.d6_dc(tmp))
        d7 = self.d7_norm(self.d7_dc(F.relu_(self.cat_feat(d6, e2))))
        # No norm on the last layer
        d8 = self.d8_dc(F.relu_(self.cat_feat(d7, e1)))

        d8 = torch.tanh(d8)

        return d8, innerFeat

    def cat_feat(self, de_feat, en_feat):
        _, _, h1, w1 = de_feat.size()
        _, _, h2, w2 = en_feat.size()
        if h1 != h2 or w1 != w2:
            de_feat = F.interpolate(de_feat, (h2, w2), mode='bilinear')
        return torch.cat([de_feat, en_feat], 1)
Example #3
0
    def __init__(self,
                 input_nc,
                 output_nc,
                 innerCos_list,
                 shift_list,
                 mask_global,
                 opt,
                 ngf=64,
                 norm_layer=nn.BatchNorm2d,
                 use_spectral_norm=False):
        super(FaceUnetGenerator, self).__init__()

        # Encoder layers
        self.e1_c = spectral_norm(
            nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1),
            use_spectral_norm)

        self.e2_c = spectral_norm(
            nn.Conv2d(ngf, ngf * 2, kernel_size=4, stride=2, padding=1),
            use_spectral_norm)
        self.e2_norm = norm_layer(ngf * 2)

        self.e3_c = spectral_norm(
            nn.Conv2d(ngf * 2, ngf * 4, kernel_size=6, stride=2, padding=2),
            use_spectral_norm)
        self.e3_norm = norm_layer(ngf * 4)

        self.e4_c = spectral_norm(
            nn.Conv2d(ngf * 4, ngf * 8, kernel_size=4, stride=2, padding=1),
            use_spectral_norm)
        self.e4_norm = norm_layer(ngf * 8)

        self.e5_c = spectral_norm(
            nn.Conv2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1),
            use_spectral_norm)
        self.e5_norm = norm_layer(ngf * 8)

        self.e6_c = spectral_norm(
            nn.Conv2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1),
            use_spectral_norm)
        self.e6_norm = norm_layer(ngf * 8)

        self.e7_c = spectral_norm(
            nn.Conv2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1),
            use_spectral_norm)
        self.e7_norm = norm_layer(ngf * 8)

        self.e8_c = spectral_norm(
            nn.Conv2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1),
            use_spectral_norm)

        # Deocder layers
        self.d1_dc = spectral_norm(
            nn.ConvTranspose2d(ngf * 8,
                               ngf * 8,
                               kernel_size=4,
                               stride=2,
                               padding=1), use_spectral_norm)
        self.d1_norm = norm_layer(ngf * 8)

        self.d2_dc = spectral_norm(
            nn.ConvTranspose2d(ngf * 8 * 2,
                               ngf * 8,
                               kernel_size=4,
                               stride=2,
                               padding=1), use_spectral_norm)
        self.d2_norm = norm_layer(ngf * 8)

        self.d3_dc = spectral_norm(
            nn.ConvTranspose2d(ngf * 8 * 2,
                               ngf * 8,
                               kernel_size=4,
                               stride=2,
                               padding=1), use_spectral_norm)
        self.d3_norm = norm_layer(ngf * 8)

        self.d4_dc = spectral_norm(
            nn.ConvTranspose2d(ngf * 8 * 2,
                               ngf * 8,
                               kernel_size=4,
                               stride=2,
                               padding=1), use_spectral_norm)
        self.d4_norm = norm_layer(ngf * 8)

        self.d5_dc = spectral_norm(
            nn.ConvTranspose2d(ngf * 8 * 2,
                               ngf * 4,
                               kernel_size=4,
                               stride=2,
                               padding=1), use_spectral_norm)
        self.d5_norm = norm_layer(ngf * 4)

        # shift before this layer
        self.d6_dc = spectral_norm(
            nn.ConvTranspose2d(ngf * 4 * 3,
                               ngf * 2,
                               kernel_size=4,
                               stride=2,
                               padding=1), use_spectral_norm)
        self.d6_norm = norm_layer(ngf * 2)

        self.d7_dc = spectral_norm(
            nn.ConvTranspose2d(ngf * 2 * 2,
                               ngf,
                               kernel_size=4,
                               stride=2,
                               padding=1), use_spectral_norm)
        self.d7_norm = norm_layer(ngf)

        self.d8_dc = spectral_norm(
            nn.ConvTranspose2d(ngf * 2,
                               output_nc,
                               kernel_size=4,
                               stride=2,
                               padding=1), use_spectral_norm)

        # construct shift and innerCos
        device = 'cpu' if len(opt.gpu_ids) == 0 else 'gpu'
        self.shift = InnerFaceShiftTriple(opt.shift_sz,
                                          opt.stride,
                                          opt.mask_thred,
                                          opt.triple_weight,
                                          layer_to_last=3,
                                          device=device)
        self.shift.set_mask(mask_global)
        shift_list.append(self.shift)

        self.innerCos = InnerCos(strength=opt.strength,
                                 skip=opt.skip,
                                 layer_to_last=3,
                                 device=device)
        self.innerCos.set_mask(
            mask_global)  # Here we need to set mask for innerCos layer too.
        innerCos_list.append(self.innerCos)
    def __init__(self, outer_nc, inner_nc, innerCos_list=None, shift_list=None, mask_global=None, input_nc=None, opt=None,\
                 submodule=None, shift_layer=False, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False, layer_to_last=3):
        super(InceptionShiftUnetSkipConnectionBlock, self).__init__()

        self.outermost = outermost
        if input_nc is None:
            input_nc = outer_nc

        if shift_layer:
            # As the downconv layer is outer_nc in and inner_nc out.
            # So the shift define like this:
            shift = InnerShiftTriple(opt.shift_sz,
                                     opt.stride,
                                     opt.mask_thred,
                                     opt.triple_weight,
                                     layer_to_last=layer_to_last)

            shift.set_mask(mask_global)
            shift_list.append(shift)

            # Add latent constraint
            # Then add the constraint to the constrain layer list!
            innerCosBefore = InnerCos(strength=opt.strength,
                                      skip=opt.skip,
                                      layer_to_last=3)
            innerCosBefore.set_mask(
                mask_global
            )  # Here we need to set mask for innerCos layer too.
            innerCos_list.append(innerCosBefore)

        downconv = InceptionDown(
            input_nc, inner_nc
        )  # nn.Conv2d(input_nc, inner_nc, kernel_size=4,stride=2, padding=1)

        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc)

        # Different position only has differences in `upconv`
        # for the outermost, the special is `tanh`
        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2,
                                        outer_nc,
                                        kernel_size=4,
                                        stride=2,
                                        padding=1)
            downconv = nn.Conv2d(input_nc,
                                 inner_nc,
                                 kernel_size=4,
                                 stride=2,
                                 padding=1)

            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
            # for the innermost, the special is `inner_nc` instead of `inner_nc*2`
        elif innermost:
            upconv = InceptionUp(
                inner_nc, outer_nc
            )  #nn.ConvTranspose2d(inner_nc, outer_nc,kernel_size=4, stride=2,padding=1)

            down = [downrelu, downconv
                    ]  # for the innermost, no submodule, and delete the bn
            up = [uprelu, upconv, upnorm]
            model = down + up
            # else, the normal
        else:
            upconv = InceptionUp(
                inner_nc * 3, outer_nc
            )  #nn.ConvTranspose2d(inner_nc * 2, outer_nc,kernel_size=4, stride=2,padding=1)

            down = [downrelu, downconv, downnorm]
            up = [uprelu, innerCosBefore, shift, upconv, upnorm]

            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up

        self.model = nn.Sequential(*model)