def __call__(self, x, m): # x: (N, C, H, W) s = self.shortcut(x, m) hidden_dim = min(x.shape[1], self.out_dim) with ps("res_layer1"): h = spade(x, m) h = self.act(h) h = PF.convolution(h, hidden_dim, kernel=(3, 3), pad=(1, 1), w_init=w_init(h, hidden_dim), **self.conv_opts) with ps("res_layer2"): h = spade(h, m) h = self.act(h) h = PF.convolution(h, self.out_dim, kernel=(3, 3), pad=(1, 1), w_init=w_init(h, self.out_dim), **self.conv_opts) return s + h
def __call__(self, z, m): # m has target image shape: (N, emb, H, W) # z: (N, z_dim) N = m.shape[0] H, W = self.image_shape sh = H // (2**self.num_upsample) sw = W // (2**self.num_upsample) with ps("spade_generator"): with ps("z_embedding"): x = PF.affine(z, 16 * self.nf * sh * sw, w_init=w_init(z, 16 * self.nf * sh * sw)) x = F.reshape(x, (N, 16 * self.nf, sh, sw)) with ps("head"): x = self.head_0(x, m) with ps("middle0"): x = self.up(x) x = self.G_middle_0(x, m) with ps("middel1"): if self.num_upsample > 5: x = self.up(x) x = self.G_middle_1(x, m) with ps("up0"): x = self.up(x) x = self.up_0(x, m) with ps("up1"): x = self.up(x) x = self.up_1(x, m) with ps("up2"): x = self.up(x) x = self.up_2(x, m) with ps("up3"): x = self.up(x) x = self.up_3(x, m) if self.num_upsample > 6: with ps("up4"): x = self.up(x) x = self.up_4(x, m) with ps("last_conv"): x = PF.convolution(F.leaky_relu(x, 2e-1), 3, kernel=(3, 3), pad=(1, 1), w_init=w_init(x, 3)) x = F.tanh(x) return x
def shortcut(self, x, m): s = x if x.shape[1] != self.out_dim: with ps("shortcut"): s = spade(s, m) s = PF.convolution(s, self.out_dim, kernel=(1, 1), w_init=w_init(s, self.out_dim), **self.conv_opts) return s