Example #1
0
    def __call__(self, x, m):
        # x: (N, C, H, W)
        s = self.shortcut(x, m)

        hidden_dim = min(x.shape[1], self.out_dim)
        with ps("res_layer1"):
            h = spade(x, m)
            h = self.act(h)
            h = PF.convolution(h,
                               hidden_dim,
                               kernel=(3, 3),
                               pad=(1, 1),
                               w_init=w_init(h, hidden_dim),
                               **self.conv_opts)

        with ps("res_layer2"):
            h = spade(h, m)
            h = self.act(h)
            h = PF.convolution(h,
                               self.out_dim,
                               kernel=(3, 3),
                               pad=(1, 1),
                               w_init=w_init(h, self.out_dim),
                               **self.conv_opts)

        return s + h
Example #2
0
    def __call__(self, z, m):
        # m has target image shape: (N, emb, H, W)
        # z: (N, z_dim)

        N = m.shape[0]
        H, W = self.image_shape
        sh = H // (2**self.num_upsample)
        sw = W // (2**self.num_upsample)

        with ps("spade_generator"):
            with ps("z_embedding"):
                x = PF.affine(z,
                              16 * self.nf * sh * sw,
                              w_init=w_init(z, 16 * self.nf * sh * sw))
                x = F.reshape(x, (N, 16 * self.nf, sh, sw))

            with ps("head"):
                x = self.head_0(x, m)

            with ps("middle0"):
                x = self.up(x)
                x = self.G_middle_0(x, m)

            with ps("middel1"):
                if self.num_upsample > 5:
                    x = self.up(x)
                x = self.G_middle_1(x, m)

            with ps("up0"):
                x = self.up(x)
                x = self.up_0(x, m)

            with ps("up1"):
                x = self.up(x)
                x = self.up_1(x, m)

            with ps("up2"):
                x = self.up(x)
                x = self.up_2(x, m)

            with ps("up3"):
                x = self.up(x)
                x = self.up_3(x, m)

            if self.num_upsample > 6:
                with ps("up4"):
                    x = self.up(x)
                    x = self.up_4(x, m)

            with ps("last_conv"):
                x = PF.convolution(F.leaky_relu(x, 2e-1),
                                   3,
                                   kernel=(3, 3),
                                   pad=(1, 1),
                                   w_init=w_init(x, 3))
                x = F.tanh(x)

        return x
Example #3
0
    def shortcut(self, x, m):
        s = x
        if x.shape[1] != self.out_dim:
            with ps("shortcut"):
                s = spade(s, m)
                s = PF.convolution(s, self.out_dim, kernel=(1, 1),
                                   w_init=w_init(s, self.out_dim), **self.conv_opts)

        return s