예제 #1
0
    def testKLBatchBroadcast(self):
        batch_shape = [2]
        event_shape = [3]
        mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
        # No batch shape.
        mu_b, sigma_b = self._random_mu_and_sigma([], event_shape)
        mvn_a = tfd.MultivariateNormalTriL(
            loc=mu_a,
            scale_tril=np.linalg.cholesky(sigma_a),
            validate_args=True)
        mvn_b = tfd.MultivariateNormalTriL(
            loc=mu_b,
            scale_tril=np.linalg.cholesky(sigma_b),
            validate_args=True)

        kl = tfd.kl_divergence(mvn_a, mvn_b)
        self.assertEqual(batch_shape, kl.shape)

        kl_v = self.evaluate(kl)
        expected_kl_0 = _compute_non_batch_kl(mu_a[0, :], sigma_a[0, :, :],
                                              mu_b, sigma_b)
        expected_kl_1 = _compute_non_batch_kl(mu_a[1, :], sigma_a[1, :, :],
                                              mu_b, sigma_b)
        self.assertAllClose(expected_kl_0, kl_v[0])
        self.assertAllClose(expected_kl_1, kl_v[1])
예제 #2
0
    def test_log_prob_matches_linear_gaussian_ssm(self):
        dim = 2
        batch_shape = [3, 1]
        seed, *model_seeds = samplers.split_seed(test_util.test_seed(), n=6)

        # Sample a random linear Gaussian process.
        prior_loc = self.evaluate(
            tfd.Normal(0., 1.).sample(batch_shape + [dim],
                                      seed=model_seeds[0]))
        prior_scale = self.evaluate(
            tfd.InverseGamma(1., 1.).sample(batch_shape + [dim],
                                            seed=model_seeds[1]))
        transition_matrix = self.evaluate(
            tfd.Normal(0., 1.).sample([dim, dim], seed=model_seeds[2]))
        transition_bias = self.evaluate(
            tfd.Normal(0., 1.).sample(batch_shape + [dim],
                                      seed=model_seeds[3]))
        transition_scale_tril = self.evaluate(
            tf.linalg.cholesky(
                tfd.WishartTriL(
                    df=dim,
                    scale_tril=tf.eye(dim)).sample(seed=model_seeds[4])))

        initial_state_prior = tfd.MultivariateNormalDiag(
            loc=prior_loc, scale_diag=prior_scale, name='initial_state_prior')

        lgssm = tfd.LinearGaussianStateSpaceModel(
            num_timesteps=7,
            transition_matrix=transition_matrix,
            transition_noise=tfd.MultivariateNormalTriL(
                loc=transition_bias, scale_tril=transition_scale_tril),
            # Trivial observation model to pass through the latent state.
            observation_matrix=tf.eye(dim),
            observation_noise=tfd.MultivariateNormalDiag(
                loc=tf.zeros(dim), scale_diag=tf.zeros(dim)),
            initial_state_prior=initial_state_prior)

        markov_chain = tfd.MarkovChain(
            initial_state_prior=initial_state_prior,
            transition_fn=lambda _, x: tfd.MultivariateNormalTriL(  # pylint: disable=g-long-lambda
                loc=tf.linalg.matvec(transition_matrix, x) + transition_bias,
                scale_tril=transition_scale_tril),
            num_steps=7)

        x = markov_chain.sample(5, seed=seed)
        self.assertAllClose(lgssm.log_prob(x),
                            markov_chain.log_prob(x),
                            rtol=1e-5)
        def solve_weight_space() -> tfd.Distribution:
            d_by_d = tf.matmul(x_train, x_train, transpose_a=True)
            precis = tf.linalg.diag(1 / weight_var) + noise_precis * d_by_d
            sqprec = jitter_cholesky(precis)
            solv_y = parallel_solve(tf.linalg.triangular_solve,
                                    sqprec,
                                    proj,
                                    lower=True)
            solv_x = parallel_solve(tf.linalg.triangular_solve,
                                    sqprec,
                                    tf.linalg.matrix_transpose(x_eval),
                                    lower=True)

            loc = noise_precis * tf.matmul(solv_x, solv_y, transpose_a=1)
            if self.mean_function is not None:
                loc += self.mean_function(x)

            if full_cov:
                scale_tril = jitter_cholesky(
                    tf.matmul(solv_x, solv_x, transpose_a=True))
                return tfd.MultivariateNormalTriL(loc=tf.squeeze(loc, axis=-1),
                                                  scale_tril=scale_tril)
            else:
                scale_diag = tf.sqrt(tf.reduce_sum(tf.square(solv_x), axis=-2))
                return tfd.MultivariateNormalDiag(loc=tf.squeeze(loc, axis=-1),
                                                  scale_diag=scale_diag)
        def solve_function_space() -> tfd.Distribution:
            n_by_n = tf.matmul(x_train,
                               weight_var[..., None, :] * x_train,
                               transpose_b=True)
            precis = (1 / noise_precis) * eyeN + n_by_n
            sqprec = jitter_cholesky(precis)
            solv_x = parallel_solve(tf.linalg.triangular_solve,
                                    sqprec,
                                    x_train,
                                    lower=True)
            w_covar = solv_x * weight_var[..., None, :]  # scratch variable
            w_covar = tf.linalg.diag(weight_var) \
                      - tf.matmul(w_covar, w_covar, transpose_a=True)

            loc = noise_precis * tf.matmul(x_eval, tf.matmul(w_covar, proj))
            covar = tf.matmul(x_eval,
                              tf.matmul(w_covar, x_eval, transpose_b=True))
            if full_cov:
                scale_tril = jitter_cholesky(covar)
                return tfd.MultivariateNormalTriL(loc=tf.squeeze(loc, axis=-1),
                                                  scale_tril=scale_tril)
            else:
                scale_diag = tf.sqrt(tf.linalg.diag_part(covar))
                return tfd.MultivariateNormalDiag(loc=tf.squeeze(loc, axis=-1),
                                                  scale_diag=scale_diag)
예제 #5
0
    def test_bijector(self):
        """Employ bijector when sampling."""
        dtype = np.float32
        true_mean = dtype([1, 1])
        true_cov = dtype([[1, 0.5], [0.5, 1]])
        target = tfd.MultivariateNormalTriL(
            loc=true_mean,
            scale_tril=tf.linalg.cholesky(true_cov)
        )

        def logp(x1, x2):
            return target.log_prob([x1, x2])

        def kernel_make_fn(target_log_prob_fn, state):
            inner_kernel = tfp.mcmc.RandomWalkMetropolis(target_log_prob_fn=target_log_prob_fn)
            return tfp.mcmc.TransformedTransitionKernel(
                inner_kernel=inner_kernel,
                bijector=tfp.bijectors.Exp()
            )

        kernel_list = [(0, kernel_make_fn),
                       (1, kernel_make_fn)]
        kernel = GibbsKernel(
            target_log_prob_fn=logp,
            kernel_list=kernel_list
        )
        samples = tfp.mcmc.sample_chain(
            num_results=20,
            current_state=[dtype(1), dtype(1)],
            kernel=kernel,
            trace_fn=None)
예제 #6
0
 def seasonal_transition_noise(t):
     noise_scale_tril = dist_util.pick_scalar_condition(
         is_last_day_of_season(t), drift_scale_tril,
         tf.zeros_like(drift_scale_tril))
     return tfd.MultivariateNormalTriL(loc=tf.zeros(
         num_seasons - 1, dtype=drift_scale.dtype),
                                       scale_tril=noise_scale_tril)
예제 #7
0
    def run_chain_and_get_estimation_error():
      chain_state = tfp.mcmc.sample_chain(
          num_results=num_steps,
          num_burnin_steps=0,
          current_state=initial_state,
          kernel=tfp.mcmc.NoUTurnSampler(
              tfd.MultivariateNormalTriL(loc=mu,
                                         scale_tril=scale_tril).log_prob,
              step_size=np.asarray([sigma1, sigma2]),
              parallel_iterations=1,
              seed=strm()),
          parallel_iterations=1,
          trace_fn=None)
      variance_est = tf.square(chain_state - mu)
      correlation_est = tf.reduce_prod(
          chain_state - mu, axis=-1, keepdims=True) / (sigma1 * sigma2)
      mcmc_samples = tf.concat([chain_state, variance_est, correlation_est],
                               axis=-1)

      expected = tf.reduce_mean(mcmc_samples, axis=[0, 1])

      ess = tf.reduce_sum(tfp.mcmc.effective_sample_size(mcmc_samples), axis=0)
      avg_monte_carlo_standard_error = tf.reduce_mean(
          tf.math.reduce_std(mcmc_samples, axis=0),
          axis=0) / tf.sqrt(ess)
      scaled_error = (
          tf.abs(expected - true_param) / avg_monte_carlo_standard_error)

      return tfd.Normal(loc=0., scale=1.).survival_function(scaled_error)
예제 #8
0
  def test_multipart_bijector(self):
    seed_stream = test_util.test_seed_stream()

    prior = tfd.JointDistributionSequential([
        tfd.Gamma(1., 1.),
        lambda scale: tfd.Uniform(0., scale),
        lambda concentration: tfd.CholeskyLKJ(4, concentration),
    ], validate_args=True)
    likelihood = lambda corr: tfd.MultivariateNormalTriL(scale_tril=corr)
    obs = self.evaluate(
        likelihood(
            prior.sample(seed=seed_stream())[-1]).sample(seed=seed_stream()))

    bij = prior.experimental_default_event_space_bijector()

    def target_log_prob(scale, conc, corr):
      return prior.log_prob(scale, conc, corr) + likelihood(corr).log_prob(obs)
    kernel = tfp.mcmc.HamiltonianMonteCarlo(target_log_prob,
                                            num_leapfrog_steps=3, step_size=.5)
    kernel = tfp.mcmc.TransformedTransitionKernel(kernel, bij)

    init = self.evaluate(
        tuple(tf.random.uniform(s, -2., 2., seed=seed_stream())
              for s in bij.inverse_event_shape(prior.event_shape)))
    state = bij.forward(init)
    kr = kernel.bootstrap_results(state)
    next_state, next_kr = kernel.one_step(state, kr, seed=seed_stream())
    self.evaluate((state, kr, next_state, next_kr))
    expected = (target_log_prob(*state) -
                bij.inverse_log_det_jacobian(state, [0, 0, 2]))
    actual = kernel._inner_kernel.target_log_prob_fn(*init)  # pylint: disable=protected-access
    self.assertAllClose(expected, actual)
예제 #9
0
  def testLangevinCorrectVolatilityGradient(self):
    """Check that the gradient of the volatility is computed correctly."""
    # Consider the example target distribution as in `testLangevin3DNormal`
    dtype = np.float32
    true_mean = dtype([1, 2, 7])
    true_cov = dtype([[1, 0.25, 0.25], [0.25, 1, 0.25], [0.25, 0.25, 1]])
    num_chains = 100

    chol = tf.linalg.cholesky(true_cov)
    target = tfd.MultivariateNormalTriL(loc=true_mean, scale_tril=chol)

    def target_log_prob(x, y):
      # Stack the input tensors together
      z = tf.concat([x, y], axis=-1)
      return target.log_prob(z)

    def volatility_fn(x, y):
      # Stack the input tensors together
      return [1. / (0.5 + 0.1 * tf.abs(x + y)),
              1. / (0.5 + 0.1 * tf.abs(y))]

    # Initial state of the chain
    init_state = [np.ones([num_chains, 2], dtype=dtype),
                  np.ones([num_chains, 1], dtype=dtype)]

    strm = test_util.test_seed_stream()
    # Define MALA with constant volatility
    langevin_unit = tfp.mcmc.MetropolisAdjustedLangevinAlgorithm(
        target_log_prob_fn=target_log_prob,
        step_size=0.1,
        seed=strm())
    # Define MALA with volatility being `volatility_fn`
    langevin_general = tfp.mcmc.MetropolisAdjustedLangevinAlgorithm(
        target_log_prob_fn=target_log_prob,
        step_size=0.1,
        volatility_fn=volatility_fn,
        seed=strm())

    # Initialize the samplers
    kernel_unit_volatility = langevin_unit.bootstrap_results(init_state)
    kernel_general = langevin_general.bootstrap_results(init_state)

    # For `langevin_general` volatility gradient should be zero.
    grad_1, grad_2 = kernel_unit_volatility.accepted_results.grads_volatility
    self.assertAllEqual(self.evaluate(grad_1),
                        np.zeros(shape=init_state[0].shape, dtype=dtype))
    self.assertAllEqual(self.evaluate(grad_2),
                        np.zeros(shape=init_state[1].shape, dtype=dtype))

    # For `langevin_unit` volatility gradient should be around -0.926 for
    # each direction.
    grad_1, grad_2 = kernel_general.accepted_results.grads_volatility
    self.assertAllClose(self.evaluate(grad_1),
                        -0.583 * np.ones(shape=init_state[0].shape,
                                         dtype=dtype),
                        atol=0.01, rtol=0.01)
    self.assertAllClose(self.evaluate(grad_2),
                        -0.926 * np.ones(shape=init_state[1].shape,
                                         dtype=dtype),
                        atol=0.01, rtol=0.01)
  def __init__(
      self,
      ndims=100,
      gamma_shape_parameter=0.5,
      max_eigvalue=None,
      seed=10,
      name='ill_conditioned_gaussian',
      pretty_name='Ill-Conditioned Gaussian',
  ):
    """Construct the ill-conditioned Gaussian.

    Args:
      ndims: Python `int`. Dimensionality of the Gaussian.
      gamma_shape_parameter: Python `float`. The shape parameter of the inverse
        Gamma distribution. Anything below 2 is likely to yield poorly
        conditioned covariance matrices.
      max_eigvalue: Python `float`. If set, will normalize the eigenvalues such
        that the maximum is this value.
      seed: Seed to use when generating the eigenvalues and the random
        orthogonal matrix.
      name: Python `str` name prefixed to Ops created by this class.
      pretty_name: A Python `str`. The pretty name of this model.
    """
    rng = onp.random.RandomState(seed=seed & (2**32 - 1))
    eigenvalues = 1. / onp.sort(
        rng.gamma(shape=gamma_shape_parameter, scale=1., size=ndims))
    if max_eigvalue is not None:
      eigenvalues *= max_eigvalue / eigenvalues.max()

    q, r = onp.linalg.qr(rng.randn(ndims, ndims))
    q *= onp.sign(onp.diag(r))

    covariance = (q * eigenvalues).dot(q.T)

    gaussian = tfd.MultivariateNormalTriL(
        loc=tf.zeros(ndims),
        scale_tril=tf.linalg.cholesky(
            tf.convert_to_tensor(covariance, dtype=tf.float32)))
    self._eigenvalues = eigenvalues

    sample_transformations = {
        'identity':
            model.Model.SampleTransformation(
                fn=lambda params: params,
                pretty_name='Identity',
                ground_truth_mean=onp.zeros(ndims),
                ground_truth_standard_deviation=onp.sqrt(onp.diag(covariance)),
            )
    }

    self._gaussian = gaussian

    super(IllConditionedGaussian, self).__init__(
        default_event_space_bijector=tfb.Identity(),
        event_shape=gaussian.event_shape,
        dtype=gaussian.dtype,
        name=name,
        pretty_name=pretty_name,
        sample_transformations=sample_transformations,
    )
예제 #11
0
    def test_float64(self):
        """Sample with dtype float64."""
        dtype = np.float64
        true_mean = dtype([1, 1])
        true_cov = dtype([[1, 0.5], [0.5, 1]])
        target = tfd.MultivariateNormalTriL(
            loc=true_mean,
            scale_tril=tf.linalg.cholesky(true_cov)
        )

        def logp(x1, x2):
            return target.log_prob([x1, x2])

        def kernel_make_fn(target_log_prob_fn, state):
            return tfp.mcmc.RandomWalkMetropolis(target_log_prob_fn=target_log_prob_fn)

        kernel_list = [(0, kernel_make_fn),
                       (1, kernel_make_fn)]
        kernel = GibbsKernel(
            target_log_prob_fn=logp,
            kernel_list=kernel_list
        )
        samples = tfp.mcmc.sample_chain(
            num_results=20,
            current_state=[dtype(1), dtype(1)],
            kernel=kernel,
            trace_fn=None)
예제 #12
0
    def testSampleWithSampleShape(self):
        mu = self._rng.rand(3, 5, 2)
        chol, sigma = self._random_chol(3, 5, 2, 2)
        chol[1, 0, 0, 0] = -chol[1, 0, 0, 0]
        chol[2, 3, 1, 1] = -chol[2, 3, 1, 1]

        mvn = tfd.MultivariateNormalTriL(mu, chol, validate_args=True)
        samples_val = self.evaluate(
            mvn.sample((10, 11, 12), seed=tfp_test_util.test_seed()))

        # Check sample shape
        self.assertEqual((10, 11, 12, 3, 5, 2), samples_val.shape)

        # Check sample means
        x = samples_val[:, :, :, 1, 1, :]
        self.assertAllClose(x.reshape(10 * 11 * 12, 2).mean(axis=0),
                            mu[1, 1],
                            atol=0.05)

        # Check that log_prob(samples) works
        log_prob_val = self.evaluate(mvn.log_prob(samples_val))
        x_log_pdf = log_prob_val[:, :, :, 1, 1]
        expected_log_pdf = stats.multivariate_normal(
            mean=mu[1, 1, :], cov=sigma[1, 1, :, :]).logpdf(x)
        self.assertAllClose(expected_log_pdf, x_log_pdf)
예제 #13
0
    def test_3d_mvn(self):
        """Sample from 3-variate Gaussian Distribution."""
        dtype = np.float32

        true_mean = dtype([1., 2., 3.])
        true_cov = dtype([[0.36, 0.12, 0.06], [0.12, 0.29, -0.13],
                          [0.06, -0.13, 0.26]])
        target = tfd.MultivariateNormalTriL(
            loc=true_mean, scale_tril=tf.linalg.cholesky(true_cov))
        kernel = AdaptiveRandomWalkMetropolis(
            target_log_prob_fn=target.log_prob,
            initial_covariance=dtype(0.001) * np.eye(3, dtype=dtype))
        samples = tfp.mcmc.sample_chain(num_results=2000,
                                        current_state=dtype([0.1, 0.1, 0.1]),
                                        kernel=kernel,
                                        num_burnin_steps=500,
                                        trace_fn=None)

        sample_mean = tf.math.reduce_mean(samples, axis=0)
        [sample_mean_] = self.evaluate([sample_mean])
        self.assertAllClose(sample_mean_, true_mean, atol=0.1, rtol=0.1)

        sample_cov = tfp.stats.covariance(samples)
        sample_cov_ = self.evaluate(sample_cov)
        self.assertAllClose(sample_cov_, true_cov, atol=0.1, rtol=0.1)
예제 #14
0
  def testFourDimNormal(self):
    """Sampling from a 4-D Multivariate Normal distribution."""

    dtype = np.float32
    true_mean = dtype([0, 4, -8, 2])
    true_cov = np.eye(4, dtype=dtype)
    num_results, tolerance = self._get_mode_dependent_settings()
    num_chains = 10
    target = tfd.MultivariateNormalTriL(loc=true_mean, scale_tril=true_cov)

    # Initial state of the chain
    init_state = np.ones([num_chains, 4], dtype=dtype)

    # Run Slice Samper for `num_results` iterations for `num_chains`
    # independent chains:
    states, _ = tfp.mcmc.sample_chain(
        num_results=num_results,
        current_state=init_state,
        kernel=tfp.mcmc.SliceSampler(
            target_log_prob_fn=tf.function(target.log_prob, autograph=False),
            step_size=1.0,
            max_doublings=5,
            seed=test_util.test_seed_stream()),
        num_burnin_steps=100,
        parallel_iterations=1)

    result = tf.reshape(states, [-1, 4])
    sample_mean = tf.reduce_mean(result, axis=0)
    sample_cov = tfp.stats.covariance(result)

    self.assertAllClose(true_mean, sample_mean, atol=tolerance, rtol=tolerance)
    self.assertAllClose(true_cov, sample_cov, atol=tolerance, rtol=tolerance)
예제 #15
0
  def testKLNonBatch(self):
    batch_shape = []
    event_shape = [2]
    mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
    mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape)
    mvn_a = tfd.MultivariateNormalTriL(
        loc=mu_a, scale_tril=np.linalg.cholesky(sigma_a), validate_args=True)
    mvn_b = tfd.MultivariateNormalTriL(
        loc=mu_b, scale_tril=np.linalg.cholesky(sigma_b), validate_args=True)

    kl = tfd.kl_divergence(mvn_a, mvn_b)
    self.assertEqual(batch_shape, kl.shape)

    kl_v = self.evaluate(kl)
    expected_kl = _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b)
    self.assertAllClose(expected_kl, kl_v)
예제 #16
0
    def _fn(kernel_size, bias_size, dtype=None):
        smallconst = np.log(np.expm1(1.))
        
        n_weights_block = kernel_size//C
        n_bias_block = bias_size//C

        n_weight_mean_params = n_weights_block
        n_weight_cov_params = tfp.layers.MultivariateNormalTriL.params_size(n_weights_block) - n_weights_block

        n_params_total = C*(n_weight_mean_params + n_weight_cov_params) + bias_size
        #print("{} params in total".format(n_params_total))

        block_param_indices = tf.split(np.arange(n_params_total - bias_size), C)
        split_array = [n_weight_mean_params, n_weight_cov_params]
        split_param_idxs = [tf.split(x, split_array, axis=0) for x in block_param_indices]

        model =  tf.keras.Sequential([
            tfpl.VariableLayer(n_params_total, dtype=dtype),
            tfpl.DistributionLambda(lambda t: tfd.Blockwise(
                    [
                        tfd.MultivariateNormalTriL(
                            loc=tf.gather(t,split_param_idxs[c][0], axis=-1),
                            scale_tril=tfp.math.fill_triangular(
                                1e-5 + tf.nn.softplus(smallconst + tf.gather(t,split_param_idxs[c][1], axis=-1)))
                        ) for c in range(C)
                    ] +
                    [ tfd.VectorDeterministic(loc=t[...,-bias_size:]) ]
                ) 
            )
        ])
        return model
예제 #17
0
    def testTwoDimNormalDynamicShape(self):
        """Checks that dynamic batch shapes for the initial state are supported."""
        if tf.executing_eagerly(): return

        dtype = np.float32
        true_mean = dtype([0, 0])
        true_cov = dtype([[1, 0.5], [0.5, 1]])
        num_results = 200
        num_chains = 75
        # Target distribution is defined through the Cholesky decomposition.
        chol = tf.linalg.cholesky(true_cov)
        target = tfd.MultivariateNormalTriL(loc=true_mean, scale_tril=chol)

        # Assume that the state is passed as a list of 1-d tensors `x` and `y`.
        # Then the target log-density is defined as follows:
        def target_log_prob(x, y):
            # Stack the input tensors together
            z = tf.stack([x, y], axis=-1) - true_mean
            return target.log_prob(z)

        # Initial state of the chain
        init_state = [
            np.ones([num_chains, 1], dtype=dtype),
            np.ones([num_chains, 1], dtype=dtype)
        ]
        placeholder_init_state = [
            tf1.placeholder_with_default(init_state[0], shape=[None, 1]),
            tf1.placeholder_with_default(init_state[1], shape=[None, 1])
        ]
        # Run Slice Samper for `num_results` iterations for `num_chains`
        # independent chains:
        [x, y], _ = tfp.mcmc.sample_chain(
            num_results=num_results,
            current_state=placeholder_init_state,
            kernel=tfp.mcmc.SliceSampler(target_log_prob_fn=target_log_prob,
                                         step_size=1.0,
                                         max_doublings=5,
                                         seed=47),
            num_burnin_steps=200,
            num_steps_between_results=1,
            parallel_iterations=1)

        states = tf.stack([x, y], axis=-1)
        sample_mean = tf.reduce_mean(input_tensor=states, axis=[0, 1])
        z = states - sample_mean
        sample_cov = tf.reduce_mean(input_tensor=tf.matmul(z,
                                                           z,
                                                           transpose_a=True),
                                    axis=[0, 1])
        [sample_mean, sample_cov] = self.evaluate([sample_mean, sample_cov])

        self.assertAllClose(true_mean,
                            b=np.squeeze(sample_mean),
                            atol=0.1,
                            rtol=0.1)
        self.assertAllClose(true_cov,
                            b=np.squeeze(sample_cov),
                            atol=0.1,
                            rtol=0.1)
예제 #18
0
 def testVariableLocation(self):
   loc = tf.Variable([1., 1.])
   scale = tf.eye(2)
   d = tfd.MultivariateNormalTriL(loc, scale, validate_args=True)
   self.evaluate(loc.initializer)
   with tf.GradientTape() as tape:
     lp = d.log_prob([0., 0.])
   self.assertIsNotNone(tape.gradient(lp, loc))
예제 #19
0
 def testVariableScale(self):
   loc = tf.constant([1., 1.])
   scale = tf.Variable([[1., 0.], [0., 1.]])
   d = tfd.MultivariateNormalTriL(loc, scale, validate_args=True)
   self.evaluate(scale.initializer)
   with tf.GradientTape() as tape:
     lp = d.log_prob([0., 0.])
   self.assertIsNotNone(tape.gradient(lp, scale))
예제 #20
0
 def new(params, event_size, validate_args=False, name=None):
   """Create the distribution instance from a `params` vector."""
   with tf.name_scope(name, 'MultivariateNormalTriL', [params, event_size]):
     return tfd.MultivariateNormalTriL(
         loc=params[..., :event_size],
         scale_tril=tfb.ScaleTriL(validate_args=validate_args)(
             params[..., event_size:]),
         validate_args=validate_args)
예제 #21
0
 def _mvn_pair(self, loc, cov_operator):
     """Construct a pair of MVNs."""
     tril = tfd.MultivariateNormalTriL(
         loc=loc, scale_tril=cov_operator.cholesky().to_dense())
     low_rank_update = (
         MultivariateNormalLowRankUpdateLinearOperatorCovariance(
             loc=loc, cov_operator=cov_operator, validate_args=True))
     return MVNPair(low_rank_update=low_rank_update, tril=tril)
 def solve_weight_space() -> tfd.Distribution:
     d_by_d = tf.matmul(x, x, transpose_a=True)
     precis = tf.linalg.diag(1 / weight_var) + noise_precis * d_by_d
     sqprec = jitter_cholesky(precis)
     means = noise_precis * parallel_solve(tf.linalg.cholesky_solve,
                                           sqprec, proj)
     scale_tril = CholeskyToInvCholesky()(sqprec)  # [!] improve me
     return tfd.MultivariateNormalTriL(loc=tf.squeeze(means, -1),
                                       scale_tril=scale_tril)
def ensemble_kalman_filter_log_marginal_likelihood(state,
                                                   observation,
                                                   observation_fn,
                                                   seed=None,
                                                   name=None):
    """Ensemble Kalman Filter Log Marginal Likelihood.

  The [Ensemble Kalman Filter](
  https://en.wikipedia.org/wiki/Ensemble_Kalman_filter) is a Monte Carlo
  version of the traditional Kalman Filter.

  This method estimates (logarithm of) the marginal likelihood of the
  observation at step `k`, `Y_k`, given previous observations from steps
  `1` to `k-1`, `Y_{1:k}`. In other words, `Log[p(Y_k | Y_{1:k})]`.
  This function's approximation to `p(Y_k | Y_{1:k})` is correct under a
  Linear Gaussian state space model assumption, as ensemble size --> infinity.

  Args:
    state: Instance of `EnsembleKalmanFilterState` at step `k`,
      conditioned on previous observations `Y_{1:k}`. Typically this is the
      output of `ensemble_kalman_filter_predict`.
    observation: `Tensor` representing the observation at step `k`.
    observation_fn: callable returning an instance of
      `tfd.MultivariateNormalLinearOperator` along with an extra information
      to be returned in the `EnsembleKalmanFilterState`.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
    name: Python `str` name for ops created by this method.
      Default value: `None`
      (i.e., `'ensemble_kalman_filter_log_marginal_likelihood'`).

  Returns:
    log_marginal_likelihood: `Tensor` with same dtype as `state`.
  """

    with tf.name_scope(name
                       or 'ensemble_kalman_filter_log_marginal_likelihood'):
        observation_particles_dist, unused_extra = observation_fn(
            state.step, state.particles, state.extra)

        common_dtype = dtype_util.common_dtype(
            [observation_particles_dist, observation], dtype_hint=tf.float32)

        observation = tf.convert_to_tensor(observation, dtype=common_dtype)

        if not isinstance(observation_particles_dist,
                          distributions.MultivariateNormalLinearOperator):
            raise ValueError(
                'Expected `observation_fn` to return an instance of '
                '`MultivariateNormalLinearOperator`')

        observation_particles = observation_particles_dist.sample(seed=seed)
        observation_dist = distributions.MultivariateNormalTriL(
            loc=tf.reduce_mean(observation_particles, axis=0),
            scale_tril=tf.linalg.cholesky(_covariance(observation_particles)))

        return observation_dist.log_prob(observation)
예제 #24
0
 def testGradientWorksForMultivariateNormalTriL(self):
     # TODO(b/72831017): Remove this once bijector cacheing is fixed for
     # graph mode.
     if not tf.executing_eagerly():
         self.skipTest('Gradients get None values in graph mode.')
     d = tfd.MultivariateNormalTriL(scale_tril=tf.eye(2))
     x = d.sample(seed=test_util.test_seed())
     fn_result, grads = util.maybe_call_fn_and_grads(d.log_prob, x)
     self.assertAllEqual(False, fn_result is None)
     self.assertAllEqual([False], [g is None for g in grads])
예제 #25
0
  def testEntropy(self):
    mu = self._rng.rand(2)
    chol, sigma = self._random_chol(2, 2)
    mvn = tfd.MultivariateNormalTriL(mu, chol, validate_args=True)
    entropy = mvn.entropy()

    scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma)
    expected_entropy = scipy_mvn.entropy()
    self.assertEqual(entropy.shape, ())
    self.assertAllClose(expected_entropy, self.evaluate(entropy))
예제 #26
0
 def testSample(self):
   mu = self._rng.rand(2)
   chol, sigma = self._random_chol(2, 2)
   n = tf.constant(100000)
   mvn = tfd.MultivariateNormalTriL(mu, chol, validate_args=True)
   samples = mvn.sample(n, seed=test_util.test_seed())
   sample_values = self.evaluate(samples)
   self.assertEqual(samples.shape, [int(100e3), 2])
   self.assertAllClose(sample_values.mean(axis=0), mu, atol=1e-2)
   self.assertAllClose(np.cov(sample_values, rowvar=0), sigma, atol=0.06)
예제 #27
0
    def testLangevin3DNormal(self):
        """Sampling from a 3-D Multivariate Normal distribution."""
        dtype = np.float32
        true_mean = dtype([1, 2, 7])
        true_cov = dtype([[1, 0.25, 0.25], [0.25, 1, 0.25], [0.25, 0.25, 1]])
        num_results = 500
        num_chains = 500

        # Target distribution is defined through the Cholesky decomposition
        chol = tf.linalg.cholesky(true_cov)
        target = tfd.MultivariateNormalTriL(loc=true_mean, scale_tril=chol)

        # Assume that the state is passed as a list of tensors `x` and `y`.
        # Then the target log-density is defined as follows:
        def target_log_prob(x, y):
            # Stack the input tensors together
            z = tf.concat([x, y], axis=-1)
            return target.log_prob(z)

        # Initial state of the chain
        init_state = [
            np.ones([num_chains, 2], dtype=dtype),
            np.ones([num_chains, 1], dtype=dtype)
        ]

        # Run MALA with normal proposal for `num_results` iterations for
        # `num_chains` independent chains:
        states, _ = tfp.mcmc.sample_chain(
            num_results=num_results,
            current_state=init_state,
            kernel=tfp.mcmc.MetropolisAdjustedLangevinAlgorithm(
                target_log_prob_fn=target_log_prob,
                step_size=.1,
                seed=test_util.test_seed()),
            num_burnin_steps=200,
            num_steps_between_results=1,
            parallel_iterations=1)

        states = tf.concat(states, axis=-1)
        sample_mean = tf.reduce_mean(states, axis=[0, 1])
        x = (states - sample_mean)[..., tf.newaxis]
        sample_cov = tf.reduce_mean(tf.matmul(
            x, tf.transpose(a=x, perm=[0, 1, 3, 2])),
                                    axis=[0, 1])

        sample_mean_, sample_cov_ = self.evaluate([sample_mean, sample_cov])

        self.assertAllClose(np.squeeze(sample_mean_),
                            true_mean,
                            atol=0.1,
                            rtol=0.1)
        self.assertAllClose(np.squeeze(sample_cov_),
                            true_cov,
                            atol=0.1,
                            rtol=0.1)
예제 #28
0
 def testDocstrSliceExample(self):
     x = tf.random.normal([5, 3, 2, 2])
     cov = tf.matmul(x, x, transpose_b=True)
     chol = tf.linalg.cholesky(cov)
     loc = tf.random.normal([4, 1, 3, 1])
     mvn = tfd.MultivariateNormalTriL(loc, chol)
     self.assertAllEqual((4, 5, 3), mvn.batch_shape)
     self.assertAllEqual((2, ), mvn.event_shape)
     mvn2 = mvn[:, 3:, ..., ::-1, tf.newaxis]
     self.assertAllEqual((4, 2, 3, 1), mvn2.batch_shape)
     self.assertAllEqual((2, ), mvn2.event_shape)
예제 #29
0
  def testEntropyMultidimensional(self):
    mu = self._rng.rand(3, 5, 2)
    chol, sigma = self._random_chol(3, 5, 2, 2)
    mvn = tfd.MultivariateNormalTriL(mu, chol, validate_args=True)
    entropy = mvn.entropy()

    # Scipy doesn't do batches, so test one of them.
    expected_entropy = stats.multivariate_normal(
        mean=mu[1, 1, :], cov=sigma[1, 1, :, :]).entropy()
    self.assertEqual(entropy.shape, (3, 5))
    self.assertAllClose(expected_entropy, self.evaluate(entropy)[1, 1])
예제 #30
0
    def testJacobianDiagonal3DListInput(self):
        """Tests that the diagonal of the Jacobian matrix computes correctly."""

        dtype = np.float32
        true_mean = dtype([0, 0, 0])
        true_cov = dtype([[1, 0.25, 0.25], [0.25, 2, 0.25], [0.25, 0.25, 3]])
        chol = tf.linalg.cholesky(true_cov)
        target = tfd.MultivariateNormalTriL(loc=true_mean, scale_tril=chol)

        # Assume that the state is passed as a list of tensors `x` and `y`.
        # Then the target function is defined as follows:
        def target_fn(x, y):
            # Stack the input tensors together
            z = tf.concat([x, y], axis=-1) - true_mean
            return target.log_prob(z)

        sample_shape = [3, 5]
        state = [
            tf.ones(sample_shape + [2], dtype=dtype),
            tf.ones(sample_shape + [1], dtype=dtype)
        ]
        fn_val, grads = tfp.math.value_and_gradient(target_fn, state)
        grad_fn = lambda *args: tfp.math.value_and_gradient(target_fn, args)[1]

        _, diag_jacobian_shape_passed = tfp.math.diag_jacobian(
            xs=state, ys=grads, fn=grad_fn, sample_shape=ps.shape(fn_val))
        _, diag_jacobian_shape_none = tfp.math.diag_jacobian(xs=state,
                                                             ys=grads,
                                                             fn=grad_fn)

        true_diag_jacobian_1 = np.zeros(sample_shape + [2])
        true_diag_jacobian_1[..., 0] = -1.05
        true_diag_jacobian_1[..., 1] = -0.52

        true_diag_jacobian_2 = -0.34 * np.ones(sample_shape + [1])

        self.assertAllClose(self.evaluate(diag_jacobian_shape_passed[0]),
                            true_diag_jacobian_1,
                            atol=0.01,
                            rtol=0.01)
        self.assertAllClose(self.evaluate(diag_jacobian_shape_none[0]),
                            true_diag_jacobian_1,
                            atol=0.01,
                            rtol=0.01)

        self.assertAllClose(self.evaluate(diag_jacobian_shape_passed[1]),
                            true_diag_jacobian_2,
                            atol=0.01,
                            rtol=0.01)
        self.assertAllClose(self.evaluate(diag_jacobian_shape_none[1]),
                            true_diag_jacobian_2,
                            atol=0.01,
                            rtol=0.01)