def make_xi_kernel(scale, bounded_convergence, name): return GibbsStep( 1, tfp.mcmc.RandomWalkMetropolis( target_log_prob_fn=logp, new_state_fn=random_walk_mvnorm_fn(scale, p_u=bounded_convergence), ), name=name, )
def make_theta_kernel(scale, bounded_convergence, name): return GibbsStep( 0, tfp.mcmc.MetropolisHastings(inner_kernel=UncalibratedLogRandomWalk( target_log_prob_fn=logp, new_state_fn=random_walk_mvnorm_fn(scale, p_u=bounded_convergence), )), name=name, )
def make_occults_step(prev_event_id, target_event_id, next_event_id, name): return GibbsStep( 2, tfp.mcmc.MetropolisHastings(inner_kernel=UncalibratedOccultUpdate( target_log_prob_fn=logp, topology=TransitionTopology(prev_event_id, target_event_id, next_event_id), cumulative_event_offset=initial_state, nmax=config["mcmc"]["occult_nmax"], t_range=(events.shape[1] - 21, events.shape[1]), name=name, )), name=name, )
def make_partially_observed_step( target_event_id, prev_event_id=None, next_event_id=None, name=None ): return GibbsStep( 2, tfp.mcmc.MetropolisHastings( inner_kernel=UncalibratedEventTimesUpdate( target_log_prob_fn=logp, target_event_id=target_event_id, prev_event_id=prev_event_id, next_event_id=next_event_id, initial_state=initial_state, dmax=config["mcmc"]["dmax"], mmax=config["mcmc"]["m"], nmax=config["mcmc"]["nmax"], ) ), name=name, )