Ejemplo n.º 1
0
def find_domain(op, *domains):
    r"""
    Finds the :class:`Domain` resulting when applying ``op`` to ``domains``.
    :param callable op: An operation.
    :param Domain \*domains: One or more input domains.
    """
    assert callable(op), op
    assert all(isinstance(arg, Domain) for arg in domains)
    if len(domains) == 1:
        dtype = domains[0].dtype
        shape = domains[0].shape
        if op is ops.log or op is ops.exp:
            dtype = 'real'
        elif isinstance(op, ops.ReshapeOp):
            shape = op.shape
        elif isinstance(op, ops.AssociativeOp):
            shape = ()
        return Domain(shape, dtype)

    lhs, rhs = domains
    if isinstance(op, ops.GetitemOp):
        dtype = lhs.dtype
        shape = lhs.shape[:op.offset] + lhs.shape[1 + op.offset:]
        return Domain(shape, dtype)
    elif op == ops.matmul:
        assert lhs.shape and rhs.shape
        if len(rhs.shape) == 1:
            assert lhs.shape[-1] == rhs.shape[-1]
            shape = lhs.shape[:-1]
        elif len(lhs.shape) == 1:
            assert lhs.shape[-1] == rhs.shape[-2]
            shape = rhs.shape[:-2] + rhs.shape[-1:]
        else:
            assert lhs.shape[-1] == rhs.shape[-2]
            shape = broadcast_shape(lhs.shape[:-1], rhs.shape[:-2] +
                                    (1, )) + rhs.shape[-1:]
        return Domain(shape, 'real')

    if lhs.dtype == 'real' or rhs.dtype == 'real':
        dtype = 'real'
    elif op in (ops.add, ops.mul, ops.pow, ops.max, ops.min):
        dtype = op(lhs.dtype - 1, rhs.dtype - 1) + 1
    elif op in (ops.and_, ops.or_, ops.xor):
        dtype = 2
    elif lhs.dtype == rhs.dtype:
        dtype = lhs.dtype
    else:
        raise NotImplementedError('TODO')

    if lhs.shape == rhs.shape:
        shape = lhs.shape
    else:
        shape = broadcast_shape(lhs.shape, rhs.shape)
    return Domain(shape, dtype)
Ejemplo n.º 2
0
def _eager_contract_tensors(reduced_vars, terms, backend):
    iter_symbols = map(opt_einsum.get_symbol, itertools.count())
    symbols = defaultdict(functools.partial(next, iter_symbols))

    inputs = OrderedDict()
    einsum_inputs = []
    operands = []
    for term in terms:
        inputs.update(term.inputs)
        einsum_inputs.append("".join(symbols[k] for k in term.inputs) +
                             "".join(symbols[i - len(term.shape)]
                                     for i, size in enumerate(term.shape)
                                     if size != 1))

        # Squeeze absent event dims to be compatible with einsum.
        data = term.data
        batch_shape = data.shape[:len(data.shape) - len(term.shape)]
        event_shape = tuple(size for size in term.shape if size != 1)
        data = data.reshape(batch_shape + event_shape)
        operands.append(data)

    for k in reduced_vars:
        del inputs[k]
    batch_shape = tuple(v.size for v in inputs.values())
    event_shape = broadcast_shape(*(term.shape for term in terms))
    einsum_output = (
        "".join(symbols[k] for k in inputs) +
        "".join(symbols[dim]
                for dim in range(-len(event_shape), 0) if dim in symbols))
    equation = ",".join(einsum_inputs) + "->" + einsum_output
    data = opt_einsum.contract(equation, *operands, backend=backend)
    data = data.reshape(batch_shape + event_shape)
    return Tensor(data, inputs)
Ejemplo n.º 3
0
def eager_multinomial(total_count, probs, value):
    # Multinomial.log_prob() supports inhomogeneous total_count only by
    # avoiding passing total_count to the constructor.
    inputs, (total_count, probs, value) = align_tensors(total_count, probs, value)
    shape = broadcast_shape(total_count.shape + (1,), probs.shape, value.shape)
    probs = Tensor(probs.expand(shape), inputs)
    value = Tensor(value.expand(shape), inputs)
    total_count = Number(total_count.max().item())  # Used by distributions validation code.
    return Multinomial.eager_log_prob(total_count, probs, value)
Ejemplo n.º 4
0
    def __init__(self,
                 initial_dist,
                 transition_dist,
                 observation_dist,
                 validate_args=None):
        assert isinstance(initial_dist, torch.distributions.MultivariateNormal)
        assert isinstance(transition_dist,
                          torch.distributions.MultivariateNormal)
        assert isinstance(observation_dist,
                          torch.distributions.MultivariateNormal)
        hidden_dim = initial_dist.event_shape[0]
        assert transition_dist.event_shape[0] == hidden_dim + hidden_dim
        obs_dim = observation_dist.event_shape[0] - hidden_dim
        shape = broadcast_shape(initial_dist.batch_shape + (1, ),
                                transition_dist.batch_shape,
                                observation_dist.batch_shape)
        batch_shape, time_shape = shape[:-1], shape[-1:]
        event_shape = time_shape + (obs_dim, )

        # Convert distributions to funsors.
        init = dist_to_funsor(initial_dist)(value="state")
        trans = mvn_to_funsor(
            transition_dist, ("time", ),
            OrderedDict([("state", Reals[hidden_dim]),
                         ("state(time=1)", Reals[hidden_dim])]))
        obs = mvn_to_funsor(
            observation_dist, ("time", ),
            OrderedDict([("state(time=1)", Reals[hidden_dim]),
                         ("value", Reals[obs_dim])]))

        # Construct the joint funsor.
        # Compare with pyro.distributions.hmm.GaussianMRF.log_prob().
        with interpretation(lazy):
            time = Variable("time", Bint[time_shape[0]])
            value = Variable("value", Reals[time_shape[0], obs_dim])
            logp_oh = trans + obs(value=value["time"])
            logp_oh = MarkovProduct(ops.logaddexp, ops.add, logp_oh, time,
                                    {"state": "state(time=1)"})
            logp_oh += init
            logp_oh = logp_oh.reduce(ops.logaddexp,
                                     frozenset({"state", "state(time=1)"}))
            logp_h = trans + obs.reduce(ops.logaddexp, "value")
            logp_h = MarkovProduct(ops.logaddexp, ops.add, logp_h, time,
                                   {"state": "state(time=1)"})
            logp_h += init
            logp_h = logp_h.reduce(ops.logaddexp,
                                   frozenset({"state", "state(time=1)"}))
            funsor_dist = logp_oh - logp_h

        dtype = "real"
        super(GaussianMRF, self).__init__(funsor_dist, batch_shape,
                                          event_shape, dtype, validate_args)
        self.hidden_dim = hidden_dim
        self.obs_dim = obs_dim
Ejemplo n.º 5
0
def test_switching_linear_hmm_shape(init_cat_shape, init_mvn_shape,
                                    trans_cat_shape, trans_mat_shape, trans_mvn_shape,
                                    obs_mat_shape, obs_mvn_shape):
    hidden_dim, obs_dim = obs_mat_shape[-2:]
    assert trans_mat_shape[-2:] == (hidden_dim, hidden_dim)

    init_logits = torch.randn(init_cat_shape)
    init_mvn = random_mvn(init_mvn_shape, hidden_dim)
    trans_logits = torch.randn(trans_cat_shape)
    trans_matrix = torch.randn(trans_mat_shape)
    trans_mvn = random_mvn(trans_mvn_shape, hidden_dim)
    obs_matrix = torch.randn(obs_mat_shape)
    obs_mvn = random_mvn(obs_mvn_shape, obs_dim)

    init_shape = broadcast_shape(init_cat_shape, init_mvn_shape)
    shape = broadcast_shape(init_shape[:-1] + (1, init_shape[-1]),
                            trans_cat_shape[:-1],
                            trans_mat_shape[:-2],
                            trans_mvn_shape,
                            obs_mat_shape[:-2],
                            obs_mvn_shape)
    expected_batch_shape, time_shape = shape[:-2], shape[-2:-1]
    expected_event_shape = time_shape + (obs_dim,)

    actual_dist = SwitchingLinearHMM(init_logits, init_mvn,
                                     trans_logits, trans_matrix, trans_mvn,
                                     obs_matrix, obs_mvn)
    assert actual_dist.event_shape == expected_event_shape
    assert actual_dist.batch_shape == expected_batch_shape

    data = obs_mvn.expand(shape).sample()[..., 0, :]
    actual_log_prob = actual_dist.log_prob(data)
    assert actual_log_prob.shape == expected_batch_shape
    check_expand(actual_dist, data)

    final_cat, final_mvn = actual_dist.filter(data)
    assert isinstance(final_cat, dist.Categorical)
    assert isinstance(final_mvn, dist.MultivariateNormal)
    assert final_cat.batch_shape == actual_dist.batch_shape
    assert final_mvn.batch_shape == actual_dist.batch_shape + final_cat.logits.shape[-1:]
Ejemplo n.º 6
0
    def __init__(self,
                 initial_dist,
                 transition_matrix,
                 transition_dist,
                 observation_matrix,
                 observation_dist,
                 validate_args=None):
        assert isinstance(initial_dist, torch.distributions.MultivariateNormal)
        assert isinstance(transition_matrix, torch.Tensor)
        assert isinstance(transition_dist,
                          torch.distributions.MultivariateNormal)
        assert isinstance(observation_matrix, torch.Tensor)
        assert isinstance(observation_dist,
                          torch.distributions.MultivariateNormal)
        hidden_dim, obs_dim = observation_matrix.shape[-2:]
        assert obs_dim >= hidden_dim // 2, "obs_dim must be at least half of hidden_dim"
        assert initial_dist.event_shape == (hidden_dim, )
        assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim)
        assert transition_dist.event_shape == (hidden_dim, )
        assert observation_dist.event_shape == (obs_dim, )
        shape = broadcast_shape(initial_dist.batch_shape + (1, ),
                                transition_matrix.shape[:-2],
                                transition_dist.batch_shape,
                                observation_matrix.shape[:-2],
                                observation_dist.batch_shape)
        batch_shape, time_shape = shape[:-1], shape[-1:]
        event_shape = time_shape + (obs_dim, )

        # Convert distributions to funsors.
        init = dist_to_funsor(initial_dist)(value="state")
        trans = matrix_and_mvn_to_funsor(transition_matrix, transition_dist,
                                         ("time", ), "state", "state(time=1)")
        obs = matrix_and_mvn_to_funsor(observation_matrix, observation_dist,
                                       ("time", ), "state(time=1)", "value")
        dtype = "real"

        # Construct the joint funsor.
        with interpretation(lazy):
            value = Variable("value", Reals[time_shape[0], obs_dim])
            result = trans + obs(value=value["time"])
            result = MarkovProduct(ops.logaddexp, ops.add, result, "time",
                                   {"state": "state(time=1)"})
            result = init + result.reduce(ops.logaddexp, "state(time=1)")
            funsor_dist = result.reduce(ops.logaddexp, "state")

        super(GaussianHMM, self).__init__(funsor_dist, batch_shape,
                                          event_shape, dtype, validate_args)
        self.hidden_dim = hidden_dim
        self.obs_dim = obs_dim
Ejemplo n.º 7
0
def test_gaussian_mrf_log_prob(init_shape, trans_shape, obs_shape, hidden_dim, obs_dim):
    init_dist = random_mvn(init_shape, hidden_dim)
    trans_dist = random_mvn(trans_shape, hidden_dim + hidden_dim)
    obs_dist = random_mvn(obs_shape, hidden_dim + obs_dim)

    actual_dist = GaussianMRF(init_dist, trans_dist, obs_dist)
    expected_dist = dist.GaussianMRF(init_dist, trans_dist, obs_dist)
    assert actual_dist.event_shape == expected_dist.event_shape
    assert actual_dist.batch_shape == expected_dist.batch_shape

    batch_shape = broadcast_shape(init_shape + (1,), trans_shape, obs_shape)
    data = obs_dist.expand(batch_shape).sample()[..., hidden_dim:]
    actual_log_prob = actual_dist.log_prob(data)
    expected_log_prob = expected_dist.log_prob(data)
    assert_close(actual_log_prob, expected_log_prob, atol=1e-4, rtol=1e-4)
    check_expand(actual_dist, data)
Ejemplo n.º 8
0
def test_discrete_normal_log_prob(init_shape, trans_shape, obs_shape, state_dim):
    init_logits = torch.randn(init_shape + (state_dim,))
    trans_logits = torch.randn(trans_shape + (state_dim, state_dim))
    loc = torch.randn(obs_shape + (state_dim,))
    scale = torch.randn(obs_shape + (state_dim,)).exp()
    obs_dist = dist.Normal(loc, scale)

    actual_dist = DiscreteHMM(init_logits, trans_logits, obs_dist)
    expected_dist = dist.DiscreteHMM(init_logits, trans_logits, obs_dist)
    assert actual_dist.event_shape == expected_dist.event_shape
    assert actual_dist.batch_shape == expected_dist.batch_shape

    batch_shape = broadcast_shape(init_shape + (1,), trans_shape, obs_shape)
    data = obs_dist.expand(batch_shape + (state_dim,)).sample()
    data = data[(slice(None),) * len(batch_shape) + (0,)]
    actual_log_prob = actual_dist.log_prob(data)
    expected_log_prob = expected_dist.log_prob(data)
    assert_close(actual_log_prob, expected_log_prob, rtol=5e-5)
    check_expand(actual_dist, data)
Ejemplo n.º 9
0
def test_discrete_categorical_log_prob(init_shape, trans_shape, obs_shape, state_dim):
    obs_dim = 4
    init_logits = torch.randn(init_shape + (state_dim,))
    trans_logits = torch.randn(trans_shape + (state_dim, state_dim))
    obs_logits = torch.randn(obs_shape + (state_dim, obs_dim))
    obs_dist = dist.Categorical(logits=obs_logits)

    actual_dist = DiscreteHMM(init_logits, trans_logits, obs_dist)
    expected_dist = dist.DiscreteHMM(init_logits, trans_logits, obs_dist)
    assert actual_dist.event_shape == expected_dist.event_shape
    assert actual_dist.batch_shape == expected_dist.batch_shape

    batch_shape = broadcast_shape(init_shape + (1,), trans_shape, obs_shape)
    data = obs_dist.expand(batch_shape + (state_dim,)).sample()
    data = data[(slice(None),) * len(batch_shape) + (0,)]
    actual_log_prob = actual_dist.log_prob(data)
    expected_log_prob = expected_dist.log_prob(data)
    assert_close(actual_log_prob, expected_log_prob)
    check_expand(actual_dist, data)
Ejemplo n.º 10
0
    def __init__(self,
                 initial_logits,
                 transition_logits,
                 observation_dist,
                 validate_args=None):
        assert isinstance(initial_logits, torch.Tensor)
        assert isinstance(transition_logits, torch.Tensor)
        assert isinstance(observation_dist, torch.distributions.Distribution)
        assert initial_logits.dim() >= 1
        assert transition_logits.dim() >= 2
        assert len(observation_dist.batch_shape) >= 1
        shape = broadcast_shape(initial_logits.shape[:-1] + (1, ),
                                transition_logits.shape[:-2],
                                observation_dist.batch_shape[:-1])
        batch_shape, time_shape = shape[:-1], shape[-1:]
        event_shape = time_shape + observation_dist.event_shape
        self._has_rsample = observation_dist.has_rsample

        # Normalize.
        initial_logits = initial_logits - initial_logits.logsumexp(-1, True)
        transition_logits = transition_logits - transition_logits.logsumexp(
            -1, True)

        # Convert tensors and distributions to funsors.
        init = tensor_to_funsor(initial_logits, ("state", ))
        trans = tensor_to_funsor(transition_logits,
                                 ("time", "state", "state(time=1)"))
        obs = dist_to_funsor(observation_dist, ("time", "state(time=1)"))
        dtype = obs.inputs["value"].dtype

        # Construct the joint funsor.
        with interpretation(lazy):
            # TODO perform math here once sequential_sum_product has been
            #   implemented as a first-class funsor.
            funsor_dist = Variable("value",
                                   obs.inputs["value"])  # a bogus value
            # Until funsor_dist is defined, we save factors for hand-computation in .log_prob().
            self._init = init
            self._trans = trans
            self._obs = obs

        super(DiscreteHMM, self).__init__(funsor_dist, batch_shape,
                                          event_shape, dtype, validate_args)
Ejemplo n.º 11
0
def eager_multinomial(total_count, probs, value):
    # Multinomial.log_prob() supports inhomogeneous total_count only by
    # avoiding passing total_count to the constructor.
    inputs, (total_count, probs,
             value) = align_tensors(total_count, probs, value)
    shape = broadcast_shape(total_count.shape + (1, ), probs.shape,
                            value.shape)
    probs = Tensor(ops.expand(probs, shape), inputs)
    value = Tensor(ops.expand(value, shape), inputs)
    if get_backend() == "torch":
        total_count = Number(
            ops.amax(total_count,
                     None).item())  # Used by distributions validation code.
    else:
        total_count = Tensor(ops.expand(total_count, shape[:-1]), inputs)
    backend_dist = import_module(
        BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()])
    return backend_dist.Multinomial.eager_log_prob(total_count, probs,
                                                   value)  # noqa: F821
Ejemplo n.º 12
0
def test_discrete_mvn_log_prob(init_shape, trans_shape, obs_shape, state_dim):
    event_size = 4
    init_logits = torch.randn(init_shape + (state_dim,))
    trans_logits = torch.randn(trans_shape + (state_dim, state_dim))
    loc = torch.randn(obs_shape + (state_dim, event_size))
    cov = torch.randn(obs_shape + (state_dim, event_size, 2 * event_size))
    cov = cov.matmul(cov.transpose(-1, -2))
    scale_tril = torch.cholesky(cov)
    obs_dist = dist.MultivariateNormal(loc, scale_tril=scale_tril)

    actual_dist = DiscreteHMM(init_logits, trans_logits, obs_dist)
    expected_dist = dist.DiscreteHMM(init_logits, trans_logits, obs_dist)
    assert actual_dist.event_shape == expected_dist.event_shape
    assert actual_dist.batch_shape == expected_dist.batch_shape

    batch_shape = broadcast_shape(init_shape + (1,), trans_shape, obs_shape)
    data = obs_dist.expand(batch_shape + (state_dim,)).sample()
    data = data[(slice(None),) * len(batch_shape) + (0,)]
    actual_log_prob = actual_dist.log_prob(data)
    expected_log_prob = expected_dist.log_prob(data)
    assert_close(actual_log_prob, expected_log_prob)
    check_expand(actual_dist, data)
Ejemplo n.º 13
0
def test_gaussian_hmm_log_prob(init_shape, trans_mat_shape, trans_mvn_shape,
                               obs_mat_shape, obs_mvn_shape, hidden_dim, obs_dim):
    init_dist = random_mvn(init_shape, hidden_dim)
    trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim))
    trans_dist = random_mvn(trans_mvn_shape, hidden_dim)
    obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim))
    obs_dist = random_mvn(obs_mvn_shape, obs_dim)

    actual_dist = GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist)
    expected_dist = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist)
    assert actual_dist.batch_shape == expected_dist.batch_shape
    assert actual_dist.event_shape == expected_dist.event_shape

    shape = broadcast_shape(init_shape + (1,),
                            trans_mat_shape, trans_mvn_shape,
                            obs_mat_shape, obs_mvn_shape)
    data = obs_dist.expand(shape).sample()
    assert data.shape == actual_dist.shape()

    actual_log_prob = actual_dist.log_prob(data)
    expected_log_prob = expected_dist.log_prob(data)
    assert_close(actual_log_prob, expected_log_prob, atol=1e-5, rtol=1e-5)
    check_expand(actual_dist, data)
Ejemplo n.º 14
0
def test_gaussian_hmm_log_prob_null_dynamics(init_shape, trans_mat_shape, trans_mvn_shape,
                                             obs_mvn_shape, hidden_dim):
    obs_dim = hidden_dim
    init_dist = random_mvn(init_shape, hidden_dim)

    # impose null dynamics
    trans_mat = torch.zeros(trans_mat_shape + (hidden_dim, hidden_dim))
    trans_dist = random_mvn(trans_mvn_shape, hidden_dim, diag=True)

    # trivial observation matrix (hidden_dim = obs_dim)
    obs_mat = torch.eye(hidden_dim)
    obs_dist = random_mvn(obs_mvn_shape, obs_dim, diag=True)

    actual_dist = GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist)
    expected_dist = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist)
    assert actual_dist.batch_shape == expected_dist.batch_shape
    assert actual_dist.event_shape == expected_dist.event_shape

    shape = broadcast_shape(init_shape + (1,),
                            trans_mat_shape, trans_mvn_shape,
                            obs_mvn_shape)
    data = obs_dist.expand(shape).sample()
    assert data.shape == actual_dist.shape()

    actual_log_prob = actual_dist.log_prob(data)
    expected_log_prob = expected_dist.log_prob(data)

    assert_close(actual_log_prob, expected_log_prob, atol=1e-5, rtol=1e-5)
    check_expand(actual_dist, data)

    obs_cov = obs_dist.covariance_matrix.diagonal(dim1=-1, dim2=-2)
    trans_cov = trans_dist.covariance_matrix.diagonal(dim1=-1, dim2=-2)
    sum_scale = (obs_cov + trans_cov).sqrt()
    sum_loc = trans_dist.loc + obs_dist.loc

    analytic_log_prob = dist.Normal(sum_loc, sum_scale).log_prob(data).sum(-1).sum(-1)
    assert_close(analytic_log_prob, actual_log_prob, atol=1.0e-5)
Ejemplo n.º 15
0
def matrix_and_mvn_to_funsor(matrix,
                             mvn,
                             event_dims=(),
                             x_name="value_x",
                             y_name="value_y"):
    """
    Convert a noisy affine function to a Gaussian. The noisy affine function is
    defined as::

        y = x @ matrix + mvn.sample()

    The result is a non-normalized Gaussian funsor with two real inputs,
    ``x_name`` and ``y_name``, corresponding to a conditional distribution of
    real vector ``y` given real vector ``x``.

    :param torch.Tensor matrix: A matrix with rightmost shape ``(x_size, y_size)``.
    :param mvn: A multivariate normal distribution with
        ``event_shape == (y_size,)``.
    :type mvn: torch.distributions.MultivariateNormal or
        torch.distributions.Independent of torch.distributions.Normal
    :param tuple event_dims: A tuple of names for rightmost dimensions.
        These will be assigned to ``result.inputs`` of type ``Bint``.
    :param str x_name: The name of the ``x`` random variable.
    :param str y_name: The name of the ``y`` random variable.
    :return: A funsor with given ``real_inputs`` and possibly additional
        Bint inputs.
    :rtype: funsor.terms.Funsor
    """
    assert (isinstance(mvn, torch.distributions.MultivariateNormal)
            or (isinstance(mvn, torch.distributions.Independent)
                and isinstance(mvn.base_dist, torch.distributions.Normal)))
    assert isinstance(matrix, torch.Tensor)
    x_size, y_size = matrix.shape[-2:]
    assert mvn.event_shape == (y_size, )

    # Handle diagonal normal distributions as an efficient special case.
    if isinstance(mvn, torch.distributions.Independent):
        return AffineNormal(
            tensor_to_funsor(matrix, event_dims, 2),
            tensor_to_funsor(mvn.base_dist.loc, event_dims, 1),
            tensor_to_funsor(mvn.base_dist.scale, event_dims, 1),
            Variable(x_name, Reals[x_size]), Variable(y_name, Reals[y_size]))

    info_vec = mvn.loc.unsqueeze(-1).cholesky_solve(mvn.scale_tril).squeeze(-1)
    log_prob = (-0.5 * y_size * math.log(2 * math.pi) -
                mvn.scale_tril.diagonal(dim1=-1, dim2=-2).log().sum(-1) - 0.5 *
                (info_vec * mvn.loc).sum(-1))

    batch_shape = broadcast_shape(matrix.shape[:-2], mvn.batch_shape)
    P_yy = mvn.precision_matrix.expand(batch_shape + (y_size, y_size))
    neg_P_xy = matrix.matmul(P_yy)
    P_xy = -neg_P_xy
    P_yx = P_xy.transpose(-1, -2)
    P_xx = neg_P_xy.matmul(matrix.transpose(-1, -2))
    precision = torch.cat(
        [torch.cat([P_xx, P_xy], -1),
         torch.cat([P_yx, P_yy], -1)], -2)
    info_y = info_vec.expand(batch_shape + (y_size, ))
    info_x = -matrix.matmul(info_y.unsqueeze(-1)).squeeze(-1)
    info_vec = torch.cat([info_x, info_y], -1)

    info_vec = tensor_to_funsor(info_vec, event_dims, 1)
    precision = tensor_to_funsor(precision, event_dims, 2)
    inputs = info_vec.inputs.copy()
    inputs[x_name] = Reals[x_size]
    inputs[y_name] = Reals[y_size]
    return tensor_to_funsor(log_prob, event_dims) + Gaussian(
        info_vec.data, precision.data, inputs)
Ejemplo n.º 16
0
    def _eager_subs_real(self, subs, remaining_subs):
        # Broadcast all component tensors.
        subs = OrderedDict(subs)
        int_inputs = OrderedDict(
            (k, d) for k, d in self.inputs.items() if d.dtype != 'real')
        tensors = [
            Tensor(self.info_vec, int_inputs),
            Tensor(self.precision, int_inputs)
        ]
        tensors.extend(subs.values())
        int_inputs, tensors = align_tensors(*tensors)
        batch_dim = len(tensors[0].shape) - 1
        batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors))
        (info_vec, precision), values = tensors[:2], tensors[2:]
        offsets, event_size = _compute_offsets(self.inputs)
        slices = [(k, slice(offset, offset + self.inputs[k].num_elements))
                  for k, offset in offsets.items()]

        # Expand all substituted values.
        values = OrderedDict(zip(subs, values))
        for k, value in values.items():
            value = value.reshape(value.shape[:batch_dim] + (-1, ))
            if not get_tracing_state():
                assert value.shape[-1] == self.inputs[k].num_elements
            values[k] = ops.expand(value, batch_shape + value.shape[-1:])

        # Try to perform a complete substitution of all real variables, resulting in a Tensor.
        if all(k in subs for k, d in self.inputs.items() if d.dtype == 'real'):
            # Form the concatenated value.
            value = BlockVector(batch_shape + (event_size, ))
            for k, i in slices:
                if k in values:
                    value[..., i] = values[k]
            value = value.as_tensor()

            # Evaluate the non-normalized log density.
            result = _vv(value, info_vec - 0.5 * _mv(precision, value))

            result = Tensor(result, int_inputs)
            assert result.output == Real
            return Subs(result, remaining_subs) if remaining_subs else result

        # Perform a partial substution of a subset of real variables, resulting in a Joint.
        # We split real inputs into two sets: a for the preserved and b for the substituted.
        b = frozenset(k for k, v in subs.items())
        a = frozenset(k for k, d in self.inputs.items()
                      if d.dtype == 'real' and k not in b)
        prec_aa = ops.cat(
            -2, *[
                ops.cat(
                    -1,
                    *[precision[..., i1, i2] for k2, i2 in slices if k2 in a])
                for k1, i1 in slices if k1 in a
            ])
        prec_ab = ops.cat(
            -2, *[
                ops.cat(
                    -1,
                    *[precision[..., i1, i2] for k2, i2 in slices if k2 in b])
                for k1, i1 in slices if k1 in a
            ])
        prec_bb = ops.cat(
            -2, *[
                ops.cat(
                    -1,
                    *[precision[..., i1, i2] for k2, i2 in slices if k2 in b])
                for k1, i1 in slices if k1 in b
            ])
        info_a = ops.cat(-1, *[info_vec[..., i] for k, i in slices if k in a])
        info_b = ops.cat(-1, *[info_vec[..., i] for k, i in slices if k in b])
        value_b = ops.cat(-1, *[values[k] for k, i in slices if k in b])
        info_vec = info_a - _mv(prec_ab, value_b)
        log_scale = _vv(value_b, info_b - 0.5 * _mv(prec_bb, value_b))
        precision = ops.expand(prec_aa, info_vec.shape + info_vec.shape[-1:])
        inputs = int_inputs.copy()
        for k, d in self.inputs.items():
            if k not in subs:
                inputs[k] = d
        result = Gaussian(info_vec, precision, inputs) + Tensor(
            log_scale, int_inputs)
        return Subs(result, remaining_subs) if remaining_subs else result
Ejemplo n.º 17
0
    def __init__(self,
                 initial_logits,
                 initial_mvn,
                 transition_logits,
                 transition_matrix,
                 transition_mvn,
                 observation_matrix,
                 observation_mvn,
                 exact=False,
                 validate_args=None):
        assert isinstance(initial_logits, torch.Tensor)
        assert isinstance(initial_mvn, torch.distributions.MultivariateNormal)
        assert isinstance(transition_logits, torch.Tensor)
        assert isinstance(transition_matrix, torch.Tensor)
        assert isinstance(transition_mvn,
                          torch.distributions.MultivariateNormal)
        assert isinstance(observation_matrix, torch.Tensor)
        assert isinstance(observation_mvn,
                          torch.distributions.MultivariateNormal)
        hidden_cardinality = initial_logits.size(-1)
        hidden_dim, obs_dim = observation_matrix.shape[-2:]
        assert obs_dim >= hidden_dim // 2, "obs_dim must be at least half of hidden_dim"
        assert initial_mvn.event_shape[0] == hidden_dim
        assert transition_logits.size(-1) == hidden_cardinality
        assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim)
        assert transition_mvn.event_shape[0] == hidden_dim
        assert observation_mvn.event_shape[0] == obs_dim
        init_shape = broadcast_shape(initial_logits.shape,
                                     initial_mvn.batch_shape)
        shape = broadcast_shape(init_shape[:-1] + (1, init_shape[-1]),
                                transition_logits.shape[:-1],
                                transition_matrix.shape[:-2],
                                transition_mvn.batch_shape,
                                observation_matrix.shape[:-2],
                                observation_mvn.batch_shape)
        assert shape[-1] == hidden_cardinality
        batch_shape, time_shape = shape[:-2], shape[-2:-1]
        event_shape = time_shape + (obs_dim, )

        # Normalize.
        initial_logits = initial_logits - initial_logits.logsumexp(-1, True)
        transition_logits = transition_logits - transition_logits.logsumexp(
            -1, True)

        # Convert tensors and distributions to funsors.
        init = (tensor_to_funsor(initial_logits, ("class", )) +
                dist_to_funsor(initial_mvn, ("class", ))(value="state"))
        trans = (tensor_to_funsor(transition_logits,
                                  ("time", "class", "class(time=1)")) +
                 matrix_and_mvn_to_funsor(transition_matrix, transition_mvn,
                                          ("time", "class(time=1)"), "state",
                                          "state(time=1)"))
        obs = matrix_and_mvn_to_funsor(observation_matrix, observation_mvn,
                                       ("time", "class(time=1)"),
                                       "state(time=1)", "value")
        if "class(time=1)" not in set(trans.inputs).union(obs.inputs):
            raise ValueError(
                "neither transition nor observation depend on discrete state")
        dtype = "real"

        # Construct the joint funsor.
        with interpretation(lazy):
            # TODO perform math here once sequential_sum_product has been
            #   implemented as a first-class funsor.
            funsor_dist = Variable("value",
                                   obs.inputs["value"])  # a bogus value
            # Until funsor_dist is defined, we save factors for hand-computation in .log_prob().
            self._init = init
            self._trans = trans
            self._obs = obs

        super(SwitchingLinearHMM,
              self).__init__(funsor_dist, batch_shape, event_shape, dtype,
                             validate_args)
        self.exact = exact