示例#1
0
def Inference_MCMC(model,
                   data,
                   polls,
                   n_samples=500,
                   n_warmup=500,
                   n_chains=1,
                   max_tree_depth=6):

    nuts_kernel = NUTS(model,
                       adapt_step_size=True,
                       jit_compile=True,
                       ignore_jit_warnings=True,
                       max_tree_depth=max_tree_depth)

    mcmc = MCMC(nuts_kernel,
                num_samples=n_samples,
                warmup_steps=n_warmup,
                num_chains=n_chains)

    mcmc.run(data, polls)

    # the samples that were not rejected;
    # actual samples from the posterior dist
    posterior_samples = mcmc.get_samples()

    # turning to a dict
    hmc_samples = {
        k: v.detach().cpu().numpy()
        for k, v in mcmc.get_samples().items()
    }

    return posterior_samples, hmc_samples
示例#2
0
    def __init__(self, model, data, covariates=None, *,
                 num_warmup=1000, num_samples=1000, num_chains=1,
                 dense_mass=False, jit_compile=False, max_tree_depth=10):
        assert data.size(-2) == covariates.size(-2)
        super().__init__()
        self.model = model
        max_plate_nesting = _guess_max_plate_nesting(model, (data, covariates), {})
        self.max_plate_nesting = max(max_plate_nesting, 1)  # force a time plate

        kernel = NUTS(model, full_mass=dense_mass, jit_compile=jit_compile, ignore_jit_warnings=True,
                      max_tree_depth=max_tree_depth, max_plate_nesting=max_plate_nesting)
        mcmc = MCMC(kernel, warmup_steps=num_warmup, num_samples=num_samples, num_chains=num_chains)
        mcmc.run(data, covariates)
        # conditions to compute rhat
        if (num_chains == 1 and num_samples >= 4) or (num_chains > 1 and num_samples >= 2):
            mcmc.summary()

        # inspect the model with particles plate = 1, so that we can reshape samples to
        # add any missing plate dim in front.
        with poutine.trace() as tr:
            with pyro.plate("particles", 1, dim=-self.max_plate_nesting - 1):
                model(data, covariates)

        self._trace = tr.trace
        self._samples = mcmc.get_samples()
        self._num_samples = num_samples * num_chains
        for name, node in list(self._trace.nodes.items()):
            if name not in self._samples:
                del self._trace.nodes[name]
示例#3
0
    def mcmc(self, data: torch.Tensor, y: torch.Tensor, tensorboard: bool,
             log_dir: str):
        """
        Perform Markov-Chain Monte-Carlo sampling on the (unknown) posterior.

        Parameters
        ----------
        data_input : np.ndarray, shape=(n_samples, n_features)
            NumPy 2-D array with data input.
        y : np.ndarray, shape=(n_samples,)
            NumPy array with ground truth labels as 1-D vector (binary).
        """

        if tensorboard:
            writer = SummaryWriter(log_dir=log_dir)
            distribution = defaultdict(list)

            def log(kernel, samples, stage, i):
                """ Log after each MCMC iteration """

                # loop through all sites and log their value as well as the underlying distribution
                # approximated by a Gaussian
                for key, value in samples.items():
                    distribution[key].append(value)
                    stacked = torch.stack(distribution[key], dim=0)
                    mean, scale = torch.mean(stacked,
                                             dim=0), torch.std(stacked, dim=0)

                    for d, x in enumerate(value):
                        writer.add_scalar("%s_%s_%d" % (stage, key, d), x, i)
                        writer.add_scalar("%s_%s_mean_%d" % (stage, key, d),
                                          mean[d], i)
                        writer.add_scalar("%s_%s_scale_%d" % (stage, key, d),
                                          scale[d], i)

                        writer.add_histogram(
                            "%s_histogram_%s_%d" % (stage, key, d),
                            stacked[:, d], i)

        # if logging is not requested, return empty lambda
        else:
            log = lambda kernel, samples, stage, i: None

        # set up MCMC kernel
        kernel = NUTS(self.model)

        # initialize MCMC sampler and run sampling algorithm
        mcmc = MCMC(kernel,
                    num_samples=self.mcmc_steps,
                    warmup_steps=self.mcmc_warmup,
                    num_chains=self.mcmc_chains,
                    hook_fn=log)
        mcmc.run(data.float(), y.float())

        # get samples from MCMC chains and store weights
        samples = mcmc.get_samples()
        self.mcmc_model = samples

        if tensorboard:
            writer.close()
示例#4
0
def _infer_hmc(args, data, model, init_values={}):
    logging.info("Running inference...")
    kernel = NUTS(model,
                  full_mass=[("R0", "rho")],
                  max_tree_depth=args.max_tree_depth,
                  init_strategy=init_to_value(values=init_values),
                  jit_compile=args.jit, ignore_jit_warnings=True)

    # We'll define a hook_fn to log potential energy values during inference.
    # This is helpful to diagnose whether the chain is mixing.
    energies = []

    def hook_fn(kernel, *unused):
        e = float(kernel._potential_energy_last)
        energies.append(e)
        if args.verbose:
            logging.info("potential = {:0.6g}".format(e))

    mcmc = MCMC(kernel, hook_fn=hook_fn,
                num_samples=args.num_samples,
                warmup_steps=args.warmup_steps)
    mcmc.run(args, data)
    mcmc.summary()
    if args.plot:
        import matplotlib.pyplot as plt
        plt.figure(figsize=(6, 3))
        plt.plot(energies)
        plt.xlabel("MCMC step")
        plt.ylabel("potential energy")
        plt.title("MCMC energy trace")
        plt.tight_layout()

    samples = mcmc.get_samples()
    return samples
示例#5
0
 def predict(mcmc_instance_object: MCMC, x: torch.Tensor) -> ndarray:
     samples = mcmc_instance_object.get_samples()
     pred = []
     for i in range(samples['sigma'].shape[0]):
         pred.append(np.dot((x.numpy()), samples['linear.weight'][i].T))
         if 'linear.bias' in samples:
             pred[-1] += float(samples['linear.bias'][i])
     return np.array(pred)
示例#6
0
def tomtom_mcmc(data, seed, nsample=5000, burnin=1000):
    pyro.clear_param_store()
    pyro.set_rng_seed(seed)

    # #declare dataset to be modeled
    # dtname = 't{}_{}_{}_3d'.format(target, dtype, auto)
    # print("running MCMC with: {}".format(dtname))
    # data = globals()[dtname]

    nuts_kernel = NUTS(model)

    mcmc = MCMC(nuts_kernel, num_samples=nsample, warmup_steps=burnin)
    mcmc.run(data)

    posterior_samples = mcmc.get_samples()
    return posterior_samples
示例#7
0
def test_neals_funnel_smoke(jit):
    dim = 10

    guide = AutoIAFNormal(neals_funnel)
    svi = SVI(neals_funnel, guide, optim.Adam({"lr": 1e-10}), Trace_ELBO())
    for _ in range(1000):
        svi.step(dim)

    neutra = NeuTraReparam(guide.requires_grad_(False))
    model = neutra.reparam(neals_funnel)
    nuts = NUTS(model, jit_compile=jit)
    mcmc = MCMC(nuts, num_samples=50, warmup_steps=50)
    mcmc.run(dim)
    samples = mcmc.get_samples()
    # XXX: `MCMC.get_samples` adds a leftmost batch dim to all sites, not uniformly at -max_plate_nesting-1;
    # hence the unsqueeze
    transformed_samples = neutra.transform_sample(
        samples['y_shared_latent'].unsqueeze(-2))
    assert 'x' in transformed_samples
    assert 'y' in transformed_samples
def run_hmc(
    x_data,
    y_data,
    model,
    num_samples=1000,
    warmup_steps=200,
):
    """
    Runs NUTS
    returns: samples
    """
    nuts_kernel = NUTS(model)
    mcmc = MCMC(nuts_kernel,
                num_samples=num_samples,
                warmup_steps=warmup_steps)
    mcmc.run(x_data, y_data)
    hmc_samples = {
        k: v.detach().cpu().numpy()
        for k, v in mcmc.get_samples().items()
    }
    hmc_samples["linear.weight"] = hmc_samples["linear.weight"].reshape(
        num_samples, -1)
    return hmc_samples
示例#9
0
    def test_inference_data_constant_data(self):
        import pyro.distributions as dist
        from pyro.infer import MCMC, NUTS

        x1 = 10
        x2 = 12
        y1 = torch.randn(10)

        def model_constant_data(x, y1=None):
            _x = pyro.sample("x", dist.Normal(1, 3))
            pyro.sample("y1", dist.Normal(x * _x, 1), obs=y1)

        nuts_kernel = NUTS(model_constant_data)
        mcmc = MCMC(nuts_kernel, num_samples=10)
        mcmc.run(x=x1, y1=y1)
        posterior = mcmc.get_samples()
        posterior_predictive = Predictive(model_constant_data, posterior)(x1)
        predictions = Predictive(model_constant_data, posterior)(x2)
        inference_data = from_pyro(
            mcmc,
            posterior_predictive=posterior_predictive,
            predictions=predictions,
            constant_data={"x1": x1},
            predictions_constant_data={"x2": x2},
        )
        test_dict = {
            "posterior": ["x"],
            "posterior_predictive": ["y1"],
            "sample_stats": ["diverging"],
            "log_likelihood": ["y1"],
            "predictions": ["y1"],
            "observed_data": ["y1"],
            "constant_data": ["x1"],
            "predictions_constant_data": ["x2"],
        }
        fails = check_multiple_attrs(test_dict, inference_data)
        assert not fails
示例#10
0
def test_neals_funnel_smoke(jit):
    dim = 10

    guide = AutoStructured(
        neals_funnel,
        conditionals={
            "y": "normal",
            "x": "mvn"
        },
        dependencies={"x": {
            "y": "linear"
        }},
    )
    Elbo = JitTrace_ELBO if jit else Trace_ELBO
    svi = SVI(neals_funnel, guide, optim.Adam({"lr": 1e-10}), Elbo())
    for _ in range(1000):
        try:
            svi.step(dim=dim)
        except SystemError as e:
            if "returned a result with an error set" in str(e):
                pytest.xfail(reason="PyTorch jit bug")
            else:
                raise e from None

    rep = StructuredReparam(guide)
    model = rep.reparam(neals_funnel)
    nuts = NUTS(model, max_tree_depth=3, jit_compile=jit)
    mcmc = MCMC(nuts, num_samples=50, warmup_steps=50)
    mcmc.run(dim)
    samples = mcmc.get_samples()
    # XXX: `MCMC.get_samples` adds a leftmost batch dim to all sites,
    # not uniformly at -max_plate_nesting-1; hence the unsqueeze.
    samples = {k: v.unsqueeze(1) for k, v in samples.items()}
    transformed_samples = rep.transform_samples(samples)
    assert isinstance(transformed_samples, dict)
    assert set(transformed_samples) == {"x", "y"}
示例#11
0
def test_neals_funnel_smoke():
    dim = 10

    def model():
        y = pyro.sample('y', dist.Normal(0, 3))
        with pyro.plate("D", dim):
            pyro.sample('x', dist.Normal(0, torch.exp(y/2)))

    guide = AutoIAFNormal(model)
    svi = SVI(model, guide,  optim.Adam({"lr": 1e-10}), Trace_ELBO())
    for _ in range(1000):
        svi.step()

    neutra = NeuTraReparam(guide)
    model = neutra.reparam(model)
    nuts = NUTS(model)
    mcmc = MCMC(nuts, num_samples=50, warmup_steps=50)
    mcmc.run()
    samples = mcmc.get_samples()
    # XXX: `MCMC.get_samples` adds a leftmost batch dim to all sites, not uniformly at -max_plate_nesting-1;
    # hence the unsqueeze
    transformed_samples = neutra.transform_sample(samples['y_shared_latent'].unsqueeze(-2))
    assert 'x' in transformed_samples
    assert 'y' in transformed_samples
示例#12
0
def main():
    start = time.time()
    pyro.clear_param_store()

    # the kernel we will use
    hmc_kernel = HMC(conditioned_model, step_size=0.1)

    # the sampler which will run the kernel
    mcmc = MCMC(hmc_kernel, num_samples=14000, warmup_steps=100)

    # the .run method accepts as parameter the same parameters our model function uses
    mcmc.run(model, data)
    end = time.time()
    print('Time taken ', end - start, ' seconds')

    sample_dict = mcmc.get_samples(num_samples=5000)

    plt.figure(figsize=(10, 7))
    sns.distplot(sample_dict['latent_fairness'].numpy(), color="orange")
    plt.xlabel("Observed probability value")
    plt.ylabel("Observed frequency")
    plt.show()

    mcmc.summary(prob=0.95)
示例#13
0
    # sigma = dist.Uniform(0., 5.).sample()
    # sigma = dist.Uniform(sigma_loc, 5.).sample()
    # sigma = dist.Normal(sigma_loc, 0.2).sample()
    pyro.sample("obs", dist.Normal(mean, sigma), obs=obserations)


dims = 4
num_samples = 100

# generate observations
x = torch.rand(dims, num_samples)
noise = torch.distributions.Normal(torch.tensor([0.] * num_samples),
                                   torch.tensor([0.2] *
                                                num_samples)).rsample()
s, fm, Zn, Vr = x
a, b, c, d = 1.5, 1.8, 2.1, 2.3
# a, b, c, d = 1., 1., 1., 1.
obserations = s * fm**a * Zn**b * c / Vr**d + noise[0]

nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=200, warmup_steps=400)
mcmc.run(x, dims, obserations)

hmc_samples = {
    k: v.detach().cpu().numpy()
    for k, v in mcmc.get_samples().items()
}
for site, values in summary(hmc_samples).items():
    print("Site: {}".format(site))
    print(values, "\n")
示例#14
0
from pyro.primitives import sample
import torch
import numpy as np
import pyro
from pyro import sample
from pyro.infer import NUTS, MCMC
from pyro.distributions import Normal
from matplotlib import pyplot as plt


def bad():
    x = sample('x', pyro.distributions.Normal(0, 1))
    for i in range(10):
        x = sample('x', pyro.distributions.Normal(x, 3))
    x


nuts_kernel = NUTS(bad)

mcmc = MCMC(nuts_kernel, num_samples=10, warmup_steps=10)
mcmc.run()
mcmc.summary()
samples = mcmc.get_samples()

print(samples.keys())

fig, ax = plt.subplots()
ax.hist(np.array(samples["x"]), bins=50)
plt.show()
示例#15
0
def do_fixed_structure_hmc_with_constraint_penalties(
        grammar,
        original_tree,
        num_samples=100,
        subsample_step=5,
        verbose=0,
        kernel_type="NUTS",
        fix_observeds=False,
        with_nonpenetration=False,
        with_static_stability=False,
        zmq_url=None,
        prefix="hmc_sample",
        constraints=[],
        structure_vis_kwargs={},
        **kwargs):
    ''' Given a scene tree, resample its continuous variables
    (i.e. the node poses) while keeping the root and observed
    node poses fixed, and trying to keep the constraints implied
    by the tree and grammar satisfied.. Returns a population of trees sampled
    from the joint distribution over node poses given the fixed structure.

    Verbose = 0: Print nothing
    Verbose = 1: Print updates about accept rate and steps
    Verbose = 2: Print NLP output and scoring info
    '''

    # Strategy sketch:
    # - Initialize at the current scene tree, asserting that it's a feasible configuration.
    # - Form a probabilistic model that samples all of the node poses,
    #   and uses observe() statements to implement constraint factors as tightly peaked
    #   energy terms.
    # - Use Pyro HMC to sample from the model.

    # Make a bookkeeping copy of the tree
    scene_tree = deepcopy(original_tree)

    # Do steps of random-walk MCMC on those variables.
    initial_score = scene_tree.score()
    assert torch.isfinite(initial_score), "Bad initialization for MCMC."

    if with_nonpenetration:
        constraints.append(NonpenetrationConstraint(0.00))

    # Form probabilistic model
    root = scene_tree.get_root()

    def model():
        # Resample the continuous structure of the tree.
        node_queue = [root]
        while len(node_queue) > 0:
            parent = node_queue.pop(0)
            children, rules = scene_tree.get_children_and_rules(parent)
            for child, rule in zip(children, rules):
                with scope(prefix=parent.name):
                    rule.sample_child(parent, child)
                node_queue.append(child)

        # Implement observation constraints
        if fix_observeds:
            xyz_observed_variance = 1E-2
            rot_observed_variance = 1E-2
            for node, original_node in zip(scene_tree.nodes,
                                           original_tree.nodes):
                if node.observed:
                    xyz_observed_dist = dist.Normal(original_node.translation,
                                                    xyz_observed_variance)
                    rot_observed_dist = dist.Normal(original_node.rotation,
                                                    rot_observed_variance)
                    pyro.sample("%s_xyz_observed" % node.name,
                                xyz_observed_dist,
                                obs=node.translation)
                    pyro.sample("%s_rotation_observed" % node.name,
                                rot_observed_dist,
                                obs=node.rotation)

        for k, constraint in enumerate(constraints):
            clamped_error_distribution = dist.Normal(0., 0.001)
            violation, _, _ = constraint.eval_violation(scene_tree)
            positive_violations = torch.clamp(violation, 0., np.inf)
            pyro.sample("%s_%d_err" % (type(constraint).__name__, k),
                        clamped_error_distribution,
                        obs=positive_violations)

    # Ugh, I shouldn't need to manually reproduce the site names here.
    # Can I rearrange how traces get assembled to extract these?
    initial_values = {}
    for parent in original_tree.nodes:
        children, rules = original_tree.get_children_and_rules(parent)
        for child, rule in zip(children, rules):
            for key, site_value in rule.get_site_values(parent, child).items():
                initial_values["%s/%s/%s" % (parent.name, child.name,
                                             key)] = site_value.value
    trace = pyro.poutine.trace(model).get_trace()
    for key in initial_values.keys():
        if key not in trace.nodes.keys():
            print("Trace keys: ", trace.nodes.keys())
            print("Initial values keys: ", initial_values.keys())
            raise ValueError("%s not in trace keys" % key)

    print("Initial trace log prob: ", trace.log_prob_sum())
    # If I let MCMC auto-tune its step size, it seems to do well,
    # but sometimes seems to get lost, and then gets stuck with big step size and
    # zero acceptances.
    if kernel_type == "NUTS":
        kernel = NUTS(model,
                      init_strategy=pyro.infer.autoguide.init_to_value(
                          values=initial_values),
                      **kwargs)
    elif kernel_type is "HMC":
        kernel = HMC(model,
                     init_strategy=pyro.infer.autoguide.init_to_value(
                         values=initial_values),
                     **kwargs)
    else:
        raise NotImplementedError(kernel_type)

    # Get MBP for viz
    if zmq_url is not None:
        builder, mbp, sg, node_to_free_body_ids_map, body_id_to_node_map = compile_scene_tree_to_mbp_and_sg(
            scene_tree)
        mbp.Finalize()
        visualizer = ConnectMeshcatVisualizer(builder,
                                              sg,
                                              zmq_url=zmq_url,
                                              prefix=prefix)
        diagram = builder.Build()
        diagram_context = diagram.CreateDefaultContext()
        mbp_context = diagram.GetMutableSubsystemContext(mbp, diagram_context)
        vis_context = diagram.GetMutableSubsystemContext(
            visualizer, diagram_context)
        visualizer.load()

        def hook_fn(kernel, samples, stage, i):
            # Set MBP context to
            for node, body_ids in node_to_free_body_ids_map.items():
                for body_id in body_ids:
                    mbp.SetFreeBodyPose(mbp_context, mbp.get_body(body_id),
                                        torch_tf_to_drake_tf(node.tf))
            diagram.Publish(diagram_context)
            draw_scene_tree_structure_meshcat(scene_tree,
                                              zmq_url=zmq_url,
                                              prefix=prefix + "/structure",
                                              delete=False,
                                              **structure_vis_kwargs)
            time.sleep(0.1)
    else:
        hook_fn = None

    mcmc = MCMC(kernel,
                num_samples=num_samples,
                warmup_steps=min(int(num_samples / 2), 50),
                num_chains=1,
                disable_progbar=(verbose == -1),
                hook_fn=hook_fn)
    mcmc.run()
    if verbose > 1:
        mcmc.summary(prob=0.5)

    samples = mcmc.get_samples()
    sampled_trees = []
    for k in range(0, num_samples, subsample_step):
        condition = {key: value[k, ...] for key, value in samples.items()}
        with pyro.condition(data=condition):
            model()
        sampled_trees.append(deepcopy(scene_tree))

    return sampled_trees
示例#16
0

if __name__ == '__main__':
    # create the params of NB distribution
    alpha = torch.tensor(args.alpha)
    beta = torch.tensor(args.beta)
    r = torch.tensor(args.r)
    data = torch.tensor([12, 11, 6, 12, 11, 0, 4, 6, 5, 6])

    nb_post = NB_Post(alpha, beta, args.r)
    # create hmc and mcmc object and sample
    hmc_kernel = HMC(nb_post.model, step_size=args.step_size, num_steps=args.num_steps)
    mcmc = MCMC(hmc_kernel, num_samples=args.num_samples, warmup_steps=args.warm_steps)

    # sample the posterior
    mcmc.run(data, args.logit)
    if args.logit:
        param = 'eta'
        posterior_samples = mcmc.get_samples()[param]
        # logit transform
        posterior_samples = torch.exp(posterior_samples) / (1. + torch.exp(posterior_samples))
        # plot the estimated posterior density
        plot_logit_density(posterior_samples)
    else:
        param = 'p'
        posterior_samples = mcmc.get_samples()[param]
        poster_alpha = (alpha + data.sum()).numpy()
        poster_beta = (len(data) * r + beta).numpy()
        # plot the estimated and ground truth density
        plot_density(poster_alpha, poster_beta, posterior_samples)
示例#17
0
        sys.stderr.write("Requires Python 3\n")

    genr = Decoder()
    genr.load_state_dict(torch.load('gaae-decd-1024.tch'))
    genr.eval()

    data = qPCRData('second.txt', randomize=False, test=False)

    # Do it with CUDA if possible.
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if device == 'cuda':
        torch.set_default_tensor_type(torch.cuda.FloatTensor)
        genr.cuda()

    model = GeneratorModel(genr)

    nuts_kernel = NUTS(model, adapt_step_size=True, jit_compile=True)
    mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=2000)
    for batch in data.batches(btchsz=8192, randomize=False, test=False):
        obs = batch[:, 90:].to(device)
        mcmc.run(obs)
        z = mcmc.get_samples()['z']
        # Propagate forward and sample observable 'x'.
        with torch.no_grad():
            mu, sd = genr(z)
        for i in range(batch.shape[0]):
            x = Normal(mu[:, i, :90], sd[:, i, :90]).sample()
            orig = batch[i, 90:].expand([1000, 45])
            out = torch.cat([x, orig], dim=1)
            np.savetxt(sys.stdout, out.cpu().numpy(), fmt='%.4f')
class BayesInferece:
    def __init__(self, pred_func, dim, lb, ub):
        self.pred_func = pred_func
        self.fit_dim = dim
        self.lb, self.ub = torch.tensor(lb).float(), torch.tensor(ub).float()

    def model(self, data):
        """
        这个data是每个省份每一天的真实值,是2D数据
        """
        params_sample = []
        # for i in pyro.plate("param_plate", self.fit_dim):
        for i in range(self.fit_dim):
            param_one = pyro.sample("fit_params_%d" % i,
                                    dist.Uniform(self.lb[i], self.ub[i]))
            params_sample.append(param_one.detach().item())
        params_sample = np.array(params_sample)
        # import ipdb; ipdb.set_trace()
        pred_mean = self.pred_func(params_sample)  # 使用当前的参数进行预测
        pred_mean = torch.tensor(pred_mean)
        pyro.sample("obs", dist.Poisson(pred_mean), obs=data)
        # plate1 = pyro.plate("obs_plate_1", data.shape[-1])
        # plate2 = pyro.plate("obs_plate_2", data.shape[-2])
        # for i in plate1:
        #     for j in plate2:
        #         pyro.sample("obs_i%d_j%d" % (i, j),
        #                     dist.Poisson(pred_mean), obs=data[j, i])

    def guide(self, data):
        mean = pyro.param(
            "mean",
            torch.rand(self.fit_dim),
            # constraint=constraints.positive
        )
        std = pyro.param("std",
                         torch.rand(self.fit_dim),
                         constraint=constraints.positive)
        with pyro.plate("guide_plate", self.fit_dim):
            pyro.sample("fit_params", dist.Normal(mean, std))

    def multi_norm_guide(self):
        return AutoMultivariateNormal(self.model, init_loc_fn=init_to_mean)

    def inference_svi(self, data, steps=3000, lr=0.01):
        self.inference_method = "svi"
        pyro.clear_param_store()
        self.optimizer = Adam({"lr": lr, "betas": (0.90, 0.999)})
        self.svi = SVI(self.model,
                       self.multi_norm_guide(),
                       self.optimizer,
                       loss=Trace_ELBO())
        self.history = {"losses": []}
        data = torch.tensor(data).float()
        bar = trange(steps)
        for i in bar:
            loss = self.svi.step(data)
            if (i + 1) % 100 == 1:
                bar.write("Now step %d completed, loss is %.4f" % (i, loss))
            self.history["losses"].append(loss)

    def inference_mcmc(self,
                       data,
                       num_samples=3000,
                       warmup_steps=2000,
                       num_chains=4):
        self.inference_method = "mcmc"
        data = torch.tensor(data).float()
        self.nuts_kernel = NUTS(self.model, adapt_step_size=True)
        self.mcmc = MCMC(self.nuts_kernel,
                         num_samples=num_samples,
                         warmup_steps=warmup_steps,
                         num_chains=num_chains)
        self.mcmc.run(data)

    def save(self, path):
        pass

    def fit_params_estimate(self, data):
        if self.inference_method == "svi":
            data = torch.tensor(data).float()
            predictive = Predictive(self.model,
                                    guide=self.multi_norm_guide(),
                                    num_samples=1000)
            svi_samples = {
                k: v.reshape(1000).detach().numpy()
                for k, v in predictive(data).items() if not k.startswith("obs")
            }
            return {
                "mean": np.array([v.mean() for v in svi_samples.values()]),
                "std": np.array([v.std() for v in svi_samples.values()])
            }

        elif self.inference_method == "mcmc":
            return {"mean": self.mcmc.get_samples()["fit_params"].mean(0)}
        else:
            raise NotImplementedError
示例#19
0
 def draw_posterior_samples(self,num_samples):
     nuts_kernel = NUTS(self.model, adapt_step_size=True)
     mcmc = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=300)
     mcmc.run()
     res = mcmc.get_samples()
     return res
示例#20
0
class BayesianVSCalibrator:
    """ This class implements the Bayesian VS calibrator, with bias.
    Performs inference using NUTS.
    """
    def __init__(self, prior_params, num_classes, **kwargs):
        self.num_classes = num_classes
        # Inference parameters
        self.NUTS_params = {
            'adapt_step_size': kwargs.pop('adapt_step_size', True),
            'target_accept_prob': kwargs.pop('target_accept_prob', 0.8),
            'max_plate_nesting': 1
        }
        self.mcmc_params = {
            'num_samples': kwargs.pop('num_samples', 250),
            'warmup_steps': kwargs.pop('num_warmup', 1000),
            'num_chains': kwargs.pop('num_chains', 4)
        }

        # Prior parameters on beta / delta ; assumes each weight/bias is i.i.d from its respective distribution.
        self.prior_params = {
            'mu_beta':
            torch.empty(self.num_classes).fill_(prior_params['mu_beta']),
            'sigma_beta':
            torch.empty(self.num_classes).fill_(prior_params['sigma_beta']),
            'mu_delta':
            torch.empty(self.num_classes).fill_(prior_params['mu_delta']),
            'sigma_delta':
            torch.empty(self.num_classes).fill_(prior_params['sigma_delta'])
        }

        # Posterior parameters after ADF
        # TODO
        self.posterior_params = {'mu_beta': None, 'sigma_beta': None}

        # Drift parameters for sequential updating
        self.sigma_drift = kwargs.pop('sigma_drift', 0.0)

        # Tracking params
        # TODO: Prior/posterior trace
        self.timestep = 0
        self.mcmc = None  # Contains the most recent Pyro MCMC api object
        self.verbose = kwargs.pop('verbose', True)

        if self.verbose:
            print('\nInitializing VS model:\n'
                  '----| Prior: {} \n----| Inference Method: NUTS \n'
                  '----| MCMC parameters: {}'
                  ''.format(prior_params, self.mcmc_params))

    def update(self, logits, labels):
        """ Performs an update given new observations.

        Args:
            logits: tensor ; shape (batch_size, num_classes)
            labels: tensor ; shape (batch_size, )
        """
        assert len(
            labels.shape
        ) == 1, 'Got label tensor with shape {} -- labels must be dense'.format(
            labels.shape)
        assert len(logits.shape) == 2, 'Got logit tensor with shape {}'.format(
            logits.shape)
        assert (labels.shape[0] == logits.shape[0]), 'Shape mismatch between logits ({}) and labels ({})' \
            .format(logits.shape[0], labels.shape[0])

        logits = logits.detach().clone().requires_grad_()
        labels = labels.detach().clone()

        batch_size = labels.shape[0]
        if self.verbose:
            print(
                '----| Updating HBC model\n--------| Got a batch size of: {}'.
                format(batch_size))

        # TODO
        # self._update_prior_params()
        if self.verbose:
            print('--------| Updated priors: {}'.format(self.prior_params))
            print('--------| Running inference ')
        nuts_kernel = NUTS(bvs_model, **self.NUTS_params)
        self.mcmc = MCMC(
            nuts_kernel, **self.mcmc_params,
            disable_progbar=not self.verbose)  # Progbar if verbose
        self.mcmc.run(self.prior_params, logits, labels)

        # TODO
        # self._update_posterior_params()
        self.timestep += 1

        return self.mcmc

    def _update_prior_params(self):
        """ Updates the prior parameters using the ADF posterior from the previous timestep, plus the drift.

        If this is the first batch, i.e. timestep == 0, do nothing.
        """
        # TODO
        if self.timestep > 0:
            self.prior_params['mu_beta'] = self.posterior_params['mu_beta']
            self.prior_params['sigma_beta'] = self.posterior_params[
                'sigma_beta'] + self.sigma_drift

    def _update_posterior_params(self):
        """ Fits a normal distribution to the current beta samples using moment matching.
        """
        # TODO
        beta_samples = self.get_current_posterior_samples()
        self.posterior_params['mu_beta'] = beta_samples.mean().item()
        self.posterior_params['sigma_beta'] = beta_samples.std().item()

    def get_current_posterior_samples(self):
        """ Returns the current posterior samples for beta.
        """
        if self.mcmc is None:
            return None

        return self.mcmc.get_samples()

    def calibrate(self, logit):
        """ Calibrates the given batch of logits using the current posterior samples.

        Args:
            logit: tensor ; shape (batch_size, num_classes)
        """
        # Get beta samples
        beta_samples = self.get_current_posterior_samples()[
            'beta']  # Shape (num_samples, num_classes)
        delta_samples = self.get_current_posterior_samples()[
            'delta']  # Shape (num_samples, num_classes)

        # Get a batch of logits for each sampled parameter vector
        # Shape (num_samples, batch_size, num_classes)
        tempered_logit_samples = beta_samples.view(-1, 1, self.num_classes) * logit + \
                                 delta_samples.view(-1, 1, self.num_classes)

        # Softmax the sampled logits to get sampled probabilities
        prob_samples = softmax(
            tempered_logit_samples,
            dim=2)  # Shape (num_samples, batch_size, num_classes)

        # Average over the sampled probabilities to get Monte Carlo estimate
        calibrated_probs = prob_samples.mean(
            dim=0)  # Shape (batch_size, num_classes)

        return calibrated_probs

    def get_MAP_temperature(self, logits, labels):
        """ Performs MAP estimation using the current prior and given data.
         NB: This should only be called after .update() if used in a sequential setting, as this method
         does not update the prior with sigma_drift.

         See: https://pyro.ai/examples/mle_map.html
         """
        pyro.clear_param_store()
        svi = pyro.infer.SVI(model=bvs_model,
                             guide=MAP_guide,
                             optim=pyro.optim.Adam({'lr': 0.001}),
                             loss=pyro.infer.Trace_ELBO())

        loss = []
        num_steps = 5000
        for _ in range(num_steps):
            loss.append(svi.step(self.prior_params, logits, labels))

        eps = 2e-2
        loss_sddev = np.std(loss[-25:])
        if loss_sddev > eps:
            warnings.warn(
                'MAP optimization may not have converged ; sddev {}'.format(
                    loss_sddev))

        beta_MAP = pyro.param('beta_MAP').detach()
        delta_MAP = pyro.param('delta_MAP').detach()
        return beta_MAP, delta_MAP
示例#21
0
class BayesianHillModel():
    def __init__(self):
        '''
        '''
        self.prior = dict()

    def get_priors(self):
        '''
        Default Prior values stored here. 
        '''
        E0_mean = 1.
        E0_std = self.prior.get('E0_std', 0.1)

        alpha_emax, beta_emax = gamma_modes_to_params(
            self.prior.get('Emax_Mean', 0.5), self.prior.get('Emax_var', 1.))

        alpha_H, beta_H = gamma_modes_to_params(self.prior.get('H_Mean', 1.5),
                                                self.prior.get('H_var', 1.))

        log10_ec50_mean = self.prior.get('log10_EC50_Mean', -3)
        log10_ec50_std = self.prior.get('log10_EC50_Var', 2.)

        alpha_obs, beta_obs = gamma_modes_to_params(
            self.prior.get('Obs_std_Mean', 1.),
            self.prior.get('Obs_std_Var', 1.))

        return E0_mean, E0_std, alpha_emax, beta_emax, alpha_H, beta_H, log10_ec50_mean, log10_ec50_std, alpha_obs, beta_obs

    def plot_priors(self):
        '''
        '''
        E0_mean, E0_std, alpha_emax, beta_emax, alpha_H, beta_H, log10_ec50_mean, log10_ec50_std, alpha_obs, beta_obs = self.get_priors(
        )

        f, axes = plt.subplots(2, 3, figsize=(12, 7))

        # E0
        xx = np.linspace(0, 2, 50)
        rv = norm(E0_mean, E0_std)
        yy = rv.pdf(xx)
        axes.flat[0].set_title('E0 parameter')
        axes.flat[0].set_xlabel('E0')
        axes.flat[0].set_ylabel('probability')
        axes.flat[0].plot(xx, yy, 'r-')

        # EMAX
        xx = np.linspace(0, 2, 50)
        rv = gamma(alpha_emax, scale=1 / beta_emax, loc=0)
        yy = rv.pdf(xx)
        axes.flat[1].set_title('Emax parameter')
        axes.flat[1].set_xlabel('Emax')
        axes.flat[1].set_ylabel('probability')
        axes.flat[1].plot(xx, yy, 'r-')

        # H
        xx = np.linspace(0, 5, 100)
        rv = gamma(alpha_H, scale=1 / beta_H, loc=0)
        yy = rv.pdf(xx)
        axes.flat[2].set_title('Hill Coefficient (H) parameter')
        axes.flat[2].set_xlabel('H')
        axes.flat[2].set_ylabel('probability')
        axes.flat[2].plot(xx, yy, 'r-')

        # EC50
        xx = np.logspace(-7, 1, 100)
        rv = norm(log10_ec50_mean, log10_ec50_std)
        yy = rv.pdf(np.log10(xx))
        axes.flat[3].set_title('EC50 parameter')
        axes.flat[3].set_xlabel('EC50 [uM]')
        axes.flat[3].set_ylabel('probability')
        axes.flat[3].plot(xx, yy, 'r-')

        # Log10 EC50
        axes.flat[4].set_title('Log10 EC50 parameter [~ Normal]')
        axes.flat[4].set_xlabel('Log10( EC50 [uM] )')
        axes.flat[4].set_ylabel('probability')
        axes.flat[4].plot(np.log10(xx), yy, 'r-')

        # OBS
        xx = np.linspace(0, 5, 100)
        rv = gamma(alpha_obs, scale=1 / beta_obs, loc=0)
        yy = rv.pdf(xx)
        axes.flat[5].set_title('Observation Std parameter')
        axes.flat[5].set_xlabel('Obs. Std')
        axes.flat[5].set_ylabel('probability')
        axes.flat[5].plot(xx, yy, 'r-')

        plt.tight_layout()
        plt.show()

    def plot_prior_regression(self,
                              n_samples=1000,
                              savepath=None,
                              verbose=True):
        '''
        '''
        XX = torch.tensor(np.logspace(-9, 4, 100))

        samples = [self.model(XX) for i in range(n_samples)]

        plt.figure(figsize=(7, 7))

        for i, s in enumerate(samples):
            if (i % 100 == 0) and verbose:
                print(f'plotting prior regression...{i/n_samples*100:.1f}%',
                      end='\r')
            plt.plot(np.log10(XX), s, 'r-', alpha=0.005, linewidth=4.0)

        plt.xlabel('log10 Concentration')
        plt.ylabel('cell_viability')
        plt.ylim(0, 1.2)
        plt.legend()
        plt.title('Prior Probability Hill Regression')

        if savepath is not None:
            plt.savefig(savepath + '/prior_regressions.png')
        else:
            plt.show()

    def model(self, X, Y=None):
        '''
        
        '''
        E0_mean, E0_std, alpha_emax, beta_emax, alpha_H, beta_H, log10_ec50_mean, log10_ec50_std, alpha_obs, beta_obs = self.get_priors(
        )

        E0 = pyro.sample('E0', dist.Normal(E0_mean, E0_std))

        Emax = pyro.sample('Emax', dist.Beta(alpha_emax, beta_emax))

        H = pyro.sample('H', dist.Gamma(alpha_H, beta_H))

        EC50 = 10**pyro.sample('log_EC50',
                               dist.Normal(log10_ec50_mean, log10_ec50_std))

        obs_sigma = pyro.sample("obs_sigma", dist.Gamma(alpha_obs, beta_obs))

        obs_mean = E0 + (Emax - E0) / (1 + (EC50 / X)**H)

        with pyro.plate("data", X.shape[0]):
            obs = pyro.sample("obs",
                              dist.Normal(obs_mean.squeeze(-1), obs_sigma),
                              obs=Y)

        return obs_mean

    def check_converged(self, Rhat_tol=0.05, verbose=False):
        '''
        '''
        results = self.summary()
        max_rhat = max(results.loc['r_hat', :].values)
        min_rhat = min(results.loc['r_hat', :].values)
        if verbose: print('max/min rhat:', (max_rhat, min_rhat))
        return ~(max_rhat > (1 + Rhat_tol) or min_rhat < (1 - Rhat_tol))

    def fit(self, X, Y, num_samples=500, burnin=150, num_chains=1, seed=1):
        '''
        '''
        self.X = X
        self.Y = Y

        if seed is not None:
            torch.manual_seed(seed)

        nuts_kernel = NUTS(self.model, adapt_step_size=True)
        self.mcmc_res = MCMC(nuts_kernel,
                             num_samples=num_samples,
                             warmup_steps=burnin,
                             num_chains=num_chains)
        self.mcmc_res.run(X, Y)

    def plot_fitted_params(self, savepath=None):
        samples = {
            k: v.detach().cpu().numpy()
            for k, v in self.mcmc_res.get_samples().items()
        }

        f, axes = plt.subplots(3, 2, figsize=(10, 5))

        for ax, key in zip(axes.flat, samples.keys()):

            ax.set_title(key)
            ax.hist(samples[key],
                    bins=np.linspace(min(samples[key]), max(samples[key]), 50),
                    density=True)
            ax.set_xlabel(key)
            ax.set_ylabel('probability')

        axes.flat[-1].hist(10**samples['log_EC50'],
                           bins=np.linspace(min(10**(samples['log_EC50'])),
                                            max(10**(samples['log_EC50'])),
                                            50))
        axes.flat[-1].set_title('EC50')
        axes.flat[-1].set_xlabel('EC50 [uM]')

        plt.tight_layout()

        if savepath is not None:
            plt.savefig(savepath)
        else:
            plt.show()

    def plot_fit(self, savepath=None):
        samples = {
            k: v.detach().cpu().numpy()
            for k, v in self.mcmc_res.get_samples().items()
        }

        plt.figure(figsize=(7, 7))

        xx = np.logspace(-7, 6, 200)

        for i, s in pd.DataFrame(samples).iterrows():
            yy = s.E0 + (s.Emax - s.E0) / (1 + (10**s.log_EC50 / xx)**s.H)
            plt.plot(np.log10(xx), yy, 'ro', alpha=0.01)

        plt.plot(np.log10(self.X), self.Y, 'b.', label='data')
        plt.xlabel('log10 Concentration')
        plt.ylabel('cell_viability')
        plt.ylim(0, 1.2)
        plt.legend()
        plt.title('MCMC results')

        if savepath is not None:
            plt.savefig(savepath)
        else:
            plt.show()

    def summary(self, p=0.9, verbose=False):
        self.results = pd.DataFrame(summary(self.mcmc_res._samples, prob=0.9))
        if verbose: print(self.results)
        return self.results

    def get_samples(self):
        return pd.DataFrame({
            key: np.array(self.mcmc_res._samples[key]).flatten()
            for key in self.mcmc_res._samples
        })

    def get_ICxx(self, xx):
        # returned in log concentration
        ics = []
        for i, row in self.get_samples().iterrows():
            try:
                ic = np.exp((1 / row.H) * (np.log((row.Emax - row.E0) /
                                                  (xx - row.E0))) -
                            np.log(10**row.log_EC50))
                ics.append(ic)
            except:
                ics.append(np.inf)
        return ics
示例#22
0
    def fit_mcmc(self, **options):
        r"""
        Runs NUTS inference to generate posterior samples.

        This uses the :class:`~pyro.infer.mcmc.nuts.NUTS` kernel to run
        :class:`~pyro.infer.mcmc.api.MCMC`, setting the ``.samples``
        attribute on completion.

        This uses an asymptotically exact enumeration-based model when
        ``num_quant_bins > 1``, and a cheaper moment-matched approximate model
        when ``num_quant_bins == 1``.

        :param \*\*options: Options passed to
            :class:`~pyro.infer.mcmc.api.MCMC`. The remaining options are
            pulled out and have special meaning.
        :param int num_samples: Number of posterior samples to draw via mcmc.
            Defaults to 100.
        :param int max_tree_depth: (Default 5). Max tree depth of the
            :class:`~pyro.infer.mcmc.nuts.NUTS` kernel.
        :param full_mass: Specification of mass matrix of the
            :class:`~pyro.infer.mcmc.nuts.NUTS` kernel. Defaults to full mass
            over global random variables.
        :param bool arrowhead_mass: Whether to treat ``full_mass`` as the head
            of an arrowhead matrix versus simply as a block. Defaults to False.
        :param int num_quant_bins: If greater than 1, use asymptotically exact
            inference via local enumeration over this many quantization bins.
            If equal to 1, use continuous-valued relaxed approximate inference.
            Note that computational cost is exponential in `num_quant_bins`.
            Defaults to 1 for relaxed inference.
        :param bool haar: Whether to use a Haar wavelet reparameterizer.
            Defaults to True.
        :param int haar_full_mass: Number of low frequency Haar components to
            include in the full mass matrix. If ``haar=False`` then this is
            ignored. Defaults to 10.
        :param int heuristic_num_particles: Passed to :meth:`heuristic` as
            ``num_particles``. Defaults to 1024.
        :returns: An MCMC object for diagnostics, e.g. ``MCMC.summary()``.
        :rtype: ~pyro.infer.mcmc.api.MCMC
        """
        _require_double_precision()

        # Parse options, saving some for use in .predict().
        num_samples = options.setdefault("num_samples", 100)
        num_chains = options.setdefault("num_chains", 1)
        self.num_quant_bins = options.pop("num_quant_bins", 1)
        assert isinstance(self.num_quant_bins, int)
        assert self.num_quant_bins >= 1
        self.relaxed = self.num_quant_bins == 1

        # Setup Haar wavelet transform.
        haar = options.pop("haar", False)
        haar_full_mass = options.pop("haar_full_mass", 10)
        full_mass = options.pop("full_mass", self.full_mass)
        assert isinstance(haar, bool)
        assert isinstance(haar_full_mass, int) and haar_full_mass >= 0
        assert isinstance(full_mass, (bool, list))
        haar_full_mass = min(haar_full_mass, self.duration)
        if not haar:
            haar_full_mass = 0
        if full_mass is True:
            haar_full_mass = 0  # No need to split.
        elif haar_full_mass >= self.duration:
            full_mass = True  # Effectively full mass.
            haar_full_mass = 0
        if haar:
            time_dim = -2 if self.is_regional else -1
            dims = {"auxiliary": time_dim}
            supports = {"auxiliary": constraints.interval(-0.5, self.population + 0.5)}
            for name, (fn, is_regional) in self._non_compartmental.items():
                dims[name] = time_dim - fn.event_dim
                supports[name] = fn.support
            haar = _HaarSplitReparam(haar_full_mass, self.duration, dims, supports)
        if haar_full_mass:
            assert full_mass and isinstance(full_mass, list)
            full_mass = full_mass[:]
            full_mass[0] += tuple(name + "_haar_split_0" for name in sorted(dims))

        # Heuristically initialize to feasible latents.
        heuristic_options = {k.replace("heuristic_", ""): options.pop(k)
                             for k in list(options)
                             if k.startswith("heuristic_")}
        init_strategy = init_to_generated(
            generate=functools.partial(self._heuristic, haar, **heuristic_options))

        # Configure a kernel.
        logger.info("Running inference...")
        model = self._relaxed_model if self.relaxed else self._quantized_model
        if haar:
            model = haar.reparam(model)
        kernel = NUTS(model,
                      full_mass=full_mass,
                      init_strategy=init_strategy,
                      max_plate_nesting=self.max_plate_nesting,
                      jit_compile=options.pop("jit_compile", False),
                      jit_options=options.pop("jit_options", None),
                      ignore_jit_warnings=options.pop("ignore_jit_warnings", True),
                      target_accept_prob=options.pop("target_accept_prob", 0.8),
                      max_tree_depth=options.pop("max_tree_depth", 5))
        if options.pop("arrowhead_mass", False):
            kernel.mass_matrix_adapter = ArrowheadMassMatrix()

        # Run mcmc.
        options.setdefault("disable_validation", None)
        mcmc = MCMC(kernel, **options)
        mcmc.run()
        self.samples = mcmc.get_samples()
        if haar:
            haar.aux_to_user(self.samples)

        # Unsqueeze samples to align particle dim for use in poutine.condition.
        # TODO refactor to an align_samples or particle_dim kwarg to MCMC.get_samples().
        model = self._relaxed_model if self.relaxed else self._quantized_model
        self.samples = align_samples(self.samples, model,
                                     particle_dim=-1 - self.max_plate_nesting)
        assert all(v.size(0) == num_samples * num_chains for v in self.samples.values()), \
            {k: tuple(v.shape) for k, v in self.samples.items()}

        return mcmc  # E.g. so user can run mcmc.summary().
示例#23
0
文件: bnn.py 项目: TyXe-BDL/TyXe
class MCMC_BNN(_BNN):
    """Supervised BNN class with an interface to pyro's MCMC that is unified with the VariationalBNN class.

    :param callable kernel_builder: function or class that returns an object that will accepted as kernel by
        pyro.infer.mcmc.MCMC, e.g. pyro.infer.mcmc.HMC or NUTS. Will be called with the entire model, i.e. also
        infer variables in the likelihood."""
    def __init__(self, net, prior, likelihood, kernel_builder, name=""):
        super().__init__(net, prior, name=name)
        self.likelihood = likelihood
        self.kernel = kernel_builder(self.model)
        self._mcmc = None

    def model(self, x, obs=None):
        predictions = self(*_as_tuple(x))
        self.likelihood(predictions, obs)
        return predictions

    def fit(self,
            data_loader,
            num_samples,
            device=None,
            batch_data=False,
            **mcmc_kwargs):
        """Runs MCMC on the data from data loader using the kernel that was used to instantiate the class.

        :param data_loader: iterable or list of batched inputs to the net. If iterable treated like the data_loader
            of VariationalBNN and all network inputs are concatenated via torch.cat. Otherwise must be a tuple of
            a single or list of network inputs and a tensor for the targets.
        :param int num_samples: number of MCMC samples to draw.
        :param torch.device device: optional device to send the data to.
        :param batch_data: whether to treat data_loader as a full batch of data or an iterable over mini-batches.
        :param dict mcmc_kwargs: keyword arguments for initializing the pyro.infer.mcmc.MCMC object."""
        if batch_data:
            input_data, observation_data = data_loader
        else:
            input_data_lists = defaultdict(list)
            observation_data_list = []
            for in_data, obs_data in iter(data_loader):
                for i, data in enumerate(_as_tuple(in_data)):
                    input_data_lists[i].append(data.to(device))
                observation_data_list.append(obs_data.to(device))
            input_data = tuple(
                torch.cat(input_data_lists[i])
                for i in range(len(input_data_lists)))
            observation_data = torch.cat(observation_data_list)
        self._mcmc = MCMC(self.kernel, num_samples, **mcmc_kwargs)
        self._mcmc.run(input_data, observation_data)

        return self._mcmc

    def predict(self, *input_data, num_predictions=1, aggregate=True):
        if self._mcmc is None:
            raise RuntimeError(
                "Call .fit to run MCMC and obtain samples from the posterior first."
            )

        preds = []
        weight_samples = self._mcmc.get_samples(num_samples=num_predictions)
        with torch.no_grad():
            for i in range(num_predictions):
                weights = {
                    name: sample[i]
                    for name, sample in weight_samples.items()
                }
                preds.append(poutine.condition(self, weights)(*input_data))
        predictions = torch.stack(preds)
        return self.likelihood.aggregate_predictions(
            predictions) if aggregate else predictions
# In[21]:


my_mcmc1.summary()


# Those don't look quite right... The means seem very different from the point estimates found by the regression from `sklearn`.
# 
# Let's grab the individual samples from our sampler, and turn those into a dataframe (they are returned as a dictionary).
# We can grab the mean of each distribution as a coefficient point estimate, and then calculate a set of predictions for our data points. Then, we can compare them to our known values for house prices

# In[22]:


beta_df = pd.DataFrame(my_mcmc1.get_samples())
beta_df.head()


# In[23]:


# TO DELETE
from pandas.plotting import table # EDIT: see deprecation warnings below

plt.figure(figsize=(30,5))
ax = plt.subplot(111, frame_on=False) # no visible frame
ax.xaxis.set_visible(False)  # hide the x axis
ax.yaxis.set_visible(False)  # hide the y axis

tabla = table(ax, beta_df.head(10), loc='upper center')  # where df is your data frame
示例#25
0
class HierarchicalBayesianCalibrator:

    def __init__(self, prior_params, num_classes, **kwargs):
        self.num_classes = num_classes
        # Inference parameters
        self.NUTS_params = {'adapt_step_size': kwargs.pop('adapt_step_size', True),
                            'target_accept_prob': kwargs.pop('target_accept_prob', 0.8),
                            'max_plate_nesting': 1
                            }
        self.mcmc_params = {'num_samples': kwargs.pop('num_samples', 250),
                            'warmup_steps': kwargs.pop('num_warmup', 1000),
                            'num_chains': kwargs.pop('num_chains', 4)
                            }

        # Constraints on delta; choices are None, 'soft', 'hard'
        self.delta_constraint = kwargs.pop('delta_constraint', 'soft')
        assert self.delta_constraint in [None, 'soft', 'hard'], 'Invalid delta constraint'

        # Prior parameters on beta / delta ; assumes each delta is iid
        self.prior_params = {'mu_beta': prior_params['mu_beta'],
                             'sigma_beta': prior_params['sigma_beta'],
                             'mu_delta': torch.empty(self.num_classes).fill_(prior_params['mu_delta']),
                             'sigma_delta': torch.empty(self.num_classes).fill_(prior_params['sigma_delta'])}

        # Tracking params
        # TODO: Prior/posterior trace
        self.timestep = 0
        self.mcmc = None  # Contains the most recent Pyro MCMC api object

        print('\nInitializing HBC model:\n'
              '----| Prior: {} \n----| Inference Method: NUTS \n'
              '----| MCMC parameters: {}'
              ''.format(prior_params, self.mcmc_params))

    def update(self, logits, labels):
        """ Performs an update given new observations.

        Args:
            logits: tensor ; shape (batch_size, num_classes)
            labels: tensor ; shape (batch_size, )
        """
        assert len(labels.shape) == 1, 'Got label tensor with shape {} -- labels must be dense'.format(labels.shape)
        assert len(logits.shape) == 2, 'Got logit tensor with shape {}'.format(logits.shape)
        assert (labels.shape[0] == logits.shape[0]), 'Shape mismatch between logits ({}) and labels ({})' \
            .format(logits.shape[0], labels.shape[0])

        print('delta constraint::')
        print(self.delta_constraint)
        self.timestep += 1

        logits = logits.detach().clone().requires_grad_()
        labels = labels.detach().clone()

        batch_size = labels.shape[0]
        print('----| Updating HBC model\n--------| Got a batch size of: {}'.format(batch_size))

        # TODO: Update prior (for sequential)
        # self._update_prior()
        # print('--------| Updated priors: {}'.format(self.prior_params))

        print('--------| Running inference ')
        nuts_kernel = NUTS(hbc_model, **self.NUTS_params)
        self.mcmc = MCMC(nuts_kernel, **self.mcmc_params, disable_progbar=False)
        print('.')
        self.mcmc.run(self.prior_params, logits, labels, self.delta_constraint)
        print('..')

        #  TODO: update posterior (for sequential)
        # self._update_posterior(posterior_samples)

        return self.mcmc

    def get_current_posterior_samples(self):
        """ Returns the current posterior samples for beta, delta, and alpha.
        """
        if self.mcmc is None:
            return None

        posterior_samples = self.mcmc.get_samples()
        posterior_samples['alpha'] = posterior_samples['beta'].view(-1, 1) + posterior_samples['delta']

        return posterior_samples

    def calibrate(self, logit):
        """ Calibrates the given batch of logits using the current posterior samples.

        Args:
            logit: tensor ; shape (batch_size, num_classes)
        """
        # Get alpha samples
        alpha_samples = self.get_current_posterior_samples()['alpha']  # Shape (num_samples, num_classes)
        num_samples, num_classes = alpha_samples.shape

        # Map alphas to temperatures
        temperature_samples = torch.exp(-1. * alpha_samples)  # Shape (num_samples, num_classes)

        # Get a batch of logits for each sampled temperature
        # Shape (num_samples, batch_size, num_classes)
        tempered_logit_samples = temperature_samples.view(num_samples, 1, num_classes) * logit

        # Softmax the sampled logits to get sampled probabilities
        prob_samples = softmax(tempered_logit_samples, dim=2)  # Shape (num_samples, batch_size, num_classes)

        # Average over the sampled probabilities to get Monte Carlo estimate
        calibrated_probs = prob_samples.mean(dim=0)   # Shape (batch_size, num_classes)

        return calibrated_probs

    def get_MAP_temperature(self, logits, labels):
        """ Performs MAP estimation using the current prior and given data.
         NB: This should only be called after .update() if used in a sequential setting, as this method
         does not update the prior with sigma_drift.

         See: https://pyro.ai/examples/mle_map.html
         """
        pyro.clear_param_store()
        svi = pyro.infer.SVI(model=hbc_model, guide=MAP_guide,
                             optim=pyro.optim.Adam({'lr': 0.001}), loss=pyro.infer.Trace_ELBO())

        loss = []
        num_steps = 5000
        for _ in range(num_steps):
            loss.append(svi.step(self.prior_params, logits, labels))

        eps = 2e-2
        loss_sddev = np.std(loss[-25:])
        if loss_sddev > eps:
            warnings.warn('MAP optimization may not have converged ; sddev {}'.format(loss_sddev))

        beta_MAP = pyro.param('beta_MAP').detach()
        delta_MAP = pyro.param('delta_MAP').detach()
        return beta_MAP, delta_MAP
示例#26
0
        tavg_norm_noauto_3d, tavg_raw_all_3d, tavg_raw_noauto_3d
    ] = pickle.load(f)

tm.mtype = 'group'
tm.target = 'self'  # 'self','targ','avg'
tm.dtype = 'norm'  # 'norm','raw'
tm.auto = 'all'  # 'noauto','all'
tm.stickbreak = False
tm.optim = pyro.optim.Adam({'lr': 0.0005, 'betas': [0.8, 0.99]})
tm.elbo = TraceEnum_ELBO(max_plate_nesting=1)

tm.K = 3

pyro.clear_param_store()
pyro.set_rng_seed(99)

# #declare dataset to be modeled
# dtname = 't{}_{}_{}_3d'.format(target, dtype, auto)
# print("running MCMC with: {}".format(dtname))
# data = globals()[dtname]

nuts_kernel = NUTS(tm.model)

mcmc = MCMC(nuts_kernel, num_samples=5000, warmup_steps=1000)
mcmc.run(tself_norm_all_3d)

posterior_samples = mcmc.get_samples()

abc = az.from_pyro(mcmc, log_likelihood=True)
az.stats.waic(abc.posterior.weights)
示例#27
0
class BayesianBinaryTester(BayesianTester):
    def __init__(self,
                 outcome: tensor,
                 traffic_size: tensor,
                 warmup_steps: int = 100,
                 num_samples: int = 1000) -> None:
        pyro.clear_param_store()
        self.outcome = outcome
        self.traffic_size = traffic_size
        self.model = self.model_generator()
        self.kernel = NUTS(self.model)
        self.mcmc = MCMC(self.kernel,
                         warmup_steps=warmup_steps,
                         num_samples=num_samples)

    def model_generator(self) -> Callable:
        def _model_(self):
            control_prior = pyro.sample('control_p', dist.Beta(1, 1))
            treatment_prior = pyro.sample('treatment_p', dist.Beta(1, 1))
            return pyro.sample(
                'obs',
                dist.Binomial(self.traffic_size,
                              torch.stack([control_prior, treatment_prior])),
                obs=self.outcome)

        return partial(_model_, self)

    def run(self, n_samples=3000) -> None:
        self.mcmc.run()
        self.posterior_samples = self.mcmc.get_samples(n_samples)

    def expected_loss_switch(self) -> float:
        return torch.mean(
            self.loss(self.posterior_samples['treatment_p'],
                      self.posterior_samples['control_p']))

    def expected_loss_stay(self) -> float:
        return torch.mean(
            self.loss(self.posterior_samples['control_p'],
                      self.posterior_samples['treatment_p']))

    def improvement_probability(self) -> float:
        return torch.mean(
            (self.loss(self.posterior_samples['treatment_p'],
                       self.posterior_samples['control_p']) > 0).float())

    def summary(self) -> None:
        cost_of_switching = self.expected_loss_switch()
        cost_of_stay = self.expected_loss_stay()
        prob_improve = self.improvement_probability()
        print('Potential cost of not switching: {:.2%}'.format(
            cost_of_switching))
        print('Cost of wrong switch: {:.2%}'.format(cost_of_stay))
        print(
            'Probability of Treatment is better: {:.2%}'.format(prob_improve))

    def plot_joint_posterior(self) -> None:
        posterior_df = pd.DataFrame(self.posterior_samples)
        g = sns.jointplot(x='control_p',
                          y='treatment_p',
                          data=posterior_df,
                          kind='kde',
                          fill=True,
                          levels=10)
        minimum = min([i.min() for i in self.posterior_samples.values()])
        maximum = max([i.max() for i in self.posterior_samples.values()])
        g.ax_marg_x.set_xlim(minimum, maximum)
        g.ax_marg_y.set_ylim(minimum, maximum)
        x0, x1 = g.ax_joint.get_xlim()
        y0, y1 = g.ax_joint.get_ylim()
        lims = [max(x0, y0), min(x1, y1)]
        g.ax_joint.plot(lims, lims, ':k')
        plt.show()

    def plot_posterior(self) -> None:
        posterior_melted_df = pd.melt(pd.DataFrame(self.posterior_samples))
        sns.kdeplot(data=posterior_melted_df, hue='variable', x='value')
        plt.show()
示例#28
0
    plt.ylim((-2,2))
    for _ in range(50):
        s_pred = (system.guide(X, actions)(X,actions)).detach()[:,:-1]
        
        plt.plot(s_pred[n,:,i], alpha = 0.2)
    plt.show()

#%%

#%% MCMC
from pyro.infer import NUTS, MCMC, EmpiricalMarginal
nuts_kernel = NUTS(system.model, adapt_step_size=True)
mcmc = MCMC(nuts_kernel, num_samples=100, warmup_steps=30)
mcmc_run = mcmc.run(X, actions)
#%%
hmc_samples = {k : v.detach().unsqueeze(1) for k, v in mcmc.get_samples().items()}
for key,val in hmc_samples.items():
    print(key)
#%%
#hmc_samples["network$$$a_emb.weight"][0]
hmc_samples["network$$$a_emb.weight"].mean(0)
#%%
for a in range(num_actions):
    plt.plot(hmc_samples["network$$$a_emb.weight"][:,0,a,:])

plt.plot(hmc_samples["network$$$a_emb.weight"][:,0,2,1])
#%%

plt.plot(st[0,:,0],st[0,:,1])
plt.plot(st_pred[0,:,0],st_pred[0,:,1])
示例#29
0
文件: baseball.py 项目: youqad/pyro
def main(args):
    baseball_dataset = pd.read_csv(DATA_URL, "\t")
    train, _, player_names = train_test_split(baseball_dataset)
    at_bats, hits = train[:, 0], train[:, 1]
    logging.info("Original Dataset:")
    logging.info(baseball_dataset)

    # (1) Full Pooling Model
    # In this model, we illustrate how to use MCMC with general potential_fn.
    init_params, potential_fn, transforms, _ = initialize_model(
        fully_pooled, model_args=(at_bats, hits), num_chains=args.num_chains,
        jit_compile=args.jit, skip_jit_warnings=True)
    nuts_kernel = NUTS(potential_fn=potential_fn)
    mcmc = MCMC(nuts_kernel,
                num_samples=args.num_samples,
                warmup_steps=args.warmup_steps,
                num_chains=args.num_chains,
                initial_params=init_params,
                transforms=transforms)
    mcmc.run(at_bats, hits)
    samples_fully_pooled = mcmc.get_samples()
    logging.info("\nModel: Fully Pooled")
    logging.info("===================")
    logging.info("\nphi:")
    logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True),
                                   sites=["phi"],
                                   player_names=player_names,
                                   diagnostics=True,
                                   group_by_chain=True)["phi"])
    num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values()))
    logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
    sample_posterior_predictive(fully_pooled, samples_fully_pooled, baseball_dataset)
    evaluate_pointwise_pred_density(fully_pooled, samples_fully_pooled, baseball_dataset)

    # (2) No Pooling Model
    nuts_kernel = NUTS(not_pooled, jit_compile=args.jit, ignore_jit_warnings=True)
    mcmc = MCMC(nuts_kernel,
                num_samples=args.num_samples,
                warmup_steps=args.warmup_steps,
                num_chains=args.num_chains)
    mcmc.run(at_bats, hits)
    samples_not_pooled = mcmc.get_samples()
    logging.info("\nModel: Not Pooled")
    logging.info("=================")
    logging.info("\nphi:")
    logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True),
                                   sites=["phi"],
                                   player_names=player_names,
                                   diagnostics=True,
                                   group_by_chain=True)["phi"])
    num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values()))
    logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
    sample_posterior_predictive(not_pooled, samples_not_pooled, baseball_dataset)
    evaluate_pointwise_pred_density(not_pooled, samples_not_pooled, baseball_dataset)

    # (3) Partially Pooled Model
    nuts_kernel = NUTS(partially_pooled, jit_compile=args.jit, ignore_jit_warnings=True)
    mcmc = MCMC(nuts_kernel,
                num_samples=args.num_samples,
                warmup_steps=args.warmup_steps,
                num_chains=args.num_chains)
    mcmc.run(at_bats, hits)
    samples_partially_pooled = mcmc.get_samples()
    logging.info("\nModel: Partially Pooled")
    logging.info("=======================")
    logging.info("\nphi:")
    logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True),
                                   sites=["phi"],
                                   player_names=player_names,
                                   diagnostics=True,
                                   group_by_chain=True)["phi"])
    num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values()))
    logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
    sample_posterior_predictive(partially_pooled, samples_partially_pooled, baseball_dataset)
    evaluate_pointwise_pred_density(partially_pooled, samples_partially_pooled, baseball_dataset)

    # (4) Partially Pooled with Logit Model
    nuts_kernel = NUTS(partially_pooled_with_logit, jit_compile=args.jit, ignore_jit_warnings=True)
    mcmc = MCMC(nuts_kernel,
                num_samples=args.num_samples,
                warmup_steps=args.warmup_steps,
                num_chains=args.num_chains)
    mcmc.run(at_bats, hits)
    samples_partially_pooled_logit = mcmc.get_samples()
    logging.info("\nModel: Partially Pooled with Logit")
    logging.info("==================================")
    logging.info("\nSigmoid(alpha):")
    logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True),
                                   sites=["alpha"],
                                   player_names=player_names,
                                   transforms={"alpha": torch.sigmoid},
                                   diagnostics=True,
                                   group_by_chain=True)["alpha"])
    num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values()))
    logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
    sample_posterior_predictive(partially_pooled_with_logit, samples_partially_pooled_logit,
                                baseball_dataset)
    evaluate_pointwise_pred_density(partially_pooled_with_logit, samples_partially_pooled_logit,
                                    baseball_dataset)
for site, values in summary(svi_samples).items():
    print("Site: {}".format(site))
    print(values, "\n")



from pyro.infer import MCMC, NUTS


nuts_kernel = NUTS(model)

mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200)
mcmc.run(is_cont_africa, ruggedness, log_gdp)

hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}




for site, values in summary(hmc_samples).items():
    print("Site: {}".format(site))
    print(values, "\n")




sites = ["a", "bA", "bR", "bAR", "sigma"]

fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(12, 10))
fig.suptitle("Marginal Posterior density - Regression Coefficients", fontsize=16)