예제 #1
0
    def log_prob(self, value):
        """
        :param ~torch.Tensor value: One-hot encoded observation. Must be
            real-valued (float) and broadcastable to
            ``(batch_size, num_steps, categorical_size)`` where
            ``categorical_size`` is the dimension of the categorical output.
            Missing data is represented by zeros, i.e.
            ``value[batch, step, :] == tensor([0, ..., 0])``.
            Variable length observation sequences can be handled by padding
            the sequence with zeros at the end.
        """

        assert value.shape[-1] == self.event_shape[1]

        # Combine observation and transition factors.
        value_logits = torch.matmul(
            value, torch.transpose(self.observation_logits, -2, -1))
        result = self.transition_logits.unsqueeze(-3) + value_logits[..., 1:,
                                                                     None, :]

        # Eliminate time dimension.
        result = _sequential_logmatmulexp(result)

        # Combine initial factor.
        result = self.initial_logits + value_logits[
            ..., 0, :] + result.logsumexp(-1)

        # Marginalize out final state.
        result = result.logsumexp(-1)
        return result
예제 #2
0
def test_sequential_logmatmulexp(batch_shape, state_dim, num_steps):
    logits = torch.randn(batch_shape + (num_steps, state_dim, state_dim))
    actual = _sequential_logmatmulexp(logits)
    assert actual.shape == batch_shape + (state_dim, state_dim)

    # Check against einsum.
    operands = list(logits.unbind(-3))
    symbol = (opt_einsum.get_symbol(i) for i in range(1000))
    batch_symbols = ''.join(next(symbol) for _ in batch_shape)
    state_symbols = [next(symbol) for _ in range(num_steps + 1)]
    equation = (','.join(batch_symbols + state_symbols[t] + state_symbols[t + 1]
                         for t in range(num_steps)) +
                '->' + batch_symbols + state_symbols[0] + state_symbols[-1])
    expected = opt_einsum.contract(equation, *operands, backend='pyro.ops.einsum.torch_log')
    assert_close(actual, expected)