class MVN(Chain): def __init__(self, mu, sigma, x0): super(MVN, self).__init__() self.mu = mu self.sigma = sigma self.n_particle = x0.shape[0] with self.init_scope(): self.theta = Parameter(initializer=x0) def logp(self): d = self.mu.shape[0] mean = np.broadcast_to(self.mu, (self.n_particle, d)) ln_var = np.broadcast_to(2 * np.log(self.sigma), (self.n_particle, d)) logp = -F.gaussian_nll(self.theta, mean, ln_var, reduce='no') logp = F.sum(logp, axis=1).reshape(-1) logp = F.broadcast_to(logp, (self.n_particle, self.n_particle)) return logp def __call__(self): ker = rbf(self.theta.reshape(self.n_particle, -1)) nlogp = -self.logp() loss = F.mean(F.sum(ker.data * nlogp + ker, axis=1)) chainer.report( { 'loss': loss, 'nlogp': F.mean(nlogp[0]), }, observer=self, ) return loss
class InitialSkipArchitecture(Chain): def __init__(self, size, in_channels, out_channels): super().__init__() with self.init_scope(): self.c1 = Parameter(shape=(in_channels, 4, 4), initializer=Normal(1.0)) self.s1 = StyleAffineTransformation(size, in_channels) self.w1 = WeightModulatedConvolution(in_channels, out_channels) self.n1 = NoiseAdder() self.a1 = LeakyRelu() self.s2 = StyleAffineTransformation(size, out_channels) self.trgb = ToRGB(out_channels) def __call__(self, w): batch = w.shape[0] h1 = self.c1.reshape(1, *self.c1.shape) h2 = broadcast_to(h1, (batch, *self.c1.shape)) h3 = self.w1(h2, self.s1(w)) h4 = self.n1(h3) h5 = self.a1(h4) return h5, self.trgb(h5, self.s2(w))
class WeightModulatedConvolution(Link): def __init__(self, in_channels, out_channels, pointwise=False, demod=True, gain=root(2)): super().__init__() self.demod = demod self.ksize = 1 if pointwise else 3 self.pad = 0 if pointwise else 1 self.c = gain * root(1 / (in_channels * self.ksize**2)) with self.init_scope(): self.w = Parameter(shape=(out_channels, in_channels, self.ksize, self.ksize), initializer=Normal(1.0)) self.b = Parameter(shape=out_channels, initializer=Zero()) def __call__(self, x, y): out_channels = self.b.shape[0] batch, in_channels, height, width = x.shape modulated_w = self.w * y.reshape(batch, 1, in_channels, 1, 1) w = modulated_w / sqrt( sum(modulated_w**2, axis=(2, 3, 4), keepdims=True) + 1e-08) if self.demod else modulated_w grouped_w = w.reshape(batch * out_channels, in_channels, self.ksize, self.ksize) grouped_x = x.reshape(1, batch * in_channels, height, width) padded_grouped_x = pad(grouped_x, ((0, 0), (0, 0), (self.pad, self.pad), (self.pad, self.pad)), mode="edge") h = convolution_2d(padded_grouped_x, grouped_w, stride=1, pad=0, groups=batch) return h.reshape(batch, out_channels, height, width) + self.b.reshape( 1, out_channels, 1, 1)