Ejemplo n.º 1
0
class MVN(Chain):
    def __init__(self, mu, sigma, x0):
        super(MVN, self).__init__()
        self.mu = mu
        self.sigma = sigma
        self.n_particle = x0.shape[0]
        with self.init_scope():
            self.theta = Parameter(initializer=x0)

    def logp(self):
        d = self.mu.shape[0]
        mean = np.broadcast_to(self.mu, (self.n_particle, d))
        ln_var = np.broadcast_to(2 * np.log(self.sigma), (self.n_particle, d))

        logp = -F.gaussian_nll(self.theta, mean, ln_var, reduce='no')
        logp = F.sum(logp, axis=1).reshape(-1)
        logp = F.broadcast_to(logp, (self.n_particle, self.n_particle))
        return logp

    def __call__(self):
        ker = rbf(self.theta.reshape(self.n_particle, -1))
        nlogp = -self.logp()
        loss = F.mean(F.sum(ker.data * nlogp + ker, axis=1))

        chainer.report(
            {
                'loss': loss,
                'nlogp': F.mean(nlogp[0]),
            },
            observer=self,
        )
        return loss
Ejemplo n.º 2
0
class InitialSkipArchitecture(Chain):
    def __init__(self, size, in_channels, out_channels):
        super().__init__()
        with self.init_scope():
            self.c1 = Parameter(shape=(in_channels, 4, 4),
                                initializer=Normal(1.0))
            self.s1 = StyleAffineTransformation(size, in_channels)
            self.w1 = WeightModulatedConvolution(in_channels, out_channels)
            self.n1 = NoiseAdder()
            self.a1 = LeakyRelu()
            self.s2 = StyleAffineTransformation(size, out_channels)
            self.trgb = ToRGB(out_channels)

    def __call__(self, w):
        batch = w.shape[0]
        h1 = self.c1.reshape(1, *self.c1.shape)
        h2 = broadcast_to(h1, (batch, *self.c1.shape))
        h3 = self.w1(h2, self.s1(w))
        h4 = self.n1(h3)
        h5 = self.a1(h4)
        return h5, self.trgb(h5, self.s2(w))
Ejemplo n.º 3
0
class WeightModulatedConvolution(Link):
    def __init__(self,
                 in_channels,
                 out_channels,
                 pointwise=False,
                 demod=True,
                 gain=root(2)):
        super().__init__()
        self.demod = demod
        self.ksize = 1 if pointwise else 3
        self.pad = 0 if pointwise else 1
        self.c = gain * root(1 / (in_channels * self.ksize**2))
        with self.init_scope():
            self.w = Parameter(shape=(out_channels, in_channels, self.ksize,
                                      self.ksize),
                               initializer=Normal(1.0))
            self.b = Parameter(shape=out_channels, initializer=Zero())

    def __call__(self, x, y):
        out_channels = self.b.shape[0]
        batch, in_channels, height, width = x.shape
        modulated_w = self.w * y.reshape(batch, 1, in_channels, 1, 1)
        w = modulated_w / sqrt(
            sum(modulated_w**2, axis=(2, 3, 4), keepdims=True) +
            1e-08) if self.demod else modulated_w
        grouped_w = w.reshape(batch * out_channels, in_channels, self.ksize,
                              self.ksize)
        grouped_x = x.reshape(1, batch * in_channels, height, width)
        padded_grouped_x = pad(grouped_x,
                               ((0, 0), (0, 0), (self.pad, self.pad),
                                (self.pad, self.pad)),
                               mode="edge")
        h = convolution_2d(padded_grouped_x,
                           grouped_w,
                           stride=1,
                           pad=0,
                           groups=batch)
        return h.reshape(batch, out_channels, height, width) + self.b.reshape(
            1, out_channels, 1, 1)