Example #1
0
    def estimate_ELBO(query_images, z_t_param_array, pixel_mean,
                      pixel_log_sigma):
        # KL Diverge, pixel_ln_varnce
        kl_divergence = 0
        for params_t in z_t_param_array:
            mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p = params_t
            normal_q = chainer.distributions.Normal(mean_z_q,
                                                    log_scale=ln_var_z_q)
            normal_p = chainer.distributions.Normal(mean_z_p,
                                                    log_scale=ln_var_z_p)
            kld_t = chainer.kl_divergence(normal_q, normal_p)
            kl_divergence += cf.sum(kld_t)
        kl_divergence = kl_divergence / args.batch_size

        # Negative log-likelihood of generated image
        batch_size = query_images.shape[0]
        num_pixels_per_batch = np.prod(query_images.shape[1:])
        normal = chainer.distributions.Normal(query_images,
                                              log_scale=pixel_log_sigma)

        log_px = cf.sum(normal.log_prob(pixel_mean)) / batch_size
        negative_log_likelihood = -log_px

        # Empirical ELBO
        ELBO = log_px - kl_divergence

        # https://arxiv.org/abs/1604.08772 Section.2
        # https://www.reddit.com/r/MachineLearning/comments/56m5o2/discussion_calculation_of_bitsdims/
        bits_per_pixel = -(ELBO / num_pixels_per_batch -
                           np.log(256)) / np.log(2)

        return ELBO, bits_per_pixel, negative_log_likelihood, kl_divergence
Example #2
0
    def __call__(self, x, x_=None, **kwargs):
        if x_ is None:
            x_ = x
        if self.k > 1:
            x_ = F.repeat(x_, self.k, axis=0)

        q_z = self.encode(x, **kwargs)
        z = self.sample(q_z)
        p_x = self.decode(z, **kwargs)
        p_z = self.prior()

        # 追加誤差関数
        # mse_vel = F.mean_squared_error(x_, p_x.mean)
        # mse_vor = F.mean_squared_error(*map(vorticity, (x_, p_x.mean)))
        # y = F.sigmoid(p_x.mean)
        # mse_vel = self.batch_mean(F.squared_error(*map(logit, (x_, y))))
        # mse_vor = self.batch_mean(F.squared_error(*map(vorticity_logit15,
        #                                                (x_, y))))

        # reporter.report({'mse_vel': mse_vel}, self)
        # reporter.report({'mse_vor': mse_vor}, self)

        # 誤差関数
        reconstr = self.batch_mean(p_x.log_prob(x_))
        # reconstr = -self.batch_mean((p_x.mean - x_) ** 2)
        kl_penalty = self.batch_mean(chainer.kl_divergence(q_z, p_z))
        loss = self.beta * kl_penalty - reconstr

        reporter.report({'loss': loss}, self)
        reporter.report({'reconstr': reconstr}, self)
        reporter.report({'kl_penalty': kl_penalty}, self)

        return loss
Example #3
0
 def cal_loss(self, x):
     q_z = self.encoder(x)
     z = q_z.sample(self.k)
     x_hat = self.decoder(z)
     loss_rec = F.mean_squared_error(x_hat, x)
     loss_kl = F.mean(
         F.sum(chainer.kl_divergence(q_z, self.prior()), axis=-1))
     return loss_rec, loss_kl
Example #4
0
    def check_kl(self, dist1, dist2):
        kl = chainer.kl_divergence(dist1, dist2).data
        if isinstance(kl, cuda.ndarray):
            kl = kl.get()

        sample = dist1.sample(300000)
        mc_kl = dist1.log_prob(sample).data - dist2.log_prob(sample).data
        if isinstance(mc_kl, cuda.ndarray):
            mc_kl = mc_kl.get()
        mc_kl = numpy.nanmean(mc_kl, axis=0)

        testing.assert_allclose(kl, mc_kl, atol=1e-2, rtol=1e-2)
Example #5
0
    def check_kl(self, dist1, dist2):
        kl = chainer.kl_divergence(dist1, dist2).data
        if isinstance(kl, cuda.ndarray):
            kl = kl.get()

        sample = dist1.sample(300000)
        mc_kl = dist1.log_prob(sample).data - dist2.log_prob(sample).data
        if isinstance(mc_kl, cuda.ndarray):
            mc_kl = mc_kl.get()
        mc_kl = numpy.nanmean(mc_kl, axis=0)

        testing.assert_allclose(kl, mc_kl, atol=1e-2, rtol=1e-2)
Example #6
0
    def __call__(self, x):
        q_z = self.encoder(x)
        z = q_z.sample(self.k)
        p_x = self.decoder(z)
        p_z = self.prior()

        reconstr = F.mean(
            p_x.log_prob(F.broadcast_to(x[None, :], (self.k, ) + x.shape)))
        kl_penalty = F.mean(chainer.kl_divergence(q_z, p_z))
        loss = -(reconstr - self.beta * kl_penalty)
        reporter.report({'loss': loss}, self)
        reporter.report({'reconstr': reconstr}, self)
        reporter.report({'kl_penalty': kl_penalty}, self)
        return loss
Example #7
0
    def __call__(self, x):
        q_z = self.encoder(x)
        z = q_z.sample(self.k)
        p_x = self.decoder(z)
        p_z = self.prior()

        reconstr = F.mean(p_x.log_prob(
            F.broadcast_to(x[None, :], (self.k,) + x.shape)))
        kl_penalty = F.mean(chainer.kl_divergence(q_z, p_z))
        loss = - (reconstr - self.beta * kl_penalty)
        reporter.report({'loss': loss}, self)
        reporter.report({'reconstr': reconstr}, self)
        reporter.report({'kl_penalty': kl_penalty}, self)
        return loss
Example #8
0
    def __call__(self, t, condition):
        # t(timesteps): 1-T

        distribution = chainer.distributions.Normal(
            self.xp.array(0, dtype='f'), self.xp.array(1, dtype='f'))
        z = distribution.sample(t.shape)
        # z(timesteps): 1-T

        condition = self.encoder(condition)
        # condition(timesteps): 1-T

        s_means, s_scales = self.student(z, condition)
        s_clipped_scales = F.maximum(
            s_scales, self.scalar_to_tensor(s_scales, -7))
        # s_means, s_scales(timesteps): 2-(T+1)

        x = z[:, :, 1:] * F.exp(s_scales[:, :, :-1]) + s_means[:, :, :-1]
        # x(timesteps): 2-T

        with chainer.using_config('train', False):
            y = self.teacher(x, condition[:, :, 1:])
        t_means, t_scales = y[:, 1:2], y[:, 2:3]
        t_clipped_scales = F.maximum(
            t_scales, self.scalar_to_tensor(t_scales, -7))
        # t_means, t_scales(timesteps): 3-(T+1)

        s_distribution = chainer.distributions.Normal(
            s_means[:, :, 1:], log_scale=s_clipped_scales[:, :, 1:])
        t_distribution = chainer.distributions.Normal(
            t_means, log_scale=t_clipped_scales)
        # s_distribution, t_distribution(timesteps): 3-(T+1)

        kl = chainer.kl_divergence(s_distribution, t_distribution)
        kl = F.minimum(
            kl, self.scalar_to_tensor(kl, 100))
        kl = F.average(kl)

        regularization = F.mean_squared_error(
            t_scales, s_scales[:, :, 1:])

        spectrogram_frame_loss = F.mean_squared_error(
            self.stft.magnitude(t[:, :, 1:]), self.stft.magnitude(x))

        loss = kl + self.lmd * regularization + spectrogram_frame_loss
        chainer.reporter.report({
            'loss': loss, 'kl_divergence': kl,
            'regularization': regularization,
            'spectrogram_frame_loss': spectrogram_frame_loss}, self)
        return loss