Exemple #1
0
    def __init__(self,
                 base_dist,
                 *,
                 gate=None,
                 gate_logits=None,
                 validate_args=None):
        if (gate is None) == (gate_logits is None):
            raise ValueError(
                "Either `gate` or `gate_logits` must be specified, but not both."
            )
        if gate is not None:
            batch_shape = broadcast_shape(gate.shape, base_dist.batch_shape)
            self.gate = gate.expand(batch_shape)
        else:
            batch_shape = broadcast_shape(gate_logits.shape,
                                          base_dist.batch_shape)
            self.gate_logits = gate_logits.expand(batch_shape)
        if base_dist.event_shape:
            raise ValueError("ZeroInflatedDistribution expected empty "
                             "base_dist.event_shape but got {}".format(
                                 base_dist.event_shape))

        self.base_dist = base_dist.expand(batch_shape)
        event_shape = torch.Size()

        super().__init__(batch_shape, event_shape, validate_args)
Exemple #2
0
def find_domain(op, *domains):
    r"""
    Finds the :class:`Domain` resulting when applying ``op`` to ``domains``.
    :param callable op: An operation.
    :param Domain \*domains: One or more input domains.
    """
    assert callable(op), op
    assert all(isinstance(arg, Domain) for arg in domains)
    if len(domains) == 1:
        dtype = domains[0].dtype
        shape = domains[0].shape
        if op is ops.log or op is ops.exp:
            dtype = 'real'
        elif isinstance(op, ops.ReshapeOp):
            shape = op.shape
        elif isinstance(op, ops.AssociativeOp):
            shape = ()
        return Domain(shape, dtype)

    lhs, rhs = domains
    if isinstance(op, ops.GetitemOp):
        dtype = lhs.dtype
        shape = lhs.shape[:op.offset] + lhs.shape[1 + op.offset:]
        return Domain(shape, dtype)
    elif op == ops.matmul:
        assert lhs.shape and rhs.shape
        if len(rhs.shape) == 1:
            assert lhs.shape[-1] == rhs.shape[-1]
            shape = lhs.shape[:-1]
        elif len(lhs.shape) == 1:
            assert lhs.shape[-1] == rhs.shape[-2]
            shape = rhs.shape[:-2] + rhs.shape[-1:]
        else:
            assert lhs.shape[-1] == rhs.shape[-2]
            shape = broadcast_shape(lhs.shape[:-1], rhs.shape[:-2] +
                                    (1, )) + rhs.shape[-1:]
        return Domain(shape, 'real')

    if lhs.dtype == 'real' or rhs.dtype == 'real':
        dtype = 'real'
    elif op in (ops.add, ops.mul, ops.pow, ops.max, ops.min):
        dtype = op(lhs.dtype - 1, rhs.dtype - 1) + 1
    elif op in (ops.and_, ops.or_, ops.xor):
        dtype = 2
    elif lhs.dtype == rhs.dtype:
        dtype = lhs.dtype
    else:
        raise NotImplementedError('TODO')

    if lhs.shape == rhs.shape:
        shape = lhs.shape
    else:
        shape = broadcast_shape(lhs.shape, rhs.shape)
    return Domain(shape, dtype)
Exemple #3
0
def test_broadcast(mask_shape, component0_shape, component1_shape, value_shape):
    mask = torch.empty(torch.Size(mask_shape)).bernoulli_(0.5).bool()
    component0 = dist.Normal(torch.zeros(component0_shape), 1.0)
    component1 = dist.Exponential(torch.ones(component1_shape))
    value = torch.ones(value_shape)

    d = dist.MaskedMixture(mask, component0, component1)
    d_shape = broadcast_shape(mask_shape, component0_shape, component1_shape)
    assert d.batch_shape == d_shape

    log_prob_shape = broadcast_shape(d_shape, value_shape)
    assert d.log_prob(value).shape == log_prob_shape
Exemple #4
0
def test_stable_hmm_shape(init_shape, trans_mat_shape, trans_dist_shape,
                          obs_mat_shape, obs_dist_shape, hidden_dim, obs_dim):
    stability = dist.Uniform(0, 2).sample()
    init_dist = random_stable(stability,
                              init_shape + (hidden_dim, )).to_event(1)
    trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim))
    trans_dist = random_stable(stability,
                               trans_dist_shape + (hidden_dim, )).to_event(1)
    obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim))
    obs_dist = random_stable(stability,
                             obs_dist_shape + (obs_dim, )).to_event(1)
    d = dist.LinearHMM(init_dist,
                       trans_mat,
                       trans_dist,
                       obs_mat,
                       obs_dist,
                       duration=4)

    shape = broadcast_shape(init_shape + (4, ), trans_mat_shape,
                            trans_dist_shape, obs_mat_shape, obs_dist_shape)
    expected_batch_shape, time_shape = shape[:-1], shape[-1:]
    expected_event_shape = time_shape + (obs_dim, )
    assert d.batch_shape == expected_batch_shape
    assert d.event_shape == expected_event_shape
    assert d.support.event_dim == d.event_dim

    x = d.rsample()
    assert x.shape == d.shape()
    x = d.rsample((6, ))
    assert x.shape == (6, ) + d.shape()
    x = d.expand((6, 5)).rsample()
    assert x.shape == (6, 5) + d.event_shape
Exemple #5
0
def _eager_contract_tensors(reduced_vars, terms, backend):
    iter_symbols = map(opt_einsum.get_symbol, itertools.count())
    symbols = defaultdict(functools.partial(next, iter_symbols))

    inputs = OrderedDict()
    einsum_inputs = []
    operands = []
    for term in terms:
        inputs.update(term.inputs)
        einsum_inputs.append("".join(symbols[k] for k in term.inputs) +
                             "".join(symbols[i - len(term.shape)]
                                     for i, size in enumerate(term.shape)
                                     if size != 1))

        # Squeeze absent event dims to be compatible with einsum.
        data = term.data
        batch_shape = data.shape[:data.dim() - len(term.shape)]
        event_shape = tuple(size for size in term.shape if size != 1)
        data = data.reshape(batch_shape + event_shape)
        operands.append(data)

    for k in reduced_vars:
        del inputs[k]
    batch_shape = tuple(v.size for v in inputs.values())
    event_shape = broadcast_shape(*(term.shape for term in terms))
    einsum_output = (
        "".join(symbols[k] for k in inputs) +
        "".join(symbols[dim]
                for dim in range(-len(event_shape), 0) if dim in symbols))
    equation = ",".join(einsum_inputs) + "->" + einsum_output
    data = opt_einsum.contract(equation, *operands, backend=backend)
    data = data.reshape(batch_shape + event_shape)
    return Tensor(data, inputs)
Exemple #6
0
 def get_scipy_batch_logpdf(self, idx):
     if not self.scipy_arg_fn:
         return
     dist_params = self.get_dist_params(idx, wrap_tensor=False)
     dist_params_wrapped = self.get_dist_params(idx)
     dist_params = self._convert_logits_to_ps(dist_params)
     test_data = self.get_test_data(idx, wrap_tensor=False)
     test_data_wrapped = self.get_test_data(idx)
     shape = broadcast_shape(
         self.pyro_dist(**dist_params_wrapped).shape(),
         test_data_wrapped.size())
     log_prob = []
     for i in range(len(test_data)):
         batch_params = {}
         for k in dist_params:
             param = np.broadcast_to(dist_params[k], shape)
             batch_params[k] = param[i]
         args, kwargs = self.scipy_arg_fn(**batch_params)
         if self.is_discrete:
             log_prob.append(
                 self.scipy_dist.logpmf(test_data[i], *args, **kwargs))
         else:
             log_prob.append(
                 self.scipy_dist.logpdf(test_data[i], *args, **kwargs))
     return log_prob
Exemple #7
0
def test_gaussian_hmm_log_prob(init_shape, trans_mat_shape, trans_mvn_shape,
                               obs_mat_shape, obs_mvn_shape, hidden_dim,
                               obs_dim):
    init_dist = random_mvn(init_shape, hidden_dim)
    trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim))
    trans_dist = random_mvn(trans_mvn_shape, hidden_dim)
    obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim))
    obs_dist = random_mvn(obs_mvn_shape, obs_dim)

    actual_dist = GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat,
                              obs_dist)
    expected_dist = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat,
                                     obs_dist)
    assert actual_dist.batch_shape == expected_dist.batch_shape
    assert actual_dist.event_shape == expected_dist.event_shape

    shape = broadcast_shape(init_shape + (1, ), trans_mat_shape,
                            trans_mvn_shape, obs_mat_shape, obs_mvn_shape)
    data = obs_dist.expand(shape).sample()
    assert data.shape == actual_dist.shape()

    actual_log_prob = actual_dist.log_prob(data)
    expected_log_prob = expected_dist.log_prob(data)
    assert_close(actual_log_prob, expected_log_prob, atol=1e-5, rtol=1e-5)
    check_expand(actual_dist, data)
Exemple #8
0
def test_gamma_gaussian_hmm_shape(scale_shape, init_shape, trans_mat_shape,
                                  trans_mvn_shape, obs_mat_shape,
                                  obs_mvn_shape, hidden_dim, obs_dim):
    init_dist = random_mvn(init_shape, hidden_dim)
    trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim))
    trans_dist = random_mvn(trans_mvn_shape, hidden_dim)
    obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim))
    obs_dist = random_mvn(obs_mvn_shape, obs_dim)
    scale_dist = random_gamma(scale_shape)
    d = dist.GammaGaussianHMM(scale_dist, init_dist, trans_mat, trans_dist,
                              obs_mat, obs_dist)

    shape = broadcast_shape(scale_shape + (1, ), init_shape + (1, ),
                            trans_mat_shape, trans_mvn_shape, obs_mat_shape,
                            obs_mvn_shape)
    expected_batch_shape, time_shape = shape[:-1], shape[-1:]
    expected_event_shape = time_shape + (obs_dim, )
    assert d.batch_shape == expected_batch_shape
    assert d.event_shape == expected_event_shape
    assert d.support.event_dim == d.event_dim

    data = obs_dist.expand(shape).sample()
    assert data.shape == d.shape()
    actual = d.log_prob(data)
    assert actual.shape == expected_batch_shape
    check_expand(d, data)

    mixing, final = d.filter(data)
    assert isinstance(mixing, dist.Gamma)
    assert mixing.batch_shape == d.batch_shape
    assert mixing.event_shape == ()
    assert isinstance(final, dist.MultivariateNormal)
    assert final.batch_shape == d.batch_shape
    assert final.event_shape == (hidden_dim, )
Exemple #9
0
def test_discrete_hmm_shape(ok, init_shape, trans_shape, obs_shape,
                            event_shape, state_dim):
    init_logits = torch.randn(init_shape + (state_dim, ))
    trans_logits = torch.randn(trans_shape + (state_dim, state_dim))
    obs_logits = torch.randn(obs_shape + (state_dim, ) + event_shape)
    obs_dist = dist.Bernoulli(logits=obs_logits).to_event(len(event_shape))
    data = obs_dist.sample()[(slice(None), ) * len(obs_shape) + (0, )]

    if not ok:
        with pytest.raises(ValueError):
            d = dist.DiscreteHMM(init_logits, trans_logits, obs_dist)
            d.log_prob(data)
        return

    d = dist.DiscreteHMM(init_logits, trans_logits, obs_dist)
    assert d.support.event_dim == d.event_dim

    actual = d.log_prob(data)
    expected_shape = broadcast_shape(init_shape, trans_shape[:-1],
                                     obs_shape[:-1])
    assert actual.shape == expected_shape
    check_expand(d, data)

    final = d.filter(data)
    assert isinstance(final, dist.Categorical)
    assert final.batch_shape == d.batch_shape
    assert final.event_shape == ()
    assert final.support.upper_bound == state_dim - 1
Exemple #10
0
    def __init__(self,
                 initial_dist,
                 transition_dist,
                 observation_dist,
                 validate_args=None):
        assert isinstance(initial_dist, torch.distributions.MultivariateNormal)
        assert isinstance(transition_dist,
                          torch.distributions.MultivariateNormal)
        assert isinstance(observation_dist,
                          torch.distributions.MultivariateNormal)
        hidden_dim = initial_dist.event_shape[0]
        assert transition_dist.event_shape[0] == hidden_dim + hidden_dim
        obs_dim = observation_dist.event_shape[0] - hidden_dim

        shape = broadcast_shape(initial_dist.batch_shape + (1, ),
                                transition_dist.batch_shape,
                                observation_dist.batch_shape)
        batch_shape, time_shape = shape[:-1], shape[-1:]
        event_shape = time_shape + (obs_dim, )
        super(GaussianMRF, self).__init__(batch_shape,
                                          event_shape,
                                          validate_args=validate_args)
        self.hidden_dim = hidden_dim
        self.obs_dim = obs_dim
        self._init = mvn_to_gaussian(initial_dist)
        self._trans = mvn_to_gaussian(transition_dist)
        self._obs = mvn_to_gaussian(observation_dist)
Exemple #11
0
 def __init__(self, base_dist, mask):
     if broadcast_shape(mask.shape, base_dist.batch_shape) != base_dist.batch_shape:
         raise ValueError("Expected mask.shape to be broadcastable to base_dist.batch_shape, "
                          "actual {} vs {}".format(mask.shape, base_dist.batch_shape))
     self.base_dist = base_dist
     self._mask = mask
     super(MaskedDistribution, self).__init__(base_dist.batch_shape, base_dist.event_shape)
Exemple #12
0
def test_gaussian_hmm_shape(diag, init_shape, trans_mat_shape, trans_mvn_shape,
                            obs_mat_shape, obs_mvn_shape, hidden_dim, obs_dim):
    init_dist = random_mvn(init_shape, hidden_dim)
    trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim))
    trans_dist = random_mvn(trans_mvn_shape, hidden_dim)
    obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim))
    obs_dist = random_mvn(obs_mvn_shape, obs_dim)
    if diag:
        scale = obs_dist.scale_tril.diagonal(dim1=-2, dim2=-1)
        obs_dist = dist.Normal(obs_dist.loc, scale).to_event(1)
    d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist)

    shape = broadcast_shape(init_shape + (1,),
                            trans_mat_shape,
                            trans_mvn_shape,
                            obs_mat_shape,
                            obs_mvn_shape)
    expected_batch_shape, time_shape = shape[:-1], shape[-1:]
    expected_event_shape = time_shape + (obs_dim,)
    assert d.batch_shape == expected_batch_shape
    assert d.event_shape == expected_event_shape

    data = obs_dist.expand(shape).sample()
    assert data.shape == d.shape()
    actual = d.log_prob(data)
    assert actual.shape == expected_batch_shape
    check_expand(d, data)

    final = d.filter(data)
    assert isinstance(final, dist.MultivariateNormal)
    assert final.batch_shape == d.batch_shape
    assert final.event_shape == (hidden_dim,)
Exemple #13
0
    def sample(self, guide_name, fn, infer=None):
        """
        Wrapper around ``pyro.sample()`` to create a single auxiliary sample
        site and then unpack to multiple sample sites for model replay.

        :param str guide_name: The name of the auxiliary guide site.
        :param callable fn: A distribution with shape ``self.event_shape``.
        :param dict infer: Optional inference configuration dict.
        :returns: A pair ``(guide_z, model_zs)`` where ``guide_z`` is the
            single concatenated blob and ``model_zs`` is a dict mapping
            site name to constrained model sample.
        :rtype: tuple
        """
        # Sample a packed tensor.
        if fn.event_shape != self.event_shape:
            raise ValueError(
                "Invalid fn.event_shape for group: expected {}, actual {}".
                format(tuple(self.event_shape), tuple(fn.event_shape)))
        if infer is None:
            infer = {}
        infer["is_auxiliary"] = True
        guide_z = pyro.sample(guide_name, fn, infer=infer)
        common_batch_shape = guide_z.shape[:-1]

        model_zs = {}
        pos = 0
        for site in self.prototype_sites:
            name = site["name"]
            fn = site["fn"]

            # Extract slice from packed sample.
            size = self._site_sizes[name]
            batch_shape = broadcast_shape(common_batch_shape,
                                          self._site_batch_shapes[name])
            unconstrained_z = guide_z[..., pos:pos + size]
            unconstrained_z = unconstrained_z.reshape(batch_shape +
                                                      fn.event_shape)
            pos += size

            # Transform to constrained space.
            transform = biject_to(fn.support)
            z = transform(unconstrained_z)
            log_density = transform.inv.log_abs_det_jacobian(
                z, unconstrained_z)
            log_density = sum_rightmost(
                log_density,
                log_density.dim() - z.dim() + fn.event_dim)
            delta_dist = dist.Delta(z,
                                    log_density=log_density,
                                    event_dim=fn.event_dim)

            # Replay model sample statement.
            with ExitStack() as stack:
                for frame in site["cond_indep_stack"]:
                    plate = self.guide.plate(frame.name)
                    if plate not in runtime._PYRO_STACK:
                        stack.enter_context(plate)
                model_zs[name] = pyro.sample(name, delta_dist)

        return guide_z, model_zs
Exemple #14
0
 def __init__(self, mu, sigma, *args, **kwargs):
     torch_dist = torch.distributions.Normal(mu, sigma)
     x_shape = torch.Size(
         broadcast_shape(mu.size(), sigma.size(), strict=True))
     event_dim = 1
     super(Normal, self).__init__(torch_dist, x_shape, event_dim, *args,
                                  **kwargs)
Exemple #15
0
    def __init__(
        self,
        total_count,
        logits,
        multiplicative_noise_scale,
        *,
        num_quad_points=8,
        validate_args=None,
    ):
        if num_quad_points < 1:
            raise ValueError("num_quad_points must be positive.")

        total_count, logits, multiplicative_noise_scale = broadcast_all(
            total_count, logits, multiplicative_noise_scale)

        self.quad_points, self.log_weights = get_quad_rule(
            num_quad_points, logits)
        quad_logits = (
            logits.unsqueeze(-1) +
            multiplicative_noise_scale.unsqueeze(-1) * self.quad_points)
        self.nb_dist = NegativeBinomial(total_count=total_count.unsqueeze(-1),
                                        logits=quad_logits)

        self.multiplicative_noise_scale = multiplicative_noise_scale
        self.total_count = total_count
        self.logits = logits
        self.num_quad_points = num_quad_points

        batch_shape = broadcast_shape(multiplicative_noise_scale.shape,
                                      self.nb_dist.batch_shape[:-1])
        event_shape = torch.Size()

        super().__init__(batch_shape, event_shape, validate_args)
Exemple #16
0
 def __init__(self, leaf_times, rate_grid, *, validate_args=None):
     batch_shape = broadcast_shape(leaf_times.shape[:-1],
                                   rate_grid.shape[:-1])
     event_shape = (leaf_times.size(-1) - 1, )
     self.leaf_times = leaf_times
     self.rate_grid = rate_grid
     super().__init__(batch_shape, event_shape, validate_args=validate_args)
Exemple #17
0
    def sample(self, sample_shape=torch.Size([])):
        """
        :param ~torch.Size sample_shape: Sample shape, last dimension must be
            ``num_steps`` and must be broadcastable to
            ``(batch_size, num_steps)``. batch_size must be int not tuple.
        """
        # shape: batch_size x num_steps x categorical_size
        shape = broadcast_shape(
            torch.Size(list(self.batch_shape) + [1, 1]),
            torch.Size(list(sample_shape) + [1]),
            torch.Size((1, 1, self.event_shape[-1])),
        )
        # state: batch_size x state_dim
        state = OneHotCategorical(logits=self.initial_logits).sample()
        # sample: batch_size x num_steps x categorical_size
        sample = torch.zeros(shape)
        for i in range(shape[-2]):
            # batch_size x 1 x state_dim @
            # batch_size x state_dim x categorical_size
            obs_logits = torch.matmul(state.unsqueeze(-2),
                                      self.observation_logits).squeeze(-2)
            sample[:, i, :] = OneHotCategorical(logits=obs_logits).sample()
            # batch_size x 1 x state_dim @
            # batch_size x state_dim x state_dim
            trans_logits = torch.matmul(state.unsqueeze(-2),
                                        self.transition_logits).squeeze(-2)
            state = OneHotCategorical(logits=trans_logits).sample()

        return sample
Exemple #18
0
 def __init__(self,
              initial_logits,
              transition_logits,
              observation_dist,
              validate_args=None):
     if initial_logits.dim() < 1:
         raise ValueError(
             "expected initial_logits to have at least one dim, "
             "actual shape = {}".format(initial_logits.shape))
     if transition_logits.dim() < 2:
         raise ValueError(
             "expected transition_logits to have at least two dims, "
             "actual shape = {}".format(transition_logits.shape))
     if len(observation_dist.batch_shape) < 1:
         raise ValueError(
             "expected observation_dist to have at least one batch dim, "
             "actual .batch_shape = {}".format(
                 observation_dist.batch_shape))
     shape = broadcast_shape(initial_logits.shape[:-1] + (1, ),
                             transition_logits.shape[:-2],
                             observation_dist.batch_shape[:-1])
     batch_shape, time_shape = shape[:-1], shape[-1:]
     event_shape = time_shape + observation_dist.event_shape
     self.initial_logits = initial_logits - initial_logits.logsumexp(
         -1, True)
     self.transition_logits = transition_logits - transition_logits.logsumexp(
         -1, True)
     self.observation_dist = observation_dist
     super(DiscreteHMM, self).__init__(batch_shape,
                                       event_shape,
                                       validate_args=validate_args)
Exemple #19
0
def _make_phylogeny(leaf_times, coal_times):
    assert leaf_times.size(-1) == 1 + coal_times.size(-1)

    # Expand shapes to match.
    N = leaf_times.size(-1)
    batch_shape = broadcast_shape(leaf_times.shape[:-1], coal_times.shape[:-1])
    if leaf_times.shape[:-1] != batch_shape:
        leaf_times = leaf_times.expand(batch_shape + (N, ))
    if coal_times.shape[:-1] != batch_shape:
        coal_times = coal_times.expand(batch_shape + (N - 1, ))

    # Combine N sampling events (leaf_times) plus N-1 coalescent events
    # (coal_times) into a pair (times, signs) of arrays of length 2N-1, where
    # leaf sample sign is +1 and coalescent sign is -1.
    times = torch.cat([coal_times, leaf_times], dim=-1)
    signs = torch.linspace(1.5 - N, N - 0.5,
                           2 * N - 1).sign()  # e.g. [-1, -1, +1, +1, +1]

    # Sort the events reverse-ordered in time, i.e. latest to earliest.
    times, index = times.sort(dim=-1, descending=True)
    signs = signs[index]
    inv_index = index.new_empty(index.shape)
    inv_index.scatter_(-1, index, torch.arange(2 * N - 1).expand_as(index))

    # Compute the number n of lineages preceding each event, then the binomial
    # coefficients that will multiply the base coalescence rate.
    lineages = signs.cumsum(-1)
    binomial = lineages * (lineages - 1) / 2

    # Compute the binomial coefficient following each coalescent event.
    coal_index = inv_index[..., :N - 1]
    coal_binomial = binomial.gather(-1, coal_index - 1)

    return _Phylogeny(times, signs, lineages, binomial, coal_binomial)
Exemple #20
0
def test_studentt_hmm_shape(
    init_shape,
    trans_mat_shape,
    trans_dist_shape,
    obs_mat_shape,
    obs_dist_shape,
    hidden_dim,
    obs_dim,
):
    init_dist = random_studentt(init_shape + (hidden_dim, )).to_event(1)
    trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim))
    trans_dist = random_studentt(trans_dist_shape + (hidden_dim, )).to_event(1)
    obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim))
    obs_dist = random_studentt(obs_dist_shape + (obs_dim, )).to_event(1)
    d = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist)

    shape = broadcast_shape(
        init_shape + (1, ),
        trans_mat_shape,
        trans_dist_shape,
        obs_mat_shape,
        obs_dist_shape,
    )
    expected_batch_shape, time_shape = shape[:-1], shape[-1:]
    expected_event_shape = time_shape + (obs_dim, )
    assert d.batch_shape == expected_batch_shape
    assert d.event_shape == expected_event_shape
    assert d.support.event_dim == d.event_dim

    x = d.rsample()
    assert x.shape == d.shape()
    x = d.rsample((6, ))
    assert x.shape == (6, ) + d.shape()
    x = d.expand((6, 5)).rsample()
    assert x.shape == (6, 5) + d.event_shape
Exemple #21
0
def matrix_and_mvn_to_gaussian(matrix, mvn):
    """
    Convert a noisy affine function to a Gaussian. The noisy affine function is defined as::

        y = x @ matrix + mvn.sample()

    :param ~torch.Tensor matrix: A matrix with rightmost shape ``(x_dim, y_dim)``.
    :param ~torch.distributions.MultivariateNormal mvn: A multivariate normal distribution.
    :return: A Gaussian with broadcasted batch shape and ``.dim() == x_dim + y_dim``.
    :rtype: ~pyro.ops.gaussian.Gaussian
    """
    assert (isinstance(mvn, torch.distributions.MultivariateNormal) or
            (isinstance(mvn, torch.distributions.Independent) and
             isinstance(mvn.base_dist, torch.distributions.Normal)))
    assert isinstance(matrix, torch.Tensor)
    x_dim, y_dim = matrix.shape[-2:]
    assert mvn.event_shape == (y_dim,)
    batch_shape = broadcast_shape(matrix.shape[:-2], mvn.batch_shape)
    matrix = matrix.expand(batch_shape + (x_dim, y_dim))
    mvn = mvn.expand(batch_shape)

    # Handle diagonal normal distributions as an efficient special case.
    if isinstance(mvn, torch.distributions.Independent):
        return AffineNormal(matrix, mvn.base_dist.loc, mvn.base_dist.scale)

    y_gaussian = mvn_to_gaussian(mvn)
    result = _matrix_and_gaussian_to_gaussian(matrix, y_gaussian)
    assert result.batch_shape == batch_shape
    assert result.dim() == x_dim + y_dim
    return result
Exemple #22
0
    def __init__(self, mask, component0, component1, validate_args=None):
        if not torch.is_tensor(mask) or mask.dtype != torch.bool:
            raise ValueError(
                'Expected mask to be a BoolTensor but got {}'.format(
                    type(mask)))
        if component0.event_shape != component1.event_shape:
            raise ValueError(
                'components event_shape disagree: {} vs {}'.format(
                    component0.event_shape, component1.event_shape))
        batch_shape = broadcast_shape(mask.shape, component0.batch_shape,
                                      component1.batch_shape)
        if mask.shape != batch_shape:
            mask = mask.expand(batch_shape)
        if component0.batch_shape != batch_shape:
            component0 = component0.expand(batch_shape)
        if component1.batch_shape != batch_shape:
            component1 = component1.expand(batch_shape)

        self.mask = mask
        self.component0 = component0
        self.component1 = component1
        super().__init__(batch_shape, component0.event_shape, validate_args)

        # We need to disable _validate_sample on each component since samples are only valid on the
        # component from which they are drawn. Instead we perform validation using a MaskedConstraint.
        self.component0._validate_args = False
        self.component1._validate_args = False
Exemple #23
0
 def log_prob(self, value):
     if self._mask is False:
         shape = broadcast_shape(self.base_dist.batch_shape,
                                 value.shape[:value.dim() - self.event_dim])
         return torch.zeros((), device=value.device).expand(shape)
     if self._mask is True:
         return self.base_dist.log_prob(value)
     return scale_and_mask(self.base_dist.log_prob(value), mask=self._mask)
Exemple #24
0
 def batch_shape(self):
     return broadcast_shape(
         self.log_normalizer.shape,
         self.info_vec.shape[:-1],
         self.precision.shape[:-2],
         self.alpha.shape,
         self.beta.shape,
     )
Exemple #25
0
 def infer_shapes(
     loc, covariance_matrix=None, precision_matrix=None, scale_tril=None
 ):
     batch_shape, event_shape = loc[:-1], loc[-1:]
     for matrix in [covariance_matrix, precision_matrix, scale_tril]:
         if matrix is not None:
             batch_shape = broadcast_shape(batch_shape, matrix[:-2])
     return batch_shape, event_shape
Exemple #26
0
 def score_parts(self, value):
     shape = broadcast_shape(self.batch_shape, value.shape[:value.dim() - self.event_dim])
     log_prob, score_function, entropy_term = self.base_dist.score_parts(value)
     log_prob = sum_rightmost(log_prob, self.reinterpreted_batch_ndims).expand(shape)
     if not isinstance(score_function, numbers.Number):
         score_function = sum_rightmost(score_function, self.reinterpreted_batch_ndims).expand(shape)
     if not isinstance(entropy_term, numbers.Number):
         entropy_term = sum_rightmost(entropy_term, self.reinterpreted_batch_ndims).expand(shape)
     return ScoreParts(log_prob, score_function, entropy_term)
Exemple #27
0
def eager_multinomial(total_count, probs, value):
    # Multinomial.log_prob() supports inhomogeneous total_count only by
    # avoiding passing total_count to the constructor.
    inputs, (total_count, probs, value) = align_tensors(total_count, probs, value)
    shape = broadcast_shape(total_count.shape + (1,), probs.shape, value.shape)
    probs = Tensor(probs.expand(shape), inputs)
    value = Tensor(value.expand(shape), inputs)
    total_count = Number(total_count.max().item())  # Used by distributions validation code.
    return Multinomial.eager_log_prob(total_count=total_count, probs=probs, value=value)
Exemple #28
0
def test_support_shape(dist):
    for idx in range(dist.get_num_test_data()):
        dist_params = dist.get_dist_params(idx)
        d = dist.pyro_dist(**dist_params)
        assert d.support.event_dim == d.event_dim
        x = dist.get_test_data(idx)
        ok = d.support.check(x)
        assert ok.shape == broadcast_shape(d.batch_shape, x.shape[:x.dim() - d.event_dim])
        assert ok.all()
Exemple #29
0
 def __init__(self, leaf_times, rate=1., *, validate_args=None):
     rate = torch.as_tensor(rate,
                            dtype=leaf_times.dtype,
                            device=leaf_times.device)
     batch_shape = broadcast_shape(rate.shape, leaf_times.shape[:-1])
     event_shape = (leaf_times.size(-1) - 1, )
     self.leaf_times = leaf_times
     self.rate = rate
     super().__init__(batch_shape, event_shape, validate_args=validate_args)
Exemple #30
0
def _gather(tensor, dim, index):
    """
    Like :func:`torch.gather` but broadcasts.
    """
    if dim != -1:
        raise NotImplementedError
    shape = broadcast_shape(tensor.shape[:-1], index.shape[:-1]) + (-1, )
    tensor = tensor.expand(shape)
    index = index.expand(shape)
    return tensor.gather(dim, index)
Exemple #31
0
    def forward(self, *input_args):
        # we have a single object
        if len(input_args) == 1:
            # regardless of type,
            # we don't care about single objects
            # we just index into the object
            input_args = input_args[0]

        # don't concat things that are just single objects
        if torch.is_tensor(input_args):
            return input_args
        else:
            if self.allow_broadcast:
                shape = broadcast_shape(*[s.shape[:-1] for s in input_args]) + (-1,)
                input_args = [s.expand(shape) for s in input_args]
            return torch.cat(input_args, dim=-1)
Exemple #32
0
 def get_scipy_batch_logpdf(self, idx):
     if not self.scipy_arg_fn:
         return
     dist_params = self.get_dist_params(idx, wrap_tensor=False)
     dist_params_wrapped = self.get_dist_params(idx)
     dist_params = self._convert_logits_to_ps(dist_params)
     test_data = self.get_test_data(idx, wrap_tensor=False)
     test_data_wrapped = self.get_test_data(idx)
     shape = broadcast_shape(self.pyro_dist(**dist_params_wrapped).shape(), test_data_wrapped.size())
     log_prob = []
     for i in range(len(test_data)):
         batch_params = {}
         for k in dist_params:
             param = np.broadcast_to(dist_params[k], shape)
             batch_params[k] = param[i]
         args, kwargs = self.scipy_arg_fn(**batch_params)
         if self.is_discrete:
             log_prob.append(self.scipy_dist.logpmf(test_data[i], *args, **kwargs))
         else:
             log_prob.append(self.scipy_dist.logpdf(test_data[i], *args, **kwargs))
     return log_prob
Exemple #33
0
def _log_prob_shape(dist, x_size=torch.Size()):
    event_dims = len(dist.event_shape)
    expected_shape = broadcast_shape(dist.shape(), x_size, strict=True)
    if event_dims > 0:
        expected_shape = expected_shape[:-event_dims]
    return expected_shape
Exemple #34
0
def test_broadcast_shape(shapes):
    assert broadcast_shape(*shapes) == np.broadcast(*map(np.empty, shapes)).shape
Exemple #35
0
def test_broadcast_shape_error(shapes):
    with pytest.raises((ValueError, RuntimeError)):
        broadcast_shape(*shapes)
Exemple #36
0
def test_broadcast_shape_strict(shapes):
    assert broadcast_shape(*shapes, strict=True) == np.broadcast(*map(np.empty, shapes)).shape
Exemple #37
0
def test_broadcast_shape_strict_error(shapes):
    with pytest.raises(ValueError):
        broadcast_shape(*shapes, strict=True)
Exemple #38
0
 def log_prob(self, value):
     shape = broadcast_shape(self.batch_shape, value.shape[:value.dim() - self.event_dim])
     return sum_rightmost(self.base_dist.log_prob(value), self.reinterpreted_batch_ndims).expand(shape)