def log_prob(self, value): """ :param ~torch.Tensor value: One-hot encoded observation. Must be real-valued (float) and broadcastable to ``(batch_size, num_steps, categorical_size)`` where ``categorical_size`` is the dimension of the categorical output. Missing data is represented by zeros, i.e. ``value[batch, step, :] == tensor([0, ..., 0])``. Variable length observation sequences can be handled by padding the sequence with zeros at the end. """ assert value.shape[-1] == self.event_shape[1] # Combine observation and transition factors. value_logits = torch.matmul( value, torch.transpose(self.observation_logits, -2, -1)) result = self.transition_logits.unsqueeze(-3) + value_logits[..., 1:, None, :] # Eliminate time dimension. result = _sequential_logmatmulexp(result) # Combine initial factor. result = self.initial_logits + value_logits[ ..., 0, :] + result.logsumexp(-1) # Marginalize out final state. result = result.logsumexp(-1) return result
def test_sequential_logmatmulexp(batch_shape, state_dim, num_steps): logits = torch.randn(batch_shape + (num_steps, state_dim, state_dim)) actual = _sequential_logmatmulexp(logits) assert actual.shape == batch_shape + (state_dim, state_dim) # Check against einsum. operands = list(logits.unbind(-3)) symbol = (opt_einsum.get_symbol(i) for i in range(1000)) batch_symbols = ''.join(next(symbol) for _ in batch_shape) state_symbols = [next(symbol) for _ in range(num_steps + 1)] equation = (','.join(batch_symbols + state_symbols[t] + state_symbols[t + 1] for t in range(num_steps)) + '->' + batch_symbols + state_symbols[0] + state_symbols[-1]) expected = opt_einsum.contract(equation, *operands, backend='pyro.ops.einsum.torch_log') assert_close(actual, expected)