예제 #1
0
def marginal_loglikelihood(X, num_samples = 512):
    mu, log_sigma = conv_encoder(X, *enc_params)

    epsilon_shape = (num_samples, X.shape[0], mu.shape[1])
    epsilon = t_rng.normal(epsilon_shape)

    mu = mu.dimshuffle('x', 0, 1)
    log_sigma = log_sigma.dimshuffle('x', 0, 1)
    #log_sigma = log_sigma * 2.

    # compute z
    z = mu + T.exp(0.5 * log_sigma) * epsilon

    # Decode p(x | z) in roder to do flatten MLP compatible
    flat_z = z.reshape((epsilon.shape[0] * epsilon.shape[1],
            epsilon.shape[2]))

    reconstructed_x, _ = conv_decoder(X, flat_z, *dec_params)
    reconstructed_x = reconstructed_x.reshape((epsilon.shape[0], epsilon.shape[1], X.shape[1] * X.shape[2] * X.shape[3]))

    # compute log-probabilities
    log_q_z_x = -0.5 * (T.log(2 * math.pi) + log_sigma + (z - mu) ** 2 / T.exp(log_sigma)).sum(axis=2)
    log_p_z = -0.5 * (T.log(2 * math.pi) + (z ** 2)).sum(axis=2)

    # if self.continuous:
    #     # need to rewrite and finish this
    #     log_p_x_z = -0.5 * (T.log(2 * math.pi) + self.gauss_sigma + (X.dimshuffle('x', 0, 1) - reconstructed_x) ** 2 /T.exp(self.gauss_sigma)).sum(axis=2)
    # else:
    X_flatten = X.flatten(2)
    log_p_x_z = - T.nnet.binary_crossentropy(reconstructed_x, X_flatten.dimshuffle('x', 0, 1)).sum(axis=2)

    return T.mean( log_sum_exp(
            log_p_z + log_p_x_z - log_q_z_x,
            axis=0
            ) - T.log(T.cast(num_samples, 'float32'))  )
 def calculate_logprob(self, y_target, weights, means, sigmas):
     return log_sum_exp([
         T.log(weights[:, i]) +
         ((T.log(y_target[:, 0]) - means[:, i])**2 / (2 * sigmas[:, i]**2) +
          T.log(y_target[:, 0] * sigmas[:, i] * T.sqrt(2 * np.pi)))
         for i in range(self.num_mixtures)
     ],
                        axis=0)
예제 #3
0
def test_log_sum_exp_2():
    """
    Tests that the stable log sum exp succeeds for extreme values."
    """

    x = np.array([-100., 100.])
    x = sharedX(x)
    stable = log_sum_exp(x).eval()
    assert np.allclose(stable, 100.)
예제 #4
0
def test_log_sum_exp_1():
    """
    Tests that the stable log sum exp matches the naive one for
    values near 1.
    """

    rng = np.random.RandomState([2015, 2, 9])
    x = 1. + rng.randn(5) / 10.
    naive = np.log(np.exp(x).sum())
    x = sharedX(x)
    stable = log_sum_exp(x).eval()
    assert np.allclose(naive, stable)
예제 #5
0
    def log_likelihood_approximation(self, X, num_samples):
        """
        Computes the importance sampling approximation to the marginal
        log-likelihood of X, using the reparametrization trick.

        Parameters
        ----------
        X : tensor_like
            Input
        num_samples : int
            Number of posterior samples per data point, e.g. number of times z
            is sampled for each x.

        Returns
        -------
        approximation : tensor_like
            Approximation on the marginal log-likelihood
        """
        # Sample noise
        epsilon_shape = (num_samples, X.shape[0], self.nhid)
        epsilon = self.sample_from_epsilon(shape=epsilon_shape)
        # Encode q(z | x) parameters
        phi = self.encode_phi(X)
        # Compute z
        z = self.sample_from_q_z_given_x(epsilon=epsilon, phi=phi)
        # Decode p(x | z) parameters
        # (z is flattened out in order to be MLP-compatible, and the parameters
        #  output by the decoder network are reshaped to the right shape)
        flat_z = z.reshape((epsilon.shape[0] * epsilon.shape[1],
                            epsilon.shape[2]))
        theta = self.decode_theta(flat_z)
        theta = tuple(
            theta_i.reshape((epsilon.shape[0], epsilon.shape[1],
                             theta_i.shape[1]))
            for theta_i in theta
        )
        # Compute log-probabilities
        log_q_z_x = self.log_q_z_given_x(z=z, phi=phi)
        log_p_z = self.log_p_z(z)
        log_p_x_z = self.log_p_x_given_z(
            X=X.dimshuffle(('x', 0, 1)),
            theta=theta
        )

        return log_sum_exp(
            log_p_z + log_p_x_z - log_q_z_x,
            axis=0
        ) - T.log(num_samples)
 def calculate_logprob(self, y_target, weights, means, sigmas):
     return log_sum_exp([T.log(weights[:, i]) + ((T.log(y_target[:, 0]) - means[:, i]) ** 2 / (2 * sigmas[:, i] ** 2) + T.log(y_target[:, 0] * sigmas[:, i] * T.sqrt(2 * np.pi))) for i in range(self.num_mixtures)], axis=0)