Esempio n. 1
0
    def testMarkovChainLogprobMatchesOriginal(self):
        model = lorenz_system.ConvectionLorenzBridgeUnknownScales(
            use_markov_chain=False)
        markov_chain_model = lorenz_system.ConvectionLorenzBridgeUnknownScales(
            use_markov_chain=True)

        x = self.evaluate(model.prior_distribution().sample(
            20, seed=tfp_test_util.test_seed()))
        self.assertAllClose(model.unnormalized_log_prob(x),
                            markov_chain_model.unnormalized_log_prob(
                                type(markov_chain_model.dtype)(x[0], x[1],
                                                               tf.stack(
                                                                   x[2:],
                                                                   axis=-2))),
                            atol=1e-2)
    def testConvectionLorenzBridgeHMC(self):
        """Checks approximate samples from the model against the ground truth."""
        self.skipTest('b/171518508')
        model = lorenz_system.ConvectionLorenzBridgeUnknownScales()

        self.validate_ground_truth_using_hmc(
            model,
            num_chains=2,
            num_steps=2000,
            num_leapfrog_steps=240,
            step_size=0.03,
        )
 def testConvectionLorenzBridge(self):
     """Checks that unconstrained parameters yield finite joint densities."""
     model = lorenz_system.ConvectionLorenzBridgeUnknownScales()
     self.validate_log_prob_and_transforms(
         model,
         sample_transformation_shapes=dict(
             identity={
                 'innovation_scale': [],
                 'observation_scale': [],
                 'latents': [30, 3]
             }),
         check_ground_truth_mean_standard_error=True,
         check_ground_truth_mean=True,
         check_ground_truth_standard_deviation=True)