Example #1
0
    def _transform_values(
        self,
        aux_values: Dict[str, torch.Tensor],
    ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.Tensor]]:
        # Learnably transform auxiliary values to user-facing values.
        values = {}
        log_densities = defaultdict(float)
        compute_density = am_i_wrapped() and poutine.get_mask() is not False
        for name, site in self._factors.items():
            if site["is_observed"]:
                continue
            loc = deep_getattr(self.locs, name)
            scale = deep_getattr(self.scales, name)
            unconstrained = aux_values[name] * scale + loc

            # Transform to constrained space.
            transform = biject_to(site["fn"].support)
            values[name] = transform(unconstrained)
            if compute_density:
                assert transform.codomain.event_dim == site["fn"].event_dim
                log_densities[name] = transform.inv.log_abs_det_jacobian(
                    values[name], unconstrained
                ) - scale.log().reshape(site["fn"].batch_shape + (-1,)).sum(-1)

        return values, log_densities
Example #2
0
def sample(name, fn, *args, **kwargs):
    """
    Calls the stochastic function ``fn`` with additional side-effects depending
    on ``name`` and the enclosing context (e.g. an inference algorithm).  See
    `Introduction to Pyro <http://pyro.ai/examples/intro_long.html>`_ for a discussion.

    :param name: name of sample
    :param fn: distribution class or function
    :param obs: observed datum (optional; should only be used in context of
        inference) optionally specified in kwargs
    :param ~torch.Tensor obs_mask: Optional boolean tensor mask of shape
        broadcastable with ``fn.batch_shape``. If provided, events with
        mask=True will be conditioned on ``obs`` and remaining events will be
        imputed by sampling. This introduces a latent sample site named ``name
        + "_unobserved"`` which should be used by guides.
    :type obs_mask: bool or ~torch.Tensor
    :param dict infer: Optional dictionary of inference parameters specified
        in kwargs. See inference documentation for details.
    :returns: sample
    """
    # Transform obs_mask into multiple sample statements.
    obs = kwargs.pop("obs", None)
    obs_mask = kwargs.pop("obs_mask", None)
    if obs_mask is not None:
        return _masked_observe(name, fn, obs, obs_mask, *args, **kwargs)

    # Check if stack is empty.
    # if stack empty, default behavior (defined here)
    infer = kwargs.pop("infer", {}).copy()
    is_observed = infer.pop("is_observed", obs is not None)
    if not am_i_wrapped():
        if obs is not None and not infer.get("_deterministic"):
            warnings.warn(
                "trying to observe a value outside of inference at " + name,
                RuntimeWarning,
            )
            return obs
        return fn(*args, **kwargs)
    # if stack not empty, apply everything in the stack?
    else:
        # initialize data structure to pass up/down the stack
        msg = {
            "type": "sample",
            "name": name,
            "fn": fn,
            "is_observed": is_observed,
            "args": args,
            "kwargs": kwargs,
            "value": obs,
            "infer": infer,
            "scale": 1.0,
            "mask": None,
            "cond_indep_stack": (),
            "done": False,
            "stop": False,
            "continuation": None,
        }
        # apply the stack and return its return value
        apply_stack(msg)
        return msg["value"]
Example #3
0
def param(name, *args, **kwargs):
    """
    Saves the variable as a parameter in the param store.
    To interact with the param store or write to disk,
    see `Parameters <parameters.html>`_.

    :param name: name of parameter
    :returns: parameter
    """
    if not am_i_wrapped():
        return _PYRO_PARAM_STORE.get_param(name, *args, **kwargs)
    else:
        msg = {
            "type": "param",
            "name": name,
            "args": args,
            "kwargs": kwargs,
            "infer": {},
            "scale": 1.0,
            "cond_indep_stack": (),
            "value": None,
            "done": False,
            "stop": False,
            "continuation": None
        }
        # apply the stack and return its return value
        apply_stack(msg)
        return msg["value"]
Example #4
0
def param(name, *args, **kwargs):
    """
    Saves the variable as a parameter in the param store.
    To interact with the param store or write to disk,
    see `Parameters <parameters.html>`_.

    :param name: name of parameter
    :returns: parameter
    """
    if not am_i_wrapped():
        return _PYRO_PARAM_STORE.get_param(name, *args, **kwargs)
    else:
        msg = {
            "type": "param",
            "name": name,
            "args": args,
            "kwargs": kwargs,
            "infer": {},
            "scale": 1.0,
            "cond_indep_stack": (),
            "value": None,
            "done": False,
            "stop": False,
            "continuation": None
        }
        # apply the stack and return its return value
        apply_stack(msg)
        return msg["value"]
Example #5
0
    def _sample_aux_values(self, *, temperature: float) -> Dict[str, torch.Tensor]:
        funsor = _import_funsor()

        # Convert torch to funsor.
        particle_plates = frozenset(get_plates())
        plate_to_dim = self._funsor_plate_to_dim.copy()
        plate_to_dim.update({f.name: f.dim for f in particle_plates})
        factors = {}
        for d, inputs in self._funsor_factor_inputs.items():
            batch_shape = torch.Size(
                p.size for p in sorted(self._plates[d], key=lambda p: p.dim)
            )
            white_vec = deep_getattr(self.white_vecs, d)
            prec_sqrt = deep_getattr(self.prec_sqrts, d)
            factors[d] = funsor.gaussian.Gaussian(
                white_vec=white_vec.reshape(batch_shape + white_vec.shape[-1:]),
                prec_sqrt=prec_sqrt.reshape(batch_shape + prec_sqrt.shape[-2:]),
                inputs=inputs,
            )

        # Perform Gaussian tensor variable elimination.
        if temperature == 1:
            samples, log_prob = _try_possibly_intractable(
                funsor.recipes.forward_filter_backward_rsample,
                factors=factors,
                eliminate=self._funsor_eliminate,
                plates=frozenset(plate_to_dim),
                sample_inputs={f.name: funsor.Bint[f.size] for f in particle_plates},
            )

        else:
            samples, log_prob = _try_possibly_intractable(
                funsor.recipes.forward_filter_backward_precondition,
                factors=factors,
                eliminate=self._funsor_eliminate,
                plates=frozenset(plate_to_dim),
            )

            # Substitute noise.
            sample_shape = torch.Size(f.size for f in particle_plates)
            noise = torch.randn(sample_shape + log_prob.inputs["aux"].shape)
            noise.mul_(temperature)
            aux = funsor.Tensor(noise)[tuple(f.name for f in particle_plates)]
            with funsor.interpretations.memoize():
                samples = {k: v(aux=aux) for k, v in samples.items()}
                log_prob = log_prob(aux=aux)

        # Convert funsor to torch.
        if am_i_wrapped() and poutine.get_mask() is not False:
            log_prob = funsor.to_data(log_prob, name_to_dim=plate_to_dim)
            pyro.factor(f"_{self._pyro_name}_latent", log_prob, has_rsample=True)
        samples = {
            k: funsor.to_data(v, name_to_dim=plate_to_dim) for k, v in samples.items()
        }
        return samples
Example #6
0
def sample(name, fn, *args, **kwargs):
    """
    Calls the stochastic function `fn` with additional side-effects depending
    on `name` and the enclosing context (e.g. an inference algorithm).
    See `Intro I <http://pyro.ai/examples/intro_part_i.html>`_ and
    `Intro II <http://pyro.ai/examples/intro_part_ii.html>`_ for a discussion.

    :param name: name of sample
    :param fn: distribution class or function
    :param obs: observed datum (optional; should only be used in context of
        inference) optionally specified in kwargs
    :param dict infer: Optional dictionary of inference parameters specified
        in kwargs. See inference documentation for details.
    :returns: sample
    """
    obs = kwargs.pop("obs", None)
    infer = kwargs.pop("infer", {}).copy()
    # check if stack is empty
    # if stack empty, default behavior (defined here)
    if not am_i_wrapped():
        if obs is not None:
            warnings.warn(
                "trying to observe a value outside of inference at " + name,
                RuntimeWarning)
            return obs
        return fn(*args, **kwargs)
    # if stack not empty, apply everything in the stack?
    else:
        # initialize data structure to pass up/down the stack
        msg = {
            "type": "sample",
            "name": name,
            "fn": fn,
            "is_observed": False,
            "args": args,
            "kwargs": kwargs,
            "value": None,
            "infer": infer,
            "scale": 1.0,
            "mask": None,
            "cond_indep_stack": (),
            "done": False,
            "stop": False,
            "continuation": None
        }
        # handle observation
        if obs is not None:
            msg["value"] = obs
            msg["is_observed"] = True
        # apply the stack and return its return value
        apply_stack(msg)
        return msg["value"]
Example #7
0
 def __iter__(self):
     if not am_i_wrapped():
         for i in self.subsample:
             yield i if isinstance(i, numbers.Number) else i.item()
     else:
         indep_context = poutine.indep(name=self.name, size=self.subsample_size)
         with poutine.scale(scale=self.size / self.subsample_size):
             for i in self.subsample:
                 indep_context.next_context()
                 with indep_context:
                     # convert to python numeric type as functions like torch.ones(*args)
                     # do not work with dim 0 torch.Tensor instances.
                     yield i if isinstance(i, numbers.Number) else i.item()
Example #8
0
 def __enter__(self):
     self._wrapped = am_i_wrapped()
     self.dim = _DIM_ALLOCATOR.allocate(self.name, self.dim)
     if self._wrapped:
         try:
             self._scale_messenger = poutine.scale(scale=self.size / self.subsample_size)
             self._indep_messenger = poutine.indep(name=self.name, size=self.subsample_size, dim=self.dim)
             self._scale_messenger.__enter__()
             self._indep_messenger.__enter__()
         except BaseException:
             _DIM_ALLOCATOR.free(self.name, self.dim)
             raise
     return self.subsample
Example #9
0
 def __iter__(self):
     if not am_i_wrapped():
         for i in self.subsample:
             yield i if isinstance(i, numbers.Number) else i.item()
     else:
         indep_context = poutine.indep(name=self.name,
                                       size=self.subsample_size)
         with poutine.scale(scale=self.size / self.subsample_size):
             for i in self.subsample:
                 indep_context.next_context()
                 with indep_context:
                     # convert to python numeric type as functions like torch.ones(*args)
                     # do not work with dim 0 torch.Tensor instances.
                     yield i if isinstance(i, numbers.Number) else i.item()
Example #10
0
def sample(name, fn, *args, **kwargs):
    """
    Calls the stochastic function `fn` with additional side-effects depending
    on `name` and the enclosing context (e.g. an inference algorithm).
    See `Intro I <http://pyro.ai/examples/intro_part_i.html>`_ and
    `Intro II <http://pyro.ai/examples/intro_part_ii.html>`_ for a discussion.

    :param name: name of sample
    :param fn: distribution class or function
    :param obs: observed datum (optional; should only be used in context of
        inference) optionally specified in kwargs
    :param dict infer: Optional dictionary of inference parameters specified
        in kwargs. See inference documentation for details.
    :returns: sample
    """
    obs = kwargs.pop("obs", None)
    infer = kwargs.pop("infer", {})
    # check if stack is empty
    # if stack empty, default behavior (defined here)
    if not am_i_wrapped():
        if obs is not None:
            warnings.warn("trying to observe a value outside of inference at " + name,
                          RuntimeWarning)
            return obs
        return fn(*args, **kwargs)
    # if stack not empty, apply everything in the stack?
    else:
        # initialize data structure to pass up/down the stack
        msg = {
            "type": "sample",
            "name": name,
            "fn": fn,
            "is_observed": False,
            "args": args,
            "kwargs": kwargs,
            "value": None,
            "infer": infer,
            "scale": 1.0,
            "cond_indep_stack": (),
            "done": False,
            "stop": False,
            "continuation": None
        }
        # handle observation
        if obs is not None:
            msg["value"] = obs
            msg["is_observed"] = True
        # apply the stack and return its return value
        apply_stack(msg)
        return msg["value"]
Example #11
0
 def __enter__(self):
     self._wrapped = am_i_wrapped()
     self.dim = _DIM_ALLOCATOR.allocate(self.name, self.dim)
     if self._wrapped:
         try:
             self._scale_messenger = poutine.scale(scale=self.size /
                                                   self.subsample_size)
             self._indep_messenger = poutine.indep(name=self.name,
                                                   size=self.subsample_size,
                                                   dim=self.dim)
             self._scale_messenger.__enter__()
             self._indep_messenger.__enter__()
         except BaseException:
             _DIM_ALLOCATOR.free(self.name, self.dim)
             raise
     return self.subsample
Example #12
0
    def forward(self, zs):
        embedding = self.pre_recurrence_linear(zs)
        hiddens = [None, None]
        teacher = None
        if runtime.am_i_wrapped() and\
           isinstance(runtime._PYRO_STACK[-1], ConditionMessenger):
            data = runtime._PYRO_STACK[-1].data
            if '$%s$' % self._smiles_name in data:
                teacher = data['$%s$' % self._smiles_name]

        probs = []
        for i in range(self._max_len):
            hiddens[0] = self.recurrence1(embedding, hiddens[0])
            hiddens[1] = self.recurrence2(hiddens[0], hiddens[1])
            embedding = self.decoder(hiddens[1])

            probs.append(embedding)
            if teacher is not None:
                embedding = teacher[:, i]
        probs = torch.stack(probs, dim=1)

        logits_categorical = dist.OneHotCategorical(probs=probs).to_event(1)
        return pyro.sample('$%s$' % self._smiles_name, logits_categorical)