Beispiel #1
0
    def guide(self):
        """Approximate posterior for the horseshoe prior. We assume posterior in the form
        of the multivariate normal distriburtion for the global mean and standard deviation
        and multivariate normal distribution for the parameters of each subject independently.
        """
        nsub = self.runs  # number of subjects
        npar = self.npar  # number of parameters
        trns = biject_to(constraints.positive)

        m_hyp = param('m_hyp', zeros(2 * npar))
        st_hyp = param('scale_tril_hyp',
                       torch.eye(2 * npar),
                       constraint=constraints.lower_cholesky)
        hyp = sample('hyp',
                     dist.MultivariateNormal(m_hyp, scale_tril=st_hyp),
                     infer={'is_auxiliary': True})

        unc_mu = hyp[..., :npar]
        unc_tau = hyp[..., npar:]

        c_tau = trns(unc_tau)

        ld_tau = trns.inv.log_abs_det_jacobian(c_tau, unc_tau)
        ld_tau = sum_rightmost(ld_tau, ld_tau.dim() - c_tau.dim() + 1)

        sample("mu", dist.Delta(unc_mu, event_dim=1))
        sample("tau", dist.Delta(c_tau, log_density=ld_tau, event_dim=1))

        m_locs = param('m_locs', zeros(nsub, npar))
        st_locs = param('scale_tril_locs',
                        torch.eye(npar).repeat(nsub, 1, 1),
                        constraint=constraints.lower_cholesky)

        with plate('runs', nsub):
            sample("locs", dist.MultivariateNormal(m_locs, scale_tril=st_locs))
def new_guide(obsmat):
    # These are just the previous values we can use to initialize params here
    initial_topic_weights = pyro.get_param_store()['AutoDelta.topic_weights']
    initial_alpha = pyro.get_param_store()['AutoDelta.topic_weights']
    initial_topic_a = pyro.get_param_store()['AutoDelta.topic_a']
    initial_topic_b = pyro.get_param_store()['AutoDelta.topic_b']

    # Use poutine.block to Keep our learned values of global parameters.
    with poutine.block(hide_types=["param"]):

        # This has to match the structure of the model
        with pyro.plate('topic', tm.K):
            # We manually define the AutoDelta params we had from before here
            topic_weights_q = pyro.param('AutoDelta.topic_weights',
                                         initial_topic_weights)
            topic_a_q = pyro.param('AutoDelta.topic_a', initial_topic_a)
            topic_b_q = pyro.param('AutoDelta.topic_b', initial_topic_b)

            # Each of the sample statements in the above model needs to have a corresponding
            # statement here where we insert our tuneable params
            pyro.sample("topic_weights", dist.Delta(topic_weights_q))
            pyro.sample('topic_a', dist.Delta(topic_a_q).to_event(2))
            pyro.sample('topic_b', dist.Delta(topic_b_q).to_event(2))

    # We define a new learnable parameter for the new participant that
    # sums to 1 (via constraint) and plug this in as their topic probabilities
    probs = pyro.param('new_participant_topic_q',
                       initial_alpha,
                       constraint=constraints.simplex)
    participant_topics = pyro.sample("new_participant_topic",
                                     dist.Delta(probs).to_event(1))
Beispiel #3
0
def _get_sample_fn(module, name):
    if module.mode == "model":
        return module._priors[name]

    dist_constructor, dist_args = module._guides[name]

    if dist_constructor is dist.Delta:
        p_map = getattr(module, "{}_map".format(name))
        return dist.Delta(p_map, event_dim=p_map.dim())

    # create guide
    dist_args = {
        arg: getattr(module, "{}_{}".format(name, arg))
        for arg in dist_args
    }
    guide = dist_constructor(**dist_args)

    # no need to do transforms when support is real (for mean field ELBO)
    support = module._priors[name].support
    if _is_real_support(support):
        return guide.to_event()

    # otherwise, we do inference in unconstrained space and transform the value
    # back to original space
    # TODO: move this logic to infer.autoguide or somewhere else
    unconstrained_value = pyro.sample(module._pyro_get_fullname(
        "{}_latent".format(name)),
                                      guide.to_event(),
                                      infer={"is_auxiliary": True})
    transform = biject_to(support)
    value = transform(unconstrained_value)
    log_density = transform.inv.log_abs_det_jacobian(value,
                                                     unconstrained_value)
    return dist.Delta(value, log_density.sum(), event_dim=value.dim())
Beispiel #4
0
 def test_batch_log_prob(self):
     log_px_torch = dist.Delta(self.vs_expanded).log_prob(
         self.batch_test_data_1).data
     assert_equal(log_px_torch.sum().item(), 0)
     log_px_torch = dist.Delta(self.vs_expanded).log_prob(
         self.batch_test_data_2).data
     assert_equal(log_px_torch.sum().item(), float('-inf'))
Beispiel #5
0
def MAP_guide(prior_params, logits, labels):
    """ Defines a guide for use in MAP inference. """
    n_cls = logits.shape[1]  # Num classes

    beta_MAP = pyro.param('beta_MAP', torch.ones(n_cls, requires_grad=True))
    delta_MAP = pyro.param('delta_MAP', torch.zeros(n_cls, requires_grad=True))
    pyro.sample('beta', dist.Delta(beta_MAP))
    pyro.sample('delta', dist.Delta(delta_MAP))
Beispiel #6
0
    def setUp(self):
        self.v = Variable(torch.Tensor([3]))
        self.vs = Variable(torch.Tensor([[0], [1], [2], [3]]))
        self.test_data = Variable(torch.Tensor([3, 3, 3]))
        self.batch_test_data = Variable(
            torch.arange(0, 4).unsqueeze(1).expand(4, 3))

        self.dist = dist.Delta(self.v)
        self.batch_dist = dist.Delta(self.vs, batch_size=2)
Beispiel #7
0
    def templates_guide_mvn(self):
        """ Multivariate normal guide for template parameters
        """

        loc = _deep_getattr(self, "mvn.loc")
        scale_tril = _deep_getattr(self, "mvn.scale_tril")

        dt = dist.MultivariateNormal(loc, scale_tril=scale_tril)
        states = pyro.sample("states_" + self.name_prefix,
                             dt,
                             infer={"is_auxiliary": True})

        result = {}

        for i_poiss in torch.arange(self.n_poiss):
            transform = biject_to(self.poiss_priors[i_poiss].support)
            value = transform(states[i_poiss])
            log_density = transform.inv.log_abs_det_jacobian(
                value, states[i_poiss])
            log_density = sum_rightmost(
                log_density,
                log_density.dim() - value.dim() +
                self.poiss_priors[i_poiss].event_dim)

            result[self.poiss_labels[i_poiss]] = pyro.sample(
                self.poiss_labels[i_poiss],
                dist.Delta(value,
                           log_density=log_density,
                           event_dim=self.poiss_priors[i_poiss].event_dim))

        i_param = self.n_poiss

        for i_ps in torch.arange(self.n_ps):
            for i_ps_param in torch.arange(self.n_ps_params):

                transform = biject_to(self.ps_priors[i_ps][i_ps_param].support)

                value = transform(states[i_param])

                log_density = transform.inv.log_abs_det_jacobian(
                    value, states[i_param])
                log_density = sum_rightmost(
                    log_density,
                    log_density.dim() - value.dim() +
                    self.ps_priors[i_ps][i_ps_param].event_dim)

                result[self.ps_param_labels[i_ps_param] + "_" +
                       self.ps_labels[i_ps]] = pyro.sample(
                           self.ps_param_labels[i_ps_param] + "_" +
                           self.ps_labels[i_ps],
                           dist.Delta(value,
                                      log_density=log_density,
                                      event_dim=self.ps_priors[i_ps]
                                      [i_ps_param].event_dim))
                i_param += 1

        return result
Beispiel #8
0
    def module_ppca_gm_means_sigma_guide(self, input_batch, epsilon):
        batch_size = input_batch.shape[0]
        if self.likelihood == 'normal':
            if self.group_isotropic:
                ppca_gm_sigma_p = pyro.param(f'ppca_gm_sigma_p',
                                             input_batch.new_ones(1, 1),
                                             constraint=constraints.positive)
                ppca_gm_sigma = pyro.sample(
                    f'ppca_gm_sigma',
                    dist.Delta(ppca_gm_sigma_p).independent(1))
            else:
                ppca_gm_sigma = input_batch.new_ones(1, 1)
                ppca_gm_sigma_list = []
                for i in range(self.d):
                    ppca_gm_sigma_p = pyro.param(
                        f'ppca_gm_sigma_{i}_p',
                        input_batch.new_ones(1, self.n[i]),
                        constraint=constraints.positive)
                    ppca_gm_sigma_list.append(
                        pyro.sample(
                            f'ppca_gm_sigma_{i}',
                            dist.Delta(ppca_gm_sigma_p).independent(1)))
                    ppca_gm_sigma = torch_utils.krp_cw_torch(
                        ppca_gm_sigma_list[i], ppca_gm_sigma, column=False)
        else:
            ppca_gm_sigma = input_batch.new_ones(1, 1)
            ppca_gm_sigma_list = [
                input_batch.new_ones(1, self.n[i]) for i in range(self.d)
            ]

        alpha_gm_p = pyro.param(
            f'alpha_gm_p', input_batch.new_ones([1, self.group_hidden_dim]))
        alpha_gm = pyro.sample(f'alpha_gm',
                               dist.Delta(alpha_gm_p).independent(1))

        if self.group_iterm is None:
            z_mu = self.group_term.linear_mapping.inverse_batch(input_batch)
        else:
            z_mu = self.group_iterm(torch_utils.flatten_torch(input_batch),
                                    T=True)
        if self.group_isotropic:
            zk_mean, zk_cov = self.group_term.get_posterior_gaussian_mean_covariance(
                input_batch,
                noise_sigma=ppca_gm_sigma[0]
                if ppca_gm_sigma is not None else input_batch.new_ones(1),
                z_mu=z_mu,
                z_sigma=alpha_gm[0])
        else:
            zk_mean, zk_cov = self.group_term.get_posterior_gaussian_mean_covariance(
                input_batch,
                noise_sigma=[x for x in ppca_gm_sigma_list],
                z_mu=z_mu,
                z_sigma=alpha_gm[0])
        ppca_gm_means = self.group_term(
            zk_mean + epsilon[:, :self.group_hidden_dim].mm(
                zk_cov.view(self.group_hidden_dim, self.group_hidden_dim)))
        return ppca_gm_means, ppca_gm_sigma
Beispiel #9
0
    def guide(self):

        a_locs = pyro.param("a_locs", torch.full((self.n_params, ), 0.0))
        a_scales_tril = pyro.param(
            "a_scales",
            lambda: 0.1 * eye_like(a_locs, self.n_params),
            constraint=constraints.lower_cholesky)

        dt = dist.MultivariateNormal(a_locs, scale_tril=a_scales_tril)
        states = pyro.sample("states", dt, infer={"is_auxiliary": True})

        result = {}

        for i_poiss in torch.arange(self.n_poiss):
            transform = biject_to(self.poiss_priors[i_poiss].support)
            value = transform(states[i_poiss])
            log_density = transform.inv.log_abs_det_jacobian(
                value, states[i_poiss])
            log_density = sum_rightmost(
                log_density,
                log_density.dim() - value.dim() +
                self.poiss_priors[i_poiss].event_dim)

            result[self.labels_poiss[i_poiss]] = pyro.sample(
                self.labels_poiss[i_poiss],
                dist.Delta(value,
                           log_density=log_density,
                           event_dim=self.poiss_priors[i_poiss].event_dim))

        i_param = self.n_poiss

        for i_ps in torch.arange(self.n_ps):
            for i_ps_param in torch.arange(self.n_ps_params):

                transform = biject_to(self.ps_priors[i_ps][i_ps_param].support)

                value = transform(states[i_param])

                log_density = transform.inv.log_abs_det_jacobian(
                    value, states[i_param])
                log_density = sum_rightmost(
                    log_density,
                    log_density.dim() - value.dim() +
                    self.ps_priors[i_ps][i_ps_param].event_dim)

                result[self.labels_ps_params[i_ps_param] + "_" +
                       self.labels_ps[i_ps]] = pyro.sample(
                           self.labels_ps_params[i_ps_param] + "_" +
                           self.labels_ps[i_ps],
                           dist.Delta(value,
                                      log_density=log_density,
                                      event_dim=self.ps_priors[i_ps]
                                      [i_ps_param].event_dim))
                i_param += 1

        return result
Beispiel #10
0
def reward(state, i=0):
    """Reward function given a state"""
    # Goal is state 15, reward 1 point
    if state == 15:
        return pyro.sample(f'reward{state}{i}', dist.Delta(torch.tensor(1.)))
    # Holes are state 5, 7, 11, 12, penalize 15 points
    if state in [5, 7, 11, 12]:
        return pyro.sample(f'reward{state}{i}', dist.Delta(torch.tensor(-10.)))
    # Create a reward that grows as we get close to goal
    r = float(1 / (15 - state + 1))
    return pyro.sample(f'reward{state}{i}', dist.Delta(torch.tensor(r)))
Beispiel #11
0
    def guide(self, evidence={}, noise=None):
        """A "smart" guide function for the SCM model above which propagates the information
        from a deterministic node being observed to the noise node, so that you don't end up with many rejected samples.
        This is slightly different from the model schema for the sake of sampling efficiency.

        Args:
            evidence (dict): a dictionary of {node_name: value} evidence data.
            noise (None): a useless parameter that exists because in Pyro, the guide fn has the same inputs as the model fn. 
        Returns:
            model_dict (dict): a sample from the guide in {node_name: value} format.
            
        TODO: Make all endogenous nodes deterministic rather than delta variables
        """
        guide_dict = {}
        # the order is a little complex. Any observed nodes have to go first, then the non-twin endog, then twin.
        for node in self._get_guide_order(evidence):
            exog_parent = [
                n for n in self.G_inference.predecessors(node)
                if self.scm._is_exog(n, self.G_inference)
            ][0]
            endog_parents = sorted([
                n for n in self.G_inference.predecessors(node)
                if not self.scm._is_exog(n, self.G_inference)
            ])
            if endog_parents:
                parent_values = [guide_dict[n] for n in endog_parents]
            else:
                parent_values = []
            if node not in evidence:
                if exog_parent not in guide_dict:  # if you haven't already sampled the exog_parent
                    guide_dict[exog_parent] = pyro.sample(
                        exog_parent, self.exog_fn)
            else:
                if not endog_parents:  # if node only has an exogenous parent
                    if exog_parent not in guide_dict:
                        guide_dict[exog_parent] = pyro.sample(
                            exog_parent,
                            dist.Delta(evidence[node]))  # TODO: Choose
                        # guide_dict[exog_parent] = self._assign_delta_node(exog_parent, evidence[node])
                else:  # if a node has exog & endog parents
                    if exog_parent not in guide_dict:
                        predicted_val = self._scm_function(node, parent_values)
                        exog_val = self.invert_fn(evidence[node],
                                                  predicted_val)
                        guide_dict[exog_parent] = pyro.sample(
                            exog_parent, dist.Delta(exog_val))  # TODO: Choose
                        # guide_dict[exog_parent] = self._assign_delta_node(exog_parent, exog_val)
            val = self._scm_function(node, parent_values,
                                     guide_dict[exog_parent])
            guide_dict[node] = pyro.sample(node,
                                           dist.Delta(val))  # TODO: Choose
            # guide_dict[node] = self._assign_delta_node(node, val)
        return guide_dict
def guide():
    # only contains param and sample from Delta dist with
    av_price_log_mean_param = pyro.param('average_price_log_mean_param',
                                         torch.tensor(1.0))
    av_price_log_var_param = pyro.param('average_price_log_var_param',
                                        torch.tensor(1.0))
    return (
        pyro.sample('average_price_log_mean',
                    dist.Delta(av_price_log_mean_param)),
        pyro.sample('average_price_log_var',
                    dist.Delta(av_price_log_var_param)),
    )
Beispiel #13
0
    def guide_map(anime_matrix_train, k=k):
        m = anime_matrix_train.shape[0]
        n = anime_matrix_train.shape[1]

        u_map = pyro.param('u_map', torch.zeros([m, k]))
        v_map = pyro.param('v_map', torch.zeros([n, k]))
        sigma_map = pyro.param("sigma_map",
                               torch.tensor(1.0),
                               constraint=constraints.positive)

        pyro.sample("u", dist.Delta(u_map).to_event(2))
        pyro.sample("v", dist.Delta(v_map).to_event(2))
        pyro.sample("sigma", dist.Delta(sigma_map))
Beispiel #14
0
    def guide_horseshoe_plus(self):
        
        npar = self.npars  # number of parameters
        nsub = self.runs  # number of subjects
        trns = biject_to(constraints.positive)

        
        m_hyp = param('m_hyp', zeros(2*npar))
        st_hyp = param('scale_tril_hyp', 
                              torch.eye(2*npar), 
                              constraint=constraints.lower_cholesky)
        hyp = sample('hyp', dist.MultivariateNormal(m_hyp, 
                                                  scale_tril=st_hyp), 
                            infer={'is_auxiliary': True})
        
        unc_mu = hyp[:npar]
        unc_sigma = hyp[npar:]
    
    
        c_sigma = trns(unc_sigma)
    
        ld_sigma = trns.inv.log_abs_det_jacobian(c_sigma, unc_sigma)
        ld_sigma = sum_rightmost(ld_sigma, ld_sigma.dim() - c_sigma.dim() + 1)
    
        mu_g = sample("mu_g", dist.Delta(unc_mu, event_dim=1))
        sigma_g = sample("sigma_g", dist.Delta(c_sigma, log_density=ld_sigma, event_dim=1))
        
        m_tmp = param('m_tmp', zeros(nsub, 2*npar))
        st_tmp = param('s_tmp', torch.eye(2*npar).repeat(nsub, 1, 1), 
                   constraint=constraints.lower_cholesky)

        with plate('subjects', nsub):
            tmp = sample('tmp', dist.MultivariateNormal(m_tmp, 
                                                  scale_tril=st_tmp), 
                            infer={'is_auxiliary': True})
            
            unc_locs = tmp[..., :npar]
            unc_scale = tmp[..., npar:]
            
            c_scale = trns(unc_scale)
            
            ld_scale = trns.inv.log_abs_det_jacobian(c_scale, unc_scale)
            ld_scale = sum_rightmost(ld_scale, ld_scale.dim() - c_scale.dim() + 1)
            
            x = sample("x", dist.Delta(unc_locs, event_dim=1))
            sigma_x = sample("sigma_x", dist.Delta(c_scale, log_density=ld_scale, event_dim=1))
        
        return {'mu_g': mu_g, 'sigma_g': sigma_g, 'sigma_x': sigma_x, 'x': x}
Beispiel #15
0
    def forward(self, design, target_labels=None):
        """
        Sample the posterior.

        :param torch.Tensor design: tensor of possible designs.
        :param list target_labels: list indicating the sample sites that are targets, i.e. for which information gain
                                   should be measured.
        """
        if target_labels is None:
            target_labels = list(self.means.keys())

        pyro.module("laplace_guide", self)
        with ExitStack() as stack:
            for plate in iter_plates_to_shape(design.shape[:-2]):
                stack.enter_context(plate)

            if self.training:
                # MAP via Delta guide
                for l in target_labels:
                    w_dist = dist.Delta(self.means[l]).to_event(1)
                    pyro.sample(l, w_dist)
            else:
                # Laplace approximation via MVN with hessian
                for l in target_labels:
                    w_dist = dist.MultivariateNormal(
                        self.means[l], scale_tril=self.scale_trils[l])
                    pyro.sample(l, w_dist)
Beispiel #16
0
    def __call__(self, name, fn, obs):
        assert fn.event_dim >= self.event_dim
        assert obs is None, "SplitReparam does not support observe statements"

        # Draw independent parts.
        dim = fn.event_dim - self.event_dim
        left_shape = fn.event_shape[:dim]
        right_shape = fn.event_shape[1 + dim:]
        parts = []
        for i, size in enumerate(self.sections):
            event_shape = left_shape + (size, ) + right_shape
            parts.append(
                pyro.sample(
                    "{}_split_{}".format(name, i),
                    dist.ImproperUniform(fn.support, fn.batch_shape,
                                         event_shape)))
        value = torch.cat(parts, dim=-self.event_dim)

        # Combine parts.
        if poutine.get_mask() is False:
            log_density = 0.0
        else:
            log_density = fn.log_prob(value)
        new_fn = dist.Delta(value,
                            event_dim=fn.event_dim,
                            log_density=log_density)
        return new_fn, value
Beispiel #17
0
    def apply(self, msg):
        name = msg["name"]
        fn = msg["fn"]
        value = msg["value"]
        is_observed = msg["is_observed"]
        if is_observed:
            raise NotImplementedError(
                "ProjectedNormalReparam does not support observe statements"
            )

        fn, event_dim = self._unwrap(fn)
        assert isinstance(fn, dist.ProjectedNormal)

        # Differentiably invert transform.
        value_normal = None
        if value is not None:
            # We use an arbitrary injection, which works only for initialization.
            value_normal = value - fn.concentration

        # Draw parameter-free noise.
        new_fn = dist.Normal(torch.zeros_like(fn.concentration), 1).to_event(1)
        x = pyro.sample(
            "{}_normal".format(name),
            self._wrap(new_fn, event_dim),
            obs=value_normal,
            infer={"is_observed": is_observed},
        )

        # Differentiably transform.
        if value is None:
            value = safe_normalize(x + fn.concentration)

        # Simulate a pyro.deterministic() site.
        new_fn = dist.Delta(value, event_dim=event_dim).mask(False)
        return {"fn": new_fn, "value": value, "is_observed": True}
Beispiel #18
0
def test_kl_delta_normal_shape(batch_shape):
    v = torch.randn(batch_shape)
    loc = torch.randn(batch_shape)
    scale = torch.randn(batch_shape).exp()
    p = dist.Delta(v)
    q = dist.Normal(loc, scale)
    assert kl_divergence(p, q).shape == batch_shape
Beispiel #19
0
    def map_estimate(self, name):
        """
        Construct a maximum a posteriori (MAP) guide using Delta distributions.

        :param str name: The name of a model sample site.
        :return: A sampled value.
        :rtype: torch.Tensor
        """
        site = self.prototype_trace.nodes[name]
        fn = site["fn"]
        event_dim = fn.event_dim
        init_needed = not hasattr(self, name)
        if init_needed:
            init_value = site["value"].detach()
        with ExitStack() as stack:
            for frame in site["cond_indep_stack"]:
                plate = self.plate(frame.name)
                if plate not in runtime._PYRO_STACK:
                    stack.enter_context(plate)
                elif init_needed and plate.subsample_size < plate.size:
                    # Repeat the init_value to full size.
                    dim = plate.dim - event_dim
                    assert init_value.size(dim) == plate.subsample_size
                    ind = torch.arange(plate.size, device=init_value.device)
                    ind = ind % plate.subsample_size
                    init_value = init_value.index_select(dim, ind)
            if init_needed:
                setattr(self, name, PyroParam(init_value, fn.support,
                                              event_dim))
            value = getattr(self, name)
            return pyro.sample(name, dist.Delta(value, event_dim=event_dim))
Beispiel #20
0
    def unpack(self, group_z: torch.Tensor) -> tp.Dict[str, torch.Tensor]:
        model_zs = {}
        for pos, (name, fn, frames), transform in zip(
            # lazy cumsum!! python ftw!
            (s for s in [0] for x in self.sizes.values() for s in [x+s]),
            map(itemgetter('name', 'fn', 'cond_indep_stack'), self.sites.values()),
            self.transforms.values()
        ):
            fn: dist.TorchDistribution
            zs = group_z[..., pos-self.sizes[name]:pos]
            z = self.inits[name].expand(zs.shape[:-1] + self.masks[name].shape).clone()
            z[..., self.masks[name]] = zs

            x = transform(z)

            if self.include_det_jac and transform.bijective:
                log_density = transform.inv.log_abs_det_jacobian(x, z)
                log_density = log_density.sum(list(range(-(log_density.ndim - z.ndim + fn.event_dim), 0)))
            else:
                log_density = 0.

            delta = dist.Delta(x, log_density=log_density, event_dim=fn.event_dim)
            model_zs[name] = pyro.sample(name, delta)

        return model_zs
Beispiel #21
0
 def get_posterior(self, *args, **kwargs):
     """
     Returns a Delta posterior distribution for MAP inference.
     """
     loc = pyro.param("{}_loc".format(self.prefix),
                      lambda: torch.zeros(self.latent_dim))
     return dist.Delta(loc).to_event(1)
Beispiel #22
0
def parametrized_guide(predictor, data, args, batch_size=None):
    # Use a conjugate guide for global variables.
    topic_weights_posterior = pyro.param(
        "topic_weights_posterior",
        lambda: torch.ones(args.num_topics),
        constraint=constraints.positive,
    )
    topic_words_posterior = pyro.param(
        "topic_words_posterior",
        lambda: torch.ones(args.num_topics, args.num_words),
        constraint=constraints.greater_than(0.5),
    )
    with pyro.plate("topics", args.num_topics):
        pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.0))
        pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior))

    # Use an amortized guide for local variables.
    pyro.module("predictor", predictor)
    with pyro.plate("documents", args.num_docs, batch_size) as ind:
        data = data[:, ind]
        # The neural network will operate on histograms rather than word
        # index vectors, so we'll convert the raw data to a histogram.
        counts = torch.zeros(args.num_words,
                             ind.size(0)).scatter_add(0, data,
                                                      torch.ones(data.shape))
        doc_topics = predictor(counts.transpose(0, 1))
        pyro.sample("doc_topics", dist.Delta(doc_topics, event_dim=1))
Beispiel #23
0
    def forward(self, *args, **kwargs):
        """
        An automatic guide with the same ``*args, **kwargs`` as the base ``model``.

        .. note:: This method is used internally by :class:`~torch.nn.Module`.
            Users should instead use :meth:`~torch.nn.Module.__call__`.

        :return: A dict mapping sample site name to sampled value.
        :rtype: dict
        """
        # if we've never run the model before, do so now so we can inspect the model structure
        if self.prototype_trace is None:
            self._setup_prototype(*args, **kwargs)

        plates = self._create_plates(*args, **kwargs)
        result = {}
        for name, site in self.prototype_trace.iter_stochastic_nodes():
            with ExitStack() as stack:
                for frame in site["cond_indep_stack"]:
                    if frame.vectorized:
                        stack.enter_context(plates[frame.name])
                attr_get = operator.attrgetter(name)
                result[name] = pyro.sample(name, dist.Delta(attr_get(self),
                                                            event_dim=site["fn"].event_dim))
        return result
Beispiel #24
0
    def apply(self, msg):
        name = msg["name"]
        fn = msg["fn"]
        # ignore msg["value"]
        is_observed = msg["is_observed"]

        fn, event_dim = self._unwrap(fn)
        assert isinstance(fn, dist.Stable) and fn.coords == "S0"
        if is_observed:
            raise NotImplementedError(
                f"At pyro.sample({repr(name)},...), "
                "LatentStableReparam does not support observe statements")

        # Draw parameter-free noise.
        proto = fn.stability
        half_pi = proto.new_tensor(math.pi / 2)
        one = proto.new_ones(proto.shape)
        u = pyro.sample(
            "{}_uniform".format(name),
            self._wrap(
                dist.Uniform(-half_pi, half_pi).expand(proto.shape),
                event_dim),
        )
        e = pyro.sample("{}_exponential".format(name),
                        self._wrap(dist.Exponential(one), event_dim))

        # Differentiably transform.
        x = _standard_stable(fn.stability, fn.skew, u, e, coords="S0")
        value = fn.loc + fn.scale * x

        # Simulate a pyro.deterministic() site.
        new_fn = dist.Delta(value, event_dim=event_dim).mask(False)
        return {"fn": new_fn, "value": value, "is_observed": True}
Beispiel #25
0
    def __call__(self, *args, **kwargs):
        """
        An automatic guide with the same ``*args, **kwargs`` as the base ``model``.

        :return: A dict mapping sample site name to sampled value.
        :rtype: dict
        """
        # if we've never run the model before, do so now so we can inspect the model structure
        if self.prototype_trace is None:
            self._setup_prototype(*args, **kwargs)

        latent = self.sample_latent(*args, **kwargs)
        plates = self._create_plates()

        # unpack continuous latent samples
        result = {}
        for site, unconstrained_value in self._unpack_latent(latent):
            name = site["name"]
            transform = biject_to(site["fn"].support)
            value = transform(unconstrained_value)
            log_density = transform.inv.log_abs_det_jacobian(value, unconstrained_value)
            log_density = sum_rightmost(log_density, log_density.dim() - value.dim() + site["fn"].event_dim)
            delta_dist = dist.Delta(value, log_density=log_density, event_dim=site["fn"].event_dim)

            with ExitStack() as stack:
                for frame in self._cond_indep_stacks[name]:
                    stack.enter_context(plates[frame.name])
                result[name] = pyro.sample(name, delta_dist)

        return result
Beispiel #26
0
def parametrized_guide(predictor, data, num_words_per_doc, args):
    # Use a conjugate guide for global variables.
    topic_weights_posterior = pyro.param(
        "topic_weights_posterior",
        lambda: torch.ones(args.num_topics),
        constraint=constraints.positive)
    topic_words_posterior = pyro.param(
        "topic_words_posterior",
        lambda: torch.ones(args.num_topics, args.num_words),
        constraint=constraints.greater_than(0.5))
    with pyro.plate("topics", args.num_topics):
        pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.))
        pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior))

    # Use an amortized guide for local variables.
    pyro.module("predictor", predictor)
    for doc in pyro.plate("documents", args.num_docs, args.batch_size):
        # data = data[:, ind]
        # The neural network will operate on histograms rather than word
        # index vectors, so we'll convert the raw data to a histogram.
        counts = torch.zeros(args.num_words, 1)
        for i in data[doc]: counts[i] += 1
        #    .scatter_add(0, data[doc], torch.ones(data[doc].shape)))
        doc_topics = predictor(counts.transpose(0, 1))
        pyro.sample("doc_topics_{}".format(doc), dist.Delta(doc_topics, event_dim=1))
        # added this part since
        with pyro.plate("words_{}".format(doc), num_words_per_doc[doc]):
            word_topics = pyro.sample("word_topics_{}".format(doc), dist.Categorical(doc_topics))
Beispiel #27
0
    def sample(self, guide_name, fn, infer=None):
        """
        Wrapper around ``pyro.sample()`` to create a single auxiliary sample
        site and then unpack to multiple sample sites for model replay.

        :param str guide_name: The name of the auxiliary guide site.
        :param callable fn: A distribution with shape ``self.event_shape``.
        :param dict infer: Optional inference configuration dict.
        :returns: A pair ``(guide_z, model_zs)`` where ``guide_z`` is the
            single concatenated blob and ``model_zs`` is a dict mapping
            site name to constrained model sample.
        :rtype: tuple
        """
        # Sample a packed tensor.
        if fn.event_shape != self.event_shape:
            raise ValueError(
                "Invalid fn.event_shape for group: expected {}, actual {}".
                format(tuple(self.event_shape), tuple(fn.event_shape)))
        if infer is None:
            infer = {}
        infer["is_auxiliary"] = True
        guide_z = pyro.sample(guide_name, fn, infer=infer)
        common_batch_shape = guide_z.shape[:-1]

        model_zs = {}
        pos = 0
        for site in self.prototype_sites:
            name = site["name"]
            fn = site["fn"]

            # Extract slice from packed sample.
            size = self._site_sizes[name]
            batch_shape = broadcast_shape(common_batch_shape,
                                          self._site_batch_shapes[name])
            unconstrained_z = guide_z[..., pos:pos + size]
            unconstrained_z = unconstrained_z.reshape(batch_shape +
                                                      fn.event_shape)
            pos += size

            # Transform to constrained space.
            transform = biject_to(fn.support)
            z = transform(unconstrained_z)
            log_density = transform.inv.log_abs_det_jacobian(
                z, unconstrained_z)
            log_density = sum_rightmost(
                log_density,
                log_density.dim() - z.dim() + fn.event_dim)
            delta_dist = dist.Delta(z,
                                    log_density=log_density,
                                    event_dim=fn.event_dim)

            # Replay model sample statement.
            with ExitStack() as stack:
                for frame in site["cond_indep_stack"]:
                    plate = self.guide.plate(frame.name)
                    if plate not in runtime._PYRO_STACK:
                        stack.enter_context(plate)
                model_zs[name] = pyro.sample(name, delta_dist)

        return guide_z, model_zs
Beispiel #28
0
    def __call__(self, name, fn, obs):
        assert obs is None, "LocScaleReparam does not support observe statements"
        centered = self.centered
        if is_identically_one(centered):
            return name, fn, obs
        event_shape = fn.event_shape
        fn, event_dim = self._unwrap(fn)

        # Apply a partial decentering transform.
        params = {key: getattr(fn, key) for key in self.shape_params}
        if self.centered is None:
            centered = pyro.param("{}_centered",
                                  lambda: fn.loc.new_full(event_shape, 0.5),
                                  constraint=constraints.unit_interval)
        params["loc"] = fn.loc * centered
        params["scale"] = fn.scale ** centered
        decentered_fn = type(fn)(**params)

        # Draw decentered noise.
        decentered_value = pyro.sample("{}_decentered".format(name),
                                       self._wrap(decentered_fn, event_dim))

        # Differentiably transform.
        delta = decentered_value - centered * fn.loc
        value = fn.loc + fn.scale.pow(1 - centered) * delta

        # Simulate a pyro.deterministic() site.
        new_fn = dist.Delta(value, event_dim=event_dim).mask(False)
        return new_fn, value
Beispiel #29
0
    def apply(self, msg):
        name = msg["name"]
        fn = msg["fn"]
        value = msg["value"]
        is_observed = msg["is_observed"]
        if name not in self.guide.prototype_trace.nodes:
            return {"fn": fn, "value": value, "is_observed": is_observed}
        if is_observed:
            raise NotImplementedError(
                f"At pyro.sample({repr(name)},...), "
                "StructuredReparam does not support observe statements")

        if name not in self.deltas:  # On first sample site.
            with ExitStack() as stack:
                for plate in self.guide.plates.values():
                    stack.enter_context(
                        block_plate(dim=plate.dim, strict=False))
                self.deltas = self.guide.get_deltas()
        new_fn = self.deltas.pop(name)
        value = new_fn.v

        if poutine.get_mask() is not False:
            log_density = new_fn.log_density + fn.log_prob(value)
            new_fn = dist.Delta(value, log_density, new_fn.event_dim)
        return {"fn": new_fn, "value": value, "is_observed": True}
Beispiel #30
0
def belief_value_model(belief, action, t, discount=1.0, discount_factor=0.95, max_depth=10,
                       bu_nsteps=10, bu_lr=0.1):
    """Returns Pr(Value | b,a)"""
    if t > max_depth:
        return tensor(1e-9)

    # Somehow compute the value
    state = states[pyro.sample("s%d" % t, belief)]
    next_state = states[pyro.sample("next_s%d" % t, transition_dist(state, action))]
    reward = pyro.sample("r%d" % t, reward_dist(state, action, next_state))

    if next_state == "terminal":
        return pyro.sample("v%d" % t, dist.Delta(reward))
    else:
        # compute future value
        discount = discount*discount_factor
        observation = observations[pyro.sample("o%d" % t,
                                               observation_dist(next_state, action))]
        with poutine.block(hide_fn=lambda site: site["name"].startswith("bu")):
            next_belief = belief_update(belief, action, observation,
                                        num_steps=bu_nsteps, lr=bu_lr, suffix=str(t))
        # action_weights = pyro.param("action_weights", action_weights)
        next_action = belief_policy_model(next_belief, t+1,
                                          discount=discount,
                                          discount_factor=discount_factor,
                                          max_depth=max_depth)
        return reward + discount*belief_value_model(next_belief, next_action, t+1,
                                                    discount=discount,
                                                    discount_factor=discount_factor,
                                                    max_depth=max_depth,
                                                    bu_nsteps=bu_nsteps,
                                                    bu_lr=bu_lr)