def testBasic(self, centered, use_fft): """Checks that you get finite values given unconstrained samples. We check `log_prob` as well as the values of the expectations. Args: centered: Whether or not to use the centered parameterization. use_fft: Whether or not to use FFT-based convolution to implement the centering transformation. """ model = vectorized_stochastic_volatility.VectorizedStochasticVolatility( centered_returns=_test_dataset(), centered=centered, use_fft=use_fft) self.validate_log_prob_and_transforms( model, sample_transformation_shapes=dict( identity={ 'persistence_of_volatility': [], 'mean_log_volatility': [], 'white_noise_shock_scale': [], 'log_volatility': [5] }))
def testParameterizationsConsistent(self): """Run HMC for both parameterizations, and compare posterior means.""" self.skipTest('Broken by omnistaging b/168705919') centered_returns = _test_dataset() centered_model = ( vectorized_stochastic_volatility.VectorizedStochasticVolatility( centered_returns, centered=True)) non_centered_model = ( vectorized_stochastic_volatility.VectorizedStochasticVolatility( centered_returns, centered=False, use_fft=False)) non_centered_fft_model = ( vectorized_stochastic_volatility.VectorizedStochasticVolatility( centered_returns, centered=False, use_fft=True)) logging.info('Centered:') centered_results = self.evaluate( test_util.run_hmc_on_model( centered_model, num_chains=4, num_steps=1000, num_leapfrog_steps=10, step_size=0.1, # TF XLA is very slow on this problem use_xla=BACKEND == 'backend_jax', target_accept_prob=0.7)) logging.info('Acceptance rate: %s', centered_results.accept_rate) logging.info('ESS: %s', centered_results.ess) logging.info('r_hat: %s', centered_results.r_hat) logging.info('Non-centered (without FFT):') non_centered_results = self.evaluate( test_util.run_hmc_on_model( non_centered_model, num_chains=4, num_steps=1000, num_leapfrog_steps=10, step_size=0.1, # TF XLA is very slow on this problem use_xla=BACKEND == 'backend_jax', target_accept_prob=0.7)) logging.info('Acceptance rate: %s', non_centered_results.accept_rate) logging.info('ESS: %s', non_centered_results.ess) logging.info('r_hat: %s', non_centered_results.r_hat) logging.info('Non-centered (with FFT):') non_centered_fft_results = self.evaluate( test_util.run_hmc_on_model( non_centered_fft_model, num_chains=4, num_steps=1000, num_leapfrog_steps=10, step_size=0.1, # TF XLA is very slow on this problem use_xla=BACKEND == 'backend_jax', target_accept_prob=0.7)) logging.info('Acceptance rate: %s', non_centered_fft_results.accept_rate) logging.info('ESS: %s', non_centered_fft_results.ess) logging.info('r_hat: %s', non_centered_fft_results.r_hat) centered_params = self.evaluate( tf.nest.map_structure( tf.identity, centered_model.sample_transformations['identity']( centered_results.chain))) non_centered_params = self.evaluate( tf.nest.map_structure( tf.identity, non_centered_model.sample_transformations['identity']( non_centered_results.chain))) non_centered_fft_params = self.evaluate( tf.nest.map_structure( tf.identity, non_centered_fft_model.sample_transformations['identity']( non_centered_fft_results.chain))) def get_mean_and_var(chain): def one_part(chain): mean = chain.mean((0, 1)) var = chain.var((0, 1)) ess = 1. / (1. / self.evaluate( tfp.mcmc.effective_sample_size( chain, filter_beyond_positive_pairs=True))).mean(0) return mean, var / ess mean_var = tf.nest.map_structure(one_part, chain) return (nest.map_structure_up_to(chain, lambda x: x[0], mean_var), nest.map_structure_up_to(chain, lambda x: x[1], mean_var)) centered_mean, centered_var = get_mean_and_var(centered_params) non_centered_mean, non_centered_var = get_mean_and_var(non_centered_params) non_centered_fft_mean, non_centered_fft_var = get_mean_and_var( non_centered_fft_params) def get_atol(var1, var2): # TODO(b/144290399): Use the full atol vector. max_var_per_rv = tf.nest.map_structure( lambda v1, v2: (3. * np.sqrt(v1 + v2)).max(), var1, var2) return functools.reduce(max, tf.nest.flatten(max_var_per_rv)) self.assertAllCloseNested( centered_mean, non_centered_mean, atol=get_atol(centered_var, non_centered_var), ) self.assertAllCloseNested( centered_mean, non_centered_fft_mean, atol=get_atol(centered_var, non_centered_fft_var), )