Пример #1
0
def run_parametrised_hmc(model_config,
                         interceptor,
                         num_samples=2000,
                         burnin=1000,
                         num_leapfrog_steps=4,
                         num_adaptation_steps=500,
                         num_optimization_steps=2000):
    """Given a (centred) model, this function transforms it based on the provided
  interceptor, and runs HMC on the reparameterised model.
  """
    def model_ncp(*params):
        with ed.interception(interceptor):
            return model_config.model(*params)

    log_joint_noncentered = ed.make_log_joint_fn(model_ncp)

    with ed.tape() as model_tape:
        _ = model_ncp(*model_config.model_args)

    param_shapes = collections.OrderedDict()
    target_ncp_kwargs = {}
    for param in model_tape.keys():
        if param not in model_config.observed_data.keys():
            param_shapes[param] = model_tape[param].shape
        else:
            target_ncp_kwargs[param] = model_config.observed_data[param]

    def target_ncp(*param_args):
        i = 0
        for param in model_tape.keys():
            if param not in model_config.observed_data.keys():
                target_ncp_kwargs[param] = param_args[i]
                i = i + 1

        return log_joint_noncentered(*model_config.model_args,
                                     **target_ncp_kwargs)

    stepsize_kwargs = {'num_leapfrog_steps': num_leapfrog_steps}
    stepsize_kwargs = {'num_optimization_steps': num_optimization_steps}
    for key in model_config.observed_data:
        stepsize_kwargs[key] = model_config.observed_data[key]
    (step_size_init_ncp, stepsize_elbo_ncp,
     vi_time) = util.approximate_mcmc_step_size(model_ncp,
                                                *model_config.model_args,
                                                **stepsize_kwargs)

    results = _run_hmc(target_ncp,
                       param_shapes,
                       step_size_init=step_size_init_ncp,
                       transform=model_config.to_centered,
                       num_samples=num_samples,
                       burnin=burnin,
                       num_adaptation_steps=num_adaptation_steps,
                       num_leapfrog_steps=num_leapfrog_steps)

    results['elbo'] = stepsize_elbo_ncp
    results['vi_time'] = vi_time
    return results
Пример #2
0
def run_centered_hmc(model_config,
                     num_samples=2000,
                     burnin=1000,
                     num_leapfrog_steps=4,
                     num_adaptation_steps=500,
                     num_optimization_steps=2000):
    """Runs HMC on the provided (centred) model."""

    tf.reset_default_graph()

    log_joint_centered = ed.make_log_joint_fn(model_config.model)

    with ed.tape() as model_tape:
        _ = model_config.model(*model_config.model_args)

    param_shapes = collections.OrderedDict()
    target_cp_kwargs = {}
    for param in model_tape.keys():
        if param not in model_config.observed_data.keys():
            param_shapes[param] = model_tape[param].shape
        else:
            target_cp_kwargs[param] = model_config.observed_data[param]

    def target_cp(*param_args):
        i = 0
        for param in model_tape.keys():
            if param not in model_config.observed_data.keys():
                target_cp_kwargs[param] = param_args[i]
                i = i + 1

        return log_joint_centered(*model_config.model_args, **target_cp_kwargs)

    stepsize_kwargs = {'num_leapfrog_steps': num_leapfrog_steps}
    stepsize_kwargs = {'num_optimization_steps': num_optimization_steps}
    for key in model_config.observed_data:
        stepsize_kwargs[key] = model_config.observed_data[key]
    (step_size_init_cp, stepsize_elbo_cp,
     vi_time) = util.approximate_mcmc_step_size(model_config.model,
                                                *model_config.model_args,
                                                **stepsize_kwargs)

    results = _run_hmc(target_cp,
                       param_shapes,
                       step_size_init=step_size_init_cp,
                       num_samples=num_samples,
                       burnin=burnin,
                       num_adaptation_steps=num_adaptation_steps,
                       num_leapfrog_steps=num_leapfrog_steps)

    results['elbo'] = stepsize_elbo_cp
    results['vi_time'] = vi_time
    return results