Exemple #1
0
  def testBasic(self, centered, use_fft):
    """Checks that you get finite values given unconstrained samples.

    We check `log_prob` as well as the values of the expectations.

    Args:
      centered: Whether or not to use the centered parameterization.
      use_fft: Whether or not to use FFT-based convolution to implement the
        centering transformation.
    """
    model = vectorized_stochastic_volatility.VectorizedStochasticVolatility(
        centered_returns=_test_dataset(), centered=centered, use_fft=use_fft)
    self.validate_log_prob_and_transforms(
        model,
        sample_transformation_shapes=dict(
            identity={
                'persistence_of_volatility': [],
                'mean_log_volatility': [],
                'white_noise_shock_scale': [],
                'log_volatility': [5]
            }))
Exemple #2
0
  def testParameterizationsConsistent(self):
    """Run HMC for both parameterizations, and compare posterior means."""
    self.skipTest('Broken by omnistaging b/168705919')
    centered_returns = _test_dataset()
    centered_model = (
        vectorized_stochastic_volatility.VectorizedStochasticVolatility(
            centered_returns, centered=True))
    non_centered_model = (
        vectorized_stochastic_volatility.VectorizedStochasticVolatility(
            centered_returns, centered=False, use_fft=False))
    non_centered_fft_model = (
        vectorized_stochastic_volatility.VectorizedStochasticVolatility(
            centered_returns, centered=False, use_fft=True))

    logging.info('Centered:')
    centered_results = self.evaluate(
        test_util.run_hmc_on_model(
            centered_model,
            num_chains=4,
            num_steps=1000,
            num_leapfrog_steps=10,
            step_size=0.1,
            # TF XLA is very slow on this problem
            use_xla=BACKEND == 'backend_jax',
            target_accept_prob=0.7))
    logging.info('Acceptance rate: %s', centered_results.accept_rate)
    logging.info('ESS: %s', centered_results.ess)
    logging.info('r_hat: %s', centered_results.r_hat)

    logging.info('Non-centered (without FFT):')
    non_centered_results = self.evaluate(
        test_util.run_hmc_on_model(
            non_centered_model,
            num_chains=4,
            num_steps=1000,
            num_leapfrog_steps=10,
            step_size=0.1,
            # TF XLA is very slow on this problem
            use_xla=BACKEND == 'backend_jax',
            target_accept_prob=0.7))
    logging.info('Acceptance rate: %s', non_centered_results.accept_rate)
    logging.info('ESS: %s', non_centered_results.ess)
    logging.info('r_hat: %s', non_centered_results.r_hat)

    logging.info('Non-centered (with FFT):')
    non_centered_fft_results = self.evaluate(
        test_util.run_hmc_on_model(
            non_centered_fft_model,
            num_chains=4,
            num_steps=1000,
            num_leapfrog_steps=10,
            step_size=0.1,
            # TF XLA is very slow on this problem
            use_xla=BACKEND == 'backend_jax',
            target_accept_prob=0.7))
    logging.info('Acceptance rate: %s', non_centered_fft_results.accept_rate)
    logging.info('ESS: %s', non_centered_fft_results.ess)
    logging.info('r_hat: %s', non_centered_fft_results.r_hat)

    centered_params = self.evaluate(
        tf.nest.map_structure(
            tf.identity, centered_model.sample_transformations['identity'](
                centered_results.chain)))
    non_centered_params = self.evaluate(
        tf.nest.map_structure(
            tf.identity, non_centered_model.sample_transformations['identity'](
                non_centered_results.chain)))
    non_centered_fft_params = self.evaluate(
        tf.nest.map_structure(
            tf.identity,
            non_centered_fft_model.sample_transformations['identity'](
                non_centered_fft_results.chain)))

    def get_mean_and_var(chain):

      def one_part(chain):
        mean = chain.mean((0, 1))
        var = chain.var((0, 1))
        ess = 1. / (1. / self.evaluate(
            tfp.mcmc.effective_sample_size(
                chain, filter_beyond_positive_pairs=True))).mean(0)
        return mean, var / ess

      mean_var = tf.nest.map_structure(one_part, chain)
      return (nest.map_structure_up_to(chain, lambda x: x[0], mean_var),
              nest.map_structure_up_to(chain, lambda x: x[1], mean_var))

    centered_mean, centered_var = get_mean_and_var(centered_params)
    non_centered_mean, non_centered_var = get_mean_and_var(non_centered_params)
    non_centered_fft_mean, non_centered_fft_var = get_mean_and_var(
        non_centered_fft_params)

    def get_atol(var1, var2):
      # TODO(b/144290399): Use the full atol vector.
      max_var_per_rv = tf.nest.map_structure(
          lambda v1, v2: (3. * np.sqrt(v1 + v2)).max(), var1, var2)
      return functools.reduce(max, tf.nest.flatten(max_var_per_rv))

    self.assertAllCloseNested(
        centered_mean,
        non_centered_mean,
        atol=get_atol(centered_var, non_centered_var),
    )
    self.assertAllCloseNested(
        centered_mean,
        non_centered_fft_mean,
        atol=get_atol(centered_var, non_centered_fft_var),
    )