Пример #1
0
 def test_stack_overwrite_failure(self):
     data1 = {"latent2": torch.randn(2)}
     data2 = {"latent2": torch.randn(2)}
     cm = poutine.condition(poutine.condition(self.model, data=data1),
                            data=data2)
     with pytest.raises(AssertionError):
         cm()
Пример #2
0
    def predict(self, forecast=0):
        """
        Predict latent variables and optionally forecast forward.

        This may be run only after :meth:`fit_mcmc` and draws the same
        ``num_samples`` as passed to :meth:`fit_mcmc`.

        :param int forecast: The number of time steps to forecast forward.
        :returns: A dictionary mapping sample site name (or compartment name)
            to a tensor whose first dimension corresponds to sample batching.
        :rtype: dict
        """
        if self.num_quant_bins > 1:
            _require_double_precision()
        if not self.samples:
            raise RuntimeError("Missing samples, try running .fit_mcmc() first")

        samples = self.samples
        num_samples = len(next(iter(samples.values())))
        particle_plate = pyro.plate("particles", num_samples,
                                    dim=-1 - self.max_plate_nesting)

        # Sample discrete auxiliary variables conditioned on the continuous
        # variables sampled by _quantized_model. This samples only time steps
        # [0:duration]. Here infer_discrete runs a forward-filter
        # backward-sample algorithm.
        logger.info("Predicting latent variables for {} time steps..."
                    .format(self.duration))
        model = self._sequential_model
        model = poutine.condition(model, samples)
        model = particle_plate(model)
        if not self.relaxed:
            model = infer_discrete(model, first_available_dim=-2 - self.max_plate_nesting)
        trace = poutine.trace(model).get_trace()
        samples = OrderedDict((name, site["value"].expand(site["fn"].shape()))
                              for name, site in trace.nodes.items()
                              if site["type"] == "sample"
                              if not site_is_subsample(site)
                              if not site_is_factor(site))
        assert all(v.size(0) == num_samples for v in samples.values()), \
            {k: tuple(v.shape) for k, v in samples.items()}

        # Optionally forecast with the forward _generative_model. This samples
        # time steps [duration:duration+forecast].
        if forecast:
            logger.info("Forecasting {} steps ahead...".format(forecast))
            model = self._generative_model
            model = poutine.condition(model, samples)
            model = particle_plate(model)
            trace = poutine.trace(model).get_trace(forecast)
            samples = OrderedDict((name, site["value"])
                                  for name, site in trace.nodes.items()
                                  if site["type"] == "sample"
                                  if not site_is_subsample(site)
                                  if not site_is_factor(site))

        self._concat_series(samples, trace, forecast)
        assert all(v.size(0) == num_samples for v in samples.values()), \
            {k: tuple(v.shape) for k, v in samples.items()}
        return samples
Пример #3
0
def save_posterior_predictive(model, guide, filename, N=300):
    if N == 1:
        mock = {}
        guide_trace = poutine.trace(guide).get_trace()
        trace = poutine.trace(poutine.condition(model,
                                                data=guide_trace)).get_trace()
        for tag in trace:
            if trace.nodes[tag]["type"] == "sample":
                mock[tag] = trace.nodes[tag]["value"].detach().cpu().numpy()
    else:
        mock = defaultdict(list)
        for i in range(N):
            # Faster way if we don't need `deterministic` statements.
            # Literally just samples from the guide.
            #
            # for tag, value in guide()[1].items():
            #     mock[tag].append(value.detach().cpu().numpy())
            # continue

            guide_trace = poutine.trace(guide).get_trace()
            trace = poutine.trace(poutine.condition(
                model, data=guide_trace)).get_trace()
            for tag in trace:
                if trace.nodes[tag]["type"] == "sample":
                    mock[tag].append(
                        trace.nodes[tag]["value"].detach().cpu().numpy())

    np.savez(filename, **mock)
    print("Saved %i sample(s) from posterior predictive distribution to %s" %
          (N, filename))
Пример #4
0
 def test_stack_overwrite_failure(self):
     data1 = {"latent2": torch.randn(2)}
     data2 = {"latent2": torch.randn(2)}
     cm = poutine.condition(poutine.condition(self.model, data=data1),
                            data=data2)
     with pytest.raises(AssertionError):
         cm()
Пример #5
0
 def test_stack_overwrite_behavior(self):
     data1 = {"latent2": torch.randn(2)}
     data2 = {"latent2": torch.randn(2)}
     with poutine.trace() as tr:
         cm = poutine.condition(poutine.condition(self.model, data=data1),
                                data=data2)
         cm()
     assert tr.trace.nodes['latent2']['value'] is data2['latent2']
Пример #6
0
def _predictive(model, posterior_samples, num_samples, return_sites=(),
                return_trace=False, parallel=False, model_args=(), model_kwargs={}):
    max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs)
    vectorize = pyro.plate("_num_predictive_samples", num_samples, dim=-max_plate_nesting-1)
    model_trace = prune_subsample_sites(poutine.trace(model).get_trace(*model_args, **model_kwargs))
    reshaped_samples = {}

    for name, sample in posterior_samples.items():
        sample_shape = sample.shape[1:]
        sample = sample.reshape((num_samples,) + (1,) * (max_plate_nesting - len(sample_shape)) + sample_shape)
        reshaped_samples[name] = sample

    if return_trace:
        trace = poutine.trace(poutine.condition(vectorize(model), reshaped_samples))\
            .get_trace(*model_args, **model_kwargs)
        return trace

    return_site_shapes = {}
    for site in model_trace.stochastic_nodes + model_trace.observation_nodes:
        append_ndim = max_plate_nesting - len(model_trace.nodes[site]["fn"].batch_shape)
        site_shape = (num_samples,) + (1,) * append_ndim + model_trace.nodes[site]['value'].shape
        # non-empty return-sites
        if return_sites:
            if site in return_sites:
                return_site_shapes[site] = site_shape
        # special case (for guides): include all sites
        elif return_sites is None:
            return_site_shapes[site] = site_shape
        # default case: return sites = ()
        # include all sites not in posterior samples
        elif site not in posterior_samples:
            return_site_shapes[site] = site_shape

    # handle _RETURN site
    if return_sites is not None and '_RETURN' in return_sites:
        value = model_trace.nodes['_RETURN']['value']
        shape = (num_samples,) + value.shape if torch.is_tensor(value) else None
        return_site_shapes['_RETURN'] = shape

    if not parallel:
        return _predictive_sequential(model, posterior_samples, model_args, model_kwargs, num_samples,
                                      return_site_shapes, return_trace=False)

    trace = poutine.trace(poutine.condition(vectorize(model), reshaped_samples))\
        .get_trace(*model_args, **model_kwargs)
    predictions = {}
    for site, shape in return_site_shapes.items():
        value = trace.nodes[site]['value']
        if site == '_RETURN' and shape is None:
            predictions[site] = value
            continue
        if value.numel() < reduce((lambda x, y: x * y), shape):
            predictions[site] = value.expand(shape)
        else:
            predictions[site] = value.reshape(shape)

    return predictions
Пример #7
0
 def test_stack_success(self):
     data1 = {"latent1": torch.randn(2)}
     data2 = {"latent2": torch.randn(2)}
     tr = poutine.trace(
         poutine.condition(poutine.condition(self.model, data=data1),
                           data=data2)).get_trace()
     assert tr.nodes["latent1"]["type"] == "sample" and \
         tr.nodes["latent1"]["is_observed"]
     assert tr.nodes["latent1"]["value"] is data1["latent1"]
     assert tr.nodes["latent2"]["type"] == "sample" and \
         tr.nodes["latent2"]["is_observed"]
     assert tr.nodes["latent2"]["value"] is data2["latent2"]
Пример #8
0
def test_nested():
    shape = (5, 6)

    @poutine.reparam(config={
        "x": HaarReparam(dim=-1),
        "x_haar": HaarReparam(dim=-2)
    })
    def model():
        pyro.sample("x", dist.Normal(torch.zeros(shape), 1).to_event(2))

    # Try without initialization, e.g. in AutoGuide._setup_prototype().
    trace = poutine.trace(model).get_trace()
    assert {"x", "x_haar", "x_haar_haar"}.issubset(trace.nodes)
    assert trace.nodes["x"]["is_observed"]
    assert trace.nodes["x_haar"]["is_observed"]
    assert not trace.nodes["x_haar_haar"]["is_observed"]
    assert trace.nodes["x"]["value"].shape == shape

    # Try conditioning on x_haar_haar, e.g. in Predictive.
    x = torch.randn(shape)
    x_haar = HaarTransform(dim=-1)(x)
    x_haar_haar = HaarTransform(dim=-2)(x_haar)
    with poutine.condition(data={"x_haar_haar": x_haar_haar}):
        trace = poutine.trace(model).get_trace()
        assert {"x", "x_haar", "x_haar_haar"}.issubset(trace.nodes)
        assert trace.nodes["x"]["is_observed"]
        assert trace.nodes["x_haar"]["is_observed"]
        assert trace.nodes["x_haar_haar"]["is_observed"]
        assert_close(trace.nodes["x"]["value"], x)
        assert_close(trace.nodes["x_haar"]["value"], x_haar)
        assert_close(trace.nodes["x_haar_haar"]["value"], x_haar_haar)

    # Try with custom initialization.
    # This is required for autoguides and MCMC.
    with InitMessenger(init_to_value(values={"x": x})):
        trace = poutine.trace(model).get_trace()
        assert {"x", "x_haar", "x_haar_haar"}.issubset(trace.nodes)
        assert trace.nodes["x"]["is_observed"]
        assert trace.nodes["x_haar"]["is_observed"]
        assert not trace.nodes["x_haar_haar"]["is_observed"]
        assert_close(trace.nodes["x"]["value"], x)

    # Try conditioning on x.
    x = torch.randn(shape)
    with poutine.condition(data={"x": x}):
        trace = poutine.trace(model).get_trace()
        assert {"x", "x_haar", "x_haar_haar"}.issubset(trace.nodes)
        assert trace.nodes["x"]["is_observed"]
        assert trace.nodes["x_haar"]["is_observed"]
        # TODO Decide whether it is worth fixing this failing assertion.
        # See https://github.com/pyro-ppl/pyro/issues/2878
        # assert trace.nodes["x_haar_haar"]["is_observed"]
        assert_close(trace.nodes["x"]["value"], x)
Пример #9
0
 def test_stack_success(self):
     data1 = {"latent1": torch.randn(2)}
     data2 = {"latent2": torch.randn(2)}
     tr = poutine.trace(
         poutine.condition(poutine.condition(self.model, data=data1),
                           data=data2)).get_trace()
     assert tr.nodes["latent1"]["type"] == "sample" and \
         tr.nodes["latent1"]["is_observed"]
     assert tr.nodes["latent1"]["value"] is data1["latent1"]
     assert tr.nodes["latent2"]["type"] == "sample" and \
         tr.nodes["latent2"]["is_observed"]
     assert tr.nodes["latent2"]["value"] is data2["latent2"]
Пример #10
0
def test_counterfactual_query(intervene, observe, flip):
    # x -> y -> z -> w

    sites = ["x", "y", "z", "w"]
    observations = {"x": 1., "y": None, "z": 1., "w": 1.}
    interventions = {"x": None, "y": 0., "z": 2., "w": 1.}

    def model():
        x = _item(pyro.sample("x", dist.Normal(0, 1)))
        y = _item(pyro.sample("y", dist.Normal(x, 1)))
        z = _item(pyro.sample("z", dist.Normal(y, 1)))
        w = _item(pyro.sample("w", dist.Normal(z, 1)))
        return dict(x=x, y=y, z=z, w=w)

    if not flip:
        if intervene:
            model = poutine.do(model, data=interventions)
        if observe:
            model = poutine.condition(model, data=observations)
    elif flip and intervene and observe:
        model = poutine.do(poutine.condition(model, data=observations),
                           data=interventions)

    tr = poutine.trace(model).get_trace()
    actual_values = tr.nodes["_RETURN"]["value"]
    for name in sites:
        # case 1: purely observational query like poutine.condition
        if not intervene and observe:
            if observations[name] is not None:
                assert tr.nodes[name]['is_observed']
                assert_equal(observations[name], actual_values[name])
                assert_equal(observations[name], tr.nodes[name]['value'])
            if interventions[name] != observations[name]:
                assert_not_equal(interventions[name], actual_values[name])
        # case 2: purely interventional query like old poutine.do
        elif intervene and not observe:
            assert not tr.nodes[name]['is_observed']
            if interventions[name] is not None:
                assert_equal(interventions[name], actual_values[name])
            assert_not_equal(observations[name], tr.nodes[name]['value'])
            assert_not_equal(interventions[name], tr.nodes[name]['value'])
        # case 3: counterfactual query mixing intervention and observation
        elif intervene and observe:
            if observations[name] is not None:
                assert tr.nodes[name]['is_observed']
                assert_equal(observations[name], tr.nodes[name]['value'])
            if interventions[name] is not None:
                assert_equal(interventions[name], actual_values[name])
            if interventions[name] != observations[name]:
                assert_not_equal(interventions[name], tr.nodes[name]['value'])
Пример #11
0
 def test_condition(self):
     data = {"latent2": torch.randn(2)}
     tr2 = poutine.trace(poutine.condition(self.model, data=data)).get_trace()
     assert "latent2" in tr2
     assert tr2.nodes["latent2"]["type"] == "sample" and \
         tr2.nodes["latent2"]["is_observed"]
     assert tr2.nodes["latent2"]["value"] is data["latent2"]
Пример #12
0
def generate_data(args):
    logging.info("Generating data...")
    params = {"R0": torch.tensor(args.basic_reproduction_number),
              "rho": torch.tensor(args.response_rate)}
    empty_data = [None] * (args.duration + args.forecast)

    # We'll retry until we get an actual outbreak.
    for attempt in range(100):
        with poutine.trace() as tr:
            with poutine.condition(data=params):
                discrete_model(args, empty_data)

        # Concatenate sequential time series into tensors.
        obs = torch.stack([site["value"]
                           for name, site in tr.trace.nodes.items()
                           if re.match("obs_[0-9]+", name)])
        S2I = torch.stack([site["value"]
                          for name, site in tr.trace.nodes.items()
                          if re.match("S2I_[0-9]+", name)])
        assert len(obs) == len(empty_data)

        obs_sum = int(obs[:args.duration].sum())
        S2I_sum = int(S2I[:args.duration].sum())
        if obs_sum >= args.min_observations:
            logging.info("Observed {:d}/{:d} infections:\n{}".format(
                obs_sum, S2I_sum, " ".join([str(int(x)) for x in obs[:args.duration]])))
            return {"S2I": S2I, "obs": obs}

    raise ValueError("Failed to generate {} observations. Try increasing "
                     "--population or decreasing --min-observations"
                     .format(args.min_observations))
Пример #13
0
def _predictive_sequential(model,
                           posterior_samples,
                           model_args,
                           model_kwargs,
                           num_samples,
                           sample_sites,
                           return_trace=False):
    collected = []
    samples = [{k: v[i]
                for k, v in posterior_samples.items()}
               for i in range(num_samples)]
    for i in range(num_samples):
        trace = poutine.trace(poutine.condition(model, samples[i])).get_trace(
            *model_args, **model_kwargs)
        if return_trace:
            collected.append(trace)
        else:
            collected.append(
                {site: trace.nodes[site]['value']
                 for site in sample_sites})

    return collected if return_trace else {
        site: torch.stack([s[site] for s in collected])
        for site in sample_sites
    }
Пример #14
0
def run_pyro(site_values, data, model, transformed_data, n_samples, params):

    # import model, transformed_data functions (if exists) from pyro module

    assert model is not None, "model couldn't be imported"

    variablize_params(params)

    log_pdfs = []
    n_log_probs = None
    for j in range(n_samples):
        if n_samples > 1:
            sample_site_values = {v: site_values[v][j] for v in site_values}
        else:
            sample_site_values = {v: float(site_values[v]) if site_values[v].shape == () else site_values[v][0] for v in
                      site_values}
        #print(sample_site_values)
        process_2d_sites(sample_site_values)

        variablize_params(sample_site_values)

        model_trace = poutine.trace(poutine.condition(model, data=sample_site_values),
                                    graph_type="flat").get_trace(data, params)
        log_p = model_trace.log_pdf()
        if n_log_probs is None:
            n_log_probs = get_num_log_probs(model_trace)
        else:
            assert n_log_probs == get_num_log_probs(model_trace)
        #print(log_p.data.numpy())
        log_pdfs.append(to_float(log_p))
    return log_pdfs, n_log_probs
Пример #15
0
 def test_condition(self):
     data = {"latent2": torch.randn(2)}
     tr2 = poutine.trace(poutine.condition(self.model, data=data)).get_trace()
     assert "latent2" in tr2
     assert tr2.nodes["latent2"]["type"] == "sample" and \
         tr2.nodes["latent2"]["is_observed"]
     assert tr2.nodes["latent2"]["value"] is data["latent2"]
Пример #16
0
def posterior_replay(model, posterior_samples, *args, **kwargs):
    r"""
    Given a model and samples from the posterior (potentially with conjugate sites
    collapsed), return a `dict` of samples from the posterior with conjugate sites
    uncollapsed. Note that this can also be used to generate samples from the
    posterior predictive distribution.

    :param model: Python callable.
    :param dict posterior_samples: posterior samples keyed by site name.
    :param args: arguments to `model`.
    :param kwargs: keyword arguments to `model`.
    :return: `dict` of samples from the posterior.
    """
    posterior_samples = posterior_samples.copy()
    num_samples = kwargs.pop("num_samples", None)
    assert posterior_samples or num_samples, "`num_samples` must be provided if `posterior_samples` is empty."
    if num_samples is None:
        num_samples = list(posterior_samples.values())[0].shape[0]

    return_samples = defaultdict(list)
    for i in range(num_samples):
        conditioned_nodes = {k: v[i] for k, v in posterior_samples.items()}
        collapsed_trace = poutine.trace(poutine.condition(collapse_conjugate(model), conditioned_nodes))\
            .get_trace(*args, **kwargs)
        trace = poutine.trace(uncollapse_conjugate(model,
                                                   collapsed_trace)).get_trace(
                                                       *args, **kwargs)
        for name, site in trace.iter_stochastic_nodes():
            if not site_is_subsample(site):
                return_samples[name].append(site["value"])

    return {k: torch.stack(v) for k, v in return_samples.items()}
Пример #17
0
def get_log_prob(mcmc, data, site_names):
    """Gets the pointwise log probability of the posterior density conditioned on the data
    
    Arguments:
        mcmc (pyro.infer.mcmc.MCMC): the fitted MC model
        data (dict): dictionary containing all the input data (including return sites)
        site_names (str or List[str]): names of return sites to measure log likelihood at
    Returns:
        Tensor: pointwise log-likelihood of shape (num posterior samples, num data points)
    """
    samples = mcmc.get_samples()
    model = mcmc.kernel.model
    # get number of samples
    N = [v.shape[0] for v in samples.values()]
    assert [n == N[0] for n in N]
    N = N[0]
    if isinstance(site_names, str):
        site_names = [site_names]
    # iterate over samples
    log_prob = torch.zeros(N, len(data[site_names[0]]))
    for i in range(N):
        # condition on samples and get trace
        s = {k: v[i] for k, v in samples.items()}
        for nm in site_names:
            s[nm] = data[nm]
        tr = poutine.trace(poutine.condition(model, data=s)).get_trace(data)
        # get pointwise log probability
        for nm in site_names:
            node = tr.nodes[nm]
            log_prob[i] += node["fn"].log_prob(node["value"])
    return log_prob
Пример #18
0
 def test_trace_data(self):
     tr1 = poutine.trace(
         poutine.block(self.model, expose_types=["sample"])).get_trace()
     tr2 = poutine.trace(
         poutine.condition(self.model, data=tr1)).get_trace()
     assert tr2.nodes["latent2"]["type"] == "sample" and \
         tr2.nodes["latent2"]["is_observed"]
     assert tr2.nodes["latent2"]["value"] is tr1.nodes["latent2"]["value"]
Пример #19
0
 def test_trace_data(self):
     tr1 = poutine.trace(
         poutine.block(self.model, expose_types=["sample"])).get_trace()
     tr2 = poutine.trace(
         poutine.condition(self.model, data=tr1)).get_trace()
     assert tr2.nodes["latent2"]["type"] == "sample" and \
         tr2.nodes["latent2"]["is_observed"]
     assert tr2.nodes["latent2"]["value"] is tr1.nodes["latent2"]["value"]
Пример #20
0
 def nested():
     true_probs = torch.ones(5) * 0.7
     num_trials = torch.ones(5) * 1000
     num_success = dist.Binomial(num_trials, true_probs).sample()
     conditioned_model = poutine.condition(model, data={"obs": num_success})
     nuts_kernel = NUTS(conditioned_model, adapt_step_size=True)
     mcmc_run = MCMC(nuts_kernel, num_samples=10, warmup_steps=2).run(num_trials)
     return mcmc_run
Пример #21
0
 def __init__(self,model,data,descriptor,block_info=None,stan_info=None):
     self.descriptor = descriptor
     self.model = poutine.condition(model, data=data)
     self.parameters = get_rvs(self.model,False)
     self.observed = get_rvs(self.model,True)
     self.data = data
     self.block_info = block_info
     self.stan_info = stan_info
Пример #22
0
 def _conditioned_model(self, model, ratings):
     data = dict()
     for x in range(self.n_user):
         for y in range(self.n_item):
             if ratings[x, y] != 0:
                 data["obs" + str(x * self.n_item + y)] = \
                   torch.tensor(ratings[x,y], dtype = torch.float64)
     return poutine.condition(model, data=data)()
Пример #23
0
def conditioned_model(model, at_bats, hits):
    """
    Condition the model on observed data, for inference.

    :param model: python callable with Pyro primitives.
    :param (torch.Tensor) at_bats: Number of at bats for each player.
    :param (torch.Tensor) hits: Number of hits for the given at bats.
    """
    return poutine.condition(model, data={"obs": hits})(at_bats)
Пример #24
0
def conditioned_model(model, at_bats, hits):
    """
    Condition the model on observed data, for inference.

    :param model: python callable with Pyro primitives.
    :param (torch.Tensor) at_bats: Number of at bats for each player.
    :param (torch.Tensor) hits: Number of hits for the given at bats.
    """
    return poutine.condition(model, data={"obs": hits})(at_bats)
Пример #25
0
 def _potential_fn(self, params):
     params_constrained = {k: self.transforms[k].inv(v) for k, v in params.items()}
     cond_model = poutine.condition(self.model, params_constrained)
     model_trace = poutine.trace(cond_model).get_trace(*self.model_args,
                                                       **self.model_kwargs)
     log_joint = self.trace_prob_evaluator.log_prob(model_trace)
     for name, t in self.transforms.items():
         log_joint = log_joint - torch.sum(
             t.log_abs_det_jacobian(params_constrained[name], params[name]))
     return -log_joint
Пример #26
0
def test_posterior_predictive():
    true_probs = torch.ones(5) * 0.7
    num_trials = torch.ones(5) * 1000
    num_success = dist.Binomial(num_trials, true_probs).sample()
    conditioned_model = poutine.condition(model, data={"obs": num_success})
    nuts_kernel = NUTS(conditioned_model, adapt_step_size=True)
    mcmc_run = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200).run(num_trials)
    posterior_predictive = TracePredictive(model, mcmc_run, num_samples=10000).run(num_trials)
    marginal_return_vals = EmpiricalMarginal(posterior_predictive)
    assert_equal(marginal_return_vals.mean, torch.ones(5) * 700, prec=30)
Пример #27
0
 def _conditioned_model(self,model, sigma, ratings):
     data = dict()
     
     rating = ratings.take(2, axis=1)
     rating_len = len(rating)
         
     for i in range(rating_len):
         data["obs" + str(i)] = torch.tensor(rating[i], dtype = torch.float64)
     
     return poutine.condition(model, data=data)(sigma)
Пример #28
0
def score_latent(zs, ys):
    model = HarmonicModel()
    with poutine.trace() as trace:
        with poutine.condition(
                data={"z_{}".format(t): z
                      for t, z in enumerate(zs)}):
            model.init()
            for y in ys[1:]:
                model.step(y)

    return trace.trace.log_prob_sum()
Пример #29
0
def conditioned_model(model, t, yt):
    # model must be a BaseModel
    assert isinstance(model, BaseModel)
    fcn = model.forward
    if model.output_type == "yt":
        obs = yt
    elif model.output_type == "logyt":
        obs = torch.log(torch.clamp(yt, min=1.0))
        # fcn = poutine.mask(fcn, mask=(yt > 0))

    return poutine.condition(fcn, data={model.output_type: obs})(t)
Пример #30
0
def WAIC(model, x, y, out_var_nm, num_samples=100):
    p = torch.zeros((num_samples, len(y)))
    # Get log probability samples
    for i in range(num_samples):
        tr = poutine.trace(poutine.condition(model, data=model.guide())).get_trace(x)
        dist = tr.nodes[out_var_nm]["fn"]
        p[i] = dist.log_prob(y).detach()
    pmax = p.max(axis=0).values
    lppd = pmax + (p - pmax).exp().mean(axis=0).log() # numerically stable version
    penalty = p.var(axis=0)
    return -2*(lppd - penalty)
Пример #31
0
def test_posterior_predictive():
    true_probs = torch.ones(5) * 0.7
    num_trials = torch.ones(5) * 1000
    num_success = dist.Binomial(num_trials, true_probs).sample()
    conditioned_model = poutine.condition(model, data={"obs": num_success})
    nuts_kernel = NUTS(conditioned_model, adapt_step_size=True)
    mcmc_run = MCMC(nuts_kernel, num_samples=1000,
                    warmup_steps=200).run(num_trials)
    posterior_predictive = TracePredictive(model, mcmc_run,
                                           num_samples=10000).run(num_trials)
    marginal_return_vals = EmpiricalMarginal(posterior_predictive)
    assert_equal(marginal_return_vals.mean, torch.ones(5) * 700, prec=30)
Пример #32
0
def test_posterior_predictive_svi_auto_delta_guide(parallel):
    true_probs = torch.ones(5) * 0.7
    num_trials = torch.ones(5) * 1000
    num_success = dist.Binomial(num_trials, true_probs).sample()
    conditioned_model = poutine.condition(model, data={"obs": num_success})
    guide = AutoDelta(conditioned_model)
    svi = SVI(conditioned_model, guide, optim.Adam(dict(lr=1.0)), Trace_ELBO())
    for i in range(1000):
        svi.step(num_trials)
    posterior_predictive = Predictive(model, guide=guide, num_samples=10000, parallel=parallel)
    marginal_return_vals = posterior_predictive.get_samples(num_trials)["obs"]
    assert_close(marginal_return_vals.mean(dim=0), torch.ones(5) * 700, rtol=0.05)
Пример #33
0
def test_posterior_predictive_svi_auto_diag_normal_guide(return_trace):
    true_probs = torch.ones(5) * 0.7
    num_trials = torch.ones(5) * 1000
    num_success = dist.Binomial(num_trials, true_probs).sample()
    conditioned_model = poutine.condition(model, data={"obs": num_success})
    guide = AutoDiagonalNormal(conditioned_model)
    svi = SVI(conditioned_model, guide, optim.Adam(dict(lr=0.1)), Trace_ELBO())
    for i in range(1000):
        svi.step(num_trials)
    posterior_predictive = Predictive(model, guide=guide, num_samples=10000, parallel=True)
    if return_trace:
        marginal_return_vals = posterior_predictive.get_vectorized_trace(num_trials).nodes["obs"]["value"]
    else:
        marginal_return_vals = posterior_predictive.get_samples(num_trials)["obs"]
    assert_close(marginal_return_vals.mean(dim=0), torch.ones(5) * 700, rtol=0.05)
Пример #34
0
def sample_and_calculate_log_impweight(x_data,
                                       y_data,
                                       model,
                                       guide,
                                       num_post_samples=int(1e3)):
    """  
    returns: samples and their log importance weights (adding to 0/exponentiated sum=1)
    """
    log_impweights = torch.zeros(torch.Size([num_post_samples]))
    samples = torch.zeros(torch.Size([num_post_samples, guide.latent_dim]))
    #sigmas = torch.zeros(torch.Size([replications]))
    #sigma_logprobs = torch.zeros(torch.Size([replications]))

    for i in range(num_post_samples):
        trace_guide = poutine.trace(guide).get_trace()
        param_sample = trace_guide.nodes["_RETURN"]["value"]
        # need to evaluate the log probs so that it appears in ["_AutoDiagonalNormal_latent"]["log_prob_sum"]
        trace_guide.log_prob_sum()
        # check log prob sum!! ["_AutoDiagonalNormal_latent"]["log_prob_sum"] seems to be correct,
        # while the above <trace_guide.log_prob_sum()> equals
        # the log prob sum for sigma (which is a deterministic function of log sigma is wrong, i.e. not zero)
        # hack
        if isinstance(guide, AutoDiagonalNormal):
            samples[i, :] = trace_guide.nodes["_AutoDiagonalNormal_latent"][
                "value"]
            param_sample_logprob = trace_guide.nodes[
                "_AutoDiagonalNormal_latent"]["log_prob_sum"]
        else:
            samples[i, :] = trace_guide.nodes[
                "_AutoMultivariateNormal_latent"]["value"]
            param_sample_logprob = trace_guide.nodes[
                "_AutoMultivariateNormal_latent"]["log_prob_sum"]
        #param_sample_logprob = trace_guide.log_prob_sum() - trace_guide.nodes["sigma"]["log_prob_sum"]
        #trace_guide.nodes["_AutoDiagonalNormal_latent"]["log_prob_sum"]

        cond_model = poutine.condition(model,
                                       data={
                                           "obs": y_data,
                                           **param_sample
                                       })
        trace_cond_model = poutine.trace(cond_model).get_trace(x=x_data)
        joint_logprob = trace_cond_model.log_prob_sum()
        #<=>estimated log-posterior
        log_impweights[i] = joint_logprob - param_sample_logprob

    log_impweights = log_impweights - torch.logsumexp(log_impweights, dim=0)

    return samples, log_impweights
Пример #35
0
    def transform_samples(self, aux_samples, save_params=None):
        """
        Given latent samples from the warped posterior (with a possible batch dimension),
        return a `dict` of samples from the latent sites in the model.

        :param dict aux_samples: Dict site name to tensor value for each latent
            auxiliary site (or if ``save_params`` is specifiec, then for only
            those latent auxiliary sites needed to compute requested params).
        :param list save_params: An optional list of site names to save. This
            is useful in models with large nuisance variables. Defaults to
            None, saving all params.
        :return: a `dict` of samples keyed by latent sites in the model.
        :rtype: dict
        """
        with poutine.condition(data=aux_samples), poutine.mask(mask=False):
            deltas = self.guide.get_deltas(save_params)
        return {name: delta.v for name, delta in deltas.items()}
Пример #36
0
    def fit(
        self,
        df,
        max_iter=6000,
        patience=200,
        optimiser_settings={"lr": 1.0e-2},
        elbo_kwargs={"num_particles": 5},
    ):
        teams = sorted(list(set(df["home_team"]) | set(df["away_team"])))
        home_team = df["home_team"].values
        away_team = df["away_team"].values
        home_goals = torch.tensor(df["home_goals"].values, dtype=torch.float32)
        away_goals = torch.tensor(df["away_goals"].values, dtype=torch.float32)
        gameweek = ((df["date"] - df["date"].min()).dt.days // 7).values

        self.team_to_index = {team: i for i, team in enumerate(teams)}
        self.index_to_team = {
            value: key
            for key, value in self.team_to_index.items()
        }
        self.n_teams = len(teams)
        self.min_date = df["date"].min()

        conditioned_model = condition(self.model,
                                      data={
                                          "home_goals": home_goals,
                                          "away_goals": away_goals
                                      })
        guide = AutoDiagonalNormal(conditioned_model)

        optimizer = Adam(optimiser_settings)
        elbo = Trace_ELBO(**elbo_kwargs)
        svi = SVI(conditioned_model, guide, optimizer, loss=elbo)

        pyro.clear_param_store()
        fitted_svi, losses = early_stopping(svi,
                                            home_team,
                                            away_team,
                                            gameweek,
                                            max_iter=max_iter,
                                            patience=patience)

        self.guide = guide

        return losses