def propose_and_update_log_weights_fn(_, weighted_particles, seed=None):
   proposed_particles = tfd.Normal(
       loc=weighted_particles.particles, scale=1.).sample(seed=seed)
   return WeightedParticles(
       particles=proposed_particles,
       log_weights=weighted_particles.log_weights + tfd.Normal(
           loc=-2.6, scale=0.1).log_prob(proposed_particles))
    def test_steps_are_reproducible(self):
        def propose_and_update_log_weights_fn(_,
                                              weighted_particles,
                                              seed=None):
            proposed_particles = tfd.Normal(loc=weighted_particles.particles,
                                            scale=1.).sample(seed=seed)
            return WeightedParticles(
                particles=proposed_particles,
                log_weights=weighted_particles.log_weights +
                tfd.Normal(loc=-2.6, scale=0.1).log_prob(proposed_particles))

        num_particles = 16
        initial_state = self.evaluate(
            WeightedParticles(
                particles=tf.random.normal([num_particles],
                                           seed=test_util.test_seed()),
                log_weights=tf.fill([num_particles],
                                    -tf.math.log(float(num_particles)))))

        # Run a couple of steps.
        kernel = SequentialMonteCarlo(
            propose_and_update_log_weights_fn=propose_and_update_log_weights_fn,
            resample_fn=tfp.experimental.mcmc.resample_systematic,
            resample_criterion_fn=tfp.experimental.mcmc.ess_below_threshold)
        seed = test_util.test_seed()
        tf.random.set_seed(seed)
        seed_stream = tfp.util.SeedStream(seed=seed, salt='test')
        state, results = kernel.one_step(
            state=initial_state,
            kernel_results=kernel.bootstrap_results(initial_state),
            seed=seed_stream())
        state, results = kernel.one_step(state=state,
                                         kernel_results=results,
                                         seed=seed_stream())
        state, results = self.evaluate(
            (tf.nest.map_structure(tf.convert_to_tensor, state),
             tf.nest.map_structure(tf.convert_to_tensor, results)))

        # Re-initialize and run the same steps with the same seed.
        kernel2 = SequentialMonteCarlo(
            propose_and_update_log_weights_fn=propose_and_update_log_weights_fn,
            resample_fn=tfp.experimental.mcmc.resample_systematic,
            resample_criterion_fn=tfp.experimental.mcmc.ess_below_threshold)
        tf.random.set_seed(seed)
        seed_stream = tfp.util.SeedStream(seed=seed, salt='test')
        state2, results2 = kernel2.one_step(
            state=initial_state,
            kernel_results=kernel2.bootstrap_results(initial_state),
            seed=seed_stream())
        state2, results2 = kernel2.one_step(state=state2,
                                            kernel_results=results2,
                                            seed=seed_stream())
        state2, results2 = self.evaluate(
            (tf.nest.map_structure(tf.convert_to_tensor, state2),
             tf.nest.map_structure(tf.convert_to_tensor, results2)))

        # Results should match.
        self.assertAllCloseNested(state, state2)
        self.assertAllCloseNested(results, results2)
 def propose_and_update_log_weights_fn(_,
                                       weighted_particles,
                                       transition_scale,
                                       seed=None):
   proposal_dist = tfd.Normal(loc=weighted_particles.particles, scale=1.)
   transition_dist = tfd.Normal(loc=weighted_particles.particles,
                                scale=transition_scale)
   proposed_particles = proposal_dist.sample(seed=seed)
   return WeightedParticles(
       particles=proposed_particles,
       log_weights=(weighted_particles.log_weights +
                    transition_dist.log_prob(proposed_particles) -
                    proposal_dist.log_prob(proposed_particles)))
    def testMarginalLikelihoodGradientIsDefined(self):
        num_particles = 16
        seeds = samplers.split_seed(test_util.test_seed(), n=3)
        initial_state = self.evaluate(
            WeightedParticles(
                particles=samplers.normal([num_particles], seed=seeds[0]),
                log_weights=tf.fill([num_particles],
                                    -tf.math.log(float(num_particles)))))

        def propose_and_update_log_weights_fn(_,
                                              weighted_particles,
                                              transition_scale,
                                              seed=None):
            proposal_dist = tfd.Normal(loc=weighted_particles.particles,
                                       scale=1.)
            transition_dist = tfd.Normal(loc=weighted_particles.particles,
                                         scale=transition_scale)
            proposed_particles = proposal_dist.sample(seed=seed)
            return WeightedParticles(
                particles=proposed_particles,
                log_weights=(weighted_particles.log_weights +
                             transition_dist.log_prob(proposed_particles) -
                             proposal_dist.log_prob(proposed_particles)))

        def marginal_logprob(transition_scale):
            kernel = SequentialMonteCarlo(
                propose_and_update_log_weights_fn=functools.partial(
                    propose_and_update_log_weights_fn,
                    transition_scale=transition_scale))
            state, results = kernel.one_step(
                state=initial_state,
                kernel_results=kernel.bootstrap_results(initial_state),
                seed=seeds[1])
            state, results = kernel.one_step(state=state,
                                             kernel_results=results,
                                             seed=seeds[2])
            return results.accumulated_log_marginal_likelihood

        _, grad_lp = tfp.math.value_and_gradient(marginal_logprob, 1.5)
        self.assertIsNotNone(grad_lp)
        self.assertNotAllZero(grad_lp)
    def test_steps_are_reproducible(self):
        def propose_and_update_log_weights_fn(_,
                                              weighted_particles,
                                              seed=None):
            proposed_particles = tfd.Normal(loc=weighted_particles.particles,
                                            scale=1.).sample(seed=seed)
            return WeightedParticles(
                particles=proposed_particles,
                log_weights=weighted_particles.log_weights +
                tfd.Normal(loc=-2.6, scale=0.1).log_prob(proposed_particles))

        num_particles = 16
        initial_state = self.evaluate(
            WeightedParticles(
                particles=tf.random.normal([num_particles],
                                           seed=test_util.test_seed()),
                log_weights=tf.fill([num_particles],
                                    -tf.math.log(float(num_particles)))))

        # Run a couple of steps.
        seeds = samplers.split_seed(
            test_util.test_seed(sampler_type='stateless'), n=2)
        kernel = SequentialMonteCarlo(
            propose_and_update_log_weights_fn=propose_and_update_log_weights_fn,
            resample_fn=tfp.experimental.mcmc.resample_systematic,
            resample_criterion_fn=tfp.experimental.mcmc.ess_below_threshold)
        state, results = kernel.one_step(
            state=initial_state,
            kernel_results=kernel.bootstrap_results(initial_state),
            seed=seeds[0])
        state, results = kernel.one_step(state=state,
                                         kernel_results=results,
                                         seed=seeds[1])
        state, results = self.evaluate(
            (tf.nest.map_structure(tf.convert_to_tensor, state),
             tf.nest.map_structure(tf.convert_to_tensor, results)))

        # Re-initialize and run the same steps with the same seed.
        kernel2 = SequentialMonteCarlo(
            propose_and_update_log_weights_fn=propose_and_update_log_weights_fn,
            resample_fn=tfp.experimental.mcmc.resample_systematic,
            resample_criterion_fn=tfp.experimental.mcmc.ess_below_threshold)
        state2, results2 = kernel2.one_step(
            state=initial_state,
            kernel_results=kernel2.bootstrap_results(initial_state),
            seed=seeds[0])
        state2, results2 = kernel2.one_step(state=state2,
                                            kernel_results=results2,
                                            seed=seeds[1])
        state2, results2 = self.evaluate(
            (tf.nest.map_structure(tf.convert_to_tensor, state2),
             tf.nest.map_structure(tf.convert_to_tensor, results2)))

        def compare_fn(x, y):
            # TODO(b/223267515): PRNGKeyArrays have no dtype.
            if hasattr(x, 'dtype'):
                self.assertAllClose(x, y)
            else:
                self.assertSeedsEqual(x, y)

        # Results should match.
        self.assertAllCloseNested(state, state2)
        self.assertAllAssertsNested(compare_fn, results, results2)