Esempio n. 1
0
def test_mvn_to_funsor(batch_shape, event_shape, event_sizes):
    event_size = sum(event_sizes)
    mvn = random_mvn(batch_shape + event_shape, event_size)
    int_inputs = OrderedDict(
        (k, bint(size)) for k, size in zip("abc", event_shape))
    real_inputs = OrderedDict(
        (k, reals(size)) for k, size in zip("xyz", event_sizes))

    f = mvn_to_funsor(mvn, tuple(int_inputs), real_inputs)
    assert isinstance(f, Funsor)
    for k, d in int_inputs.items():
        if d.num_elements == 1:
            assert d not in f.inputs
        else:
            assert k in f.inputs
            assert f.inputs[k] == d
    for k, d in real_inputs.items():
        assert k in f.inputs
        assert f.inputs[k] == d

    value = mvn.sample()
    subs = {}
    beg = 0
    for k, d in real_inputs.items():
        end = beg + d.num_elements
        subs[k] = tensor_to_funsor(value[..., beg:end], tuple(int_inputs), 1)
        beg = end
    actual_log_prob = f(**subs)
    expected_log_prob = tensor_to_funsor(mvn.log_prob(value),
                                         tuple(int_inputs))
    assert_close(actual_log_prob, expected_log_prob, atol=1e-5, rtol=1e-5)
Esempio n. 2
0
def test_dist_to_funsor_bernoulli(batch_shape):
    logits = torch.randn(batch_shape)
    d = dist.Bernoulli(logits=logits)
    f = dist_to_funsor(d)
    assert isinstance(f, Funsor)

    value = d.sample()
    actual_log_prob = f(value=tensor_to_funsor(value))
    expected_log_prob = tensor_to_funsor(d.log_prob(value))
    assert_close(actual_log_prob, expected_log_prob)
Esempio n. 3
0
def test_dist_to_funsor_normal(batch_shape):
    loc = torch.randn(batch_shape)
    scale = torch.randn(batch_shape).exp()
    d = dist.Normal(loc, scale)
    f = dist_to_funsor(d)
    assert isinstance(f, Funsor)

    value = d.sample()
    actual_log_prob = f(value=tensor_to_funsor(value))
    expected_log_prob = tensor_to_funsor(d.log_prob(value))
    assert_close(actual_log_prob, expected_log_prob, rtol=1e-5)
Esempio n. 4
0
def test_dist_to_funsor_independent(batch_shape, event_shape):
    loc = torch.randn(batch_shape + event_shape)
    scale = torch.randn(batch_shape + event_shape).exp()
    d = dist.Normal(loc, scale).to_event(len(event_shape))
    f = dist_to_funsor(d)
    assert isinstance(f, Funsor)

    value = d.sample()
    funsor_value = tensor_to_funsor(value, event_output=len(event_shape))
    actual_log_prob = f(value=funsor_value)
    expected_log_prob = tensor_to_funsor(d.log_prob(value))
    assert_close(actual_log_prob, expected_log_prob, rtol=1e-5)
Esempio n. 5
0
def test_dist_to_funsor_mvn(batch_shape, event_size):
    loc = torch.randn(batch_shape + (event_size, ))
    cov = torch.randn(batch_shape + (event_size, 2 * event_size))
    cov = cov.matmul(cov.transpose(-1, -2))
    scale_tril = torch.cholesky(cov)
    d = dist.MultivariateNormal(loc, scale_tril=scale_tril)
    f = dist_to_funsor(d)
    assert isinstance(f, Funsor)

    value = d.sample()
    actual_log_prob = f(value=tensor_to_funsor(value, event_output=1))
    expected_log_prob = tensor_to_funsor(d.log_prob(value))
    assert_close(actual_log_prob, expected_log_prob)
Esempio n. 6
0
def test_dist_to_funsor_masked(batch_shape):
    loc = torch.randn(batch_shape)
    scale = torch.randn(batch_shape).exp()
    mask = torch.bernoulli(torch.full(batch_shape, 0.5)).byte()
    d = dist.Normal(loc, scale).mask(mask)
    assert isinstance(d, MaskedDistribution)
    f = dist_to_funsor(d)
    assert isinstance(f, Funsor)

    value = d.sample()
    actual_log_prob = f(value=tensor_to_funsor(value))
    expected_log_prob = tensor_to_funsor(d.log_prob(value))
    assert_close(actual_log_prob, expected_log_prob)
Esempio n. 7
0
 def __init__(self, logits, validate_args=None):
     batch_shape = logits.shape[:-1]
     event_shape = torch.Size()
     funsor_dist = tensor_to_funsor(logits, ("value", ))
     dtype = int(logits.size(-1))
     super(Categorical, self).__init__(funsor_dist, batch_shape,
                                       event_shape, dtype, validate_args)
Esempio n. 8
0
    def filter(self, value):
        """
        Compute posterior over final state given a sequence of observations.

        :param ~torch.Tensor value: A sequence of observations.
        :return: A posterior distribution over latent states at the final time
            step, represented as a pair ``(cat, mvn)``, where
            :class:`~pyro.distributions.Categorical` distribution over mixture
            components and ``mvn`` is a
            :class:`~pyro.distributions.MultivariateNormal` with rightmost
            batch dimension ranging over mixture components. This can then be
            used to initialize a sequential Pyro model for prediction.
        :rtype: tuple
        """
        ndims = max(len(self.batch_shape), value.dim() - 2)
        time = Variable("time", Bint[self.event_shape[0]])
        value = tensor_to_funsor(value, ("time", ), 1)

        seq_sum_prod = naive_sequential_sum_product if self.exact else sequential_sum_product
        with interpretation(eager if self.exact else moment_matching):
            logp = self._trans + self._obs(value=value)
            logp = seq_sum_prod(ops.logaddexp, ops.add, logp, time, {
                "class": "class(time=1)",
                "state": "state(time=1)"
            })
            logp += self._init
            logp = logp.reduce(ops.logaddexp, frozenset(["class", "state"]))

        cat, mvn = funsor_to_cat_and_mvn(logp, ndims, ("class(time=1)", ))
        cat = cat.expand(self.batch_shape)
        mvn = mvn.expand(self.batch_shape + cat.logits.shape[-1:])
        return cat, mvn
Esempio n. 9
0
def test_matrix_and_mvn_to_funsor(batch_shape, event_shape, x_size, y_size):
    matrix = torch.randn(batch_shape + event_shape + (x_size, y_size))
    y_mvn = random_mvn(batch_shape + event_shape, y_size)
    xy_mvn = random_mvn(batch_shape + event_shape, x_size + y_size)
    int_inputs = OrderedDict(
        (k, bint(size)) for k, size in zip("abc", event_shape))
    real_inputs = OrderedDict([("x", reals(x_size)), ("y", reals(y_size))])

    f = (matrix_and_mvn_to_funsor(matrix, y_mvn, tuple(int_inputs), "x", "y") +
         mvn_to_funsor(xy_mvn, tuple(int_inputs), real_inputs))
    assert isinstance(f, Funsor)
    for k, d in int_inputs.items():
        if d.num_elements == 1:
            assert d not in f.inputs
        else:
            assert k in f.inputs
            assert f.inputs[k] == d
    assert f.inputs["x"] == reals(x_size)
    assert f.inputs["y"] == reals(y_size)

    xy = torch.randn(x_size + y_size)
    x, y = xy[:x_size], xy[x_size:]
    y_pred = x.unsqueeze(-2).matmul(matrix).squeeze(-2)
    actual_log_prob = f(x=x, y=y)
    expected_log_prob = tensor_to_funsor(
        xy_mvn.log_prob(xy) + y_mvn.log_prob(y - y_pred), tuple(int_inputs))
    assert_close(actual_log_prob, expected_log_prob, atol=1e-4, rtol=1e-4)
Esempio n. 10
0
def test_dist_to_funsor_categorical(batch_shape, cardinality):
    logits = torch.randn(batch_shape + (cardinality, ))
    logits -= logits.logsumexp(dim=-1, keepdim=True)
    d = dist.Categorical(logits=logits)
    f = dist_to_funsor(d)
    assert isinstance(f, Tensor)
    expected = tensor_to_funsor(logits, ("value", ))
    assert_close(f, expected)
Esempio n. 11
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     ndims = max(len(self.batch_shape), value.dim() - self.event_dim)
     value = tensor_to_funsor(value,
                              event_output=self.event_dim,
                              dtype=self.dtype)
     log_prob = self.funsor_dist(value=value)
     log_prob = funsor_to_tensor(log_prob, ndims=ndims)
     return log_prob
Esempio n. 12
0
    def __init__(self,
                 initial_logits,
                 transition_logits,
                 observation_dist,
                 validate_args=None):
        assert isinstance(initial_logits, torch.Tensor)
        assert isinstance(transition_logits, torch.Tensor)
        assert isinstance(observation_dist, torch.distributions.Distribution)
        assert initial_logits.dim() >= 1
        assert transition_logits.dim() >= 2
        assert len(observation_dist.batch_shape) >= 1
        shape = broadcast_shape(initial_logits.shape[:-1] + (1, ),
                                transition_logits.shape[:-2],
                                observation_dist.batch_shape[:-1])
        batch_shape, time_shape = shape[:-1], shape[-1:]
        event_shape = time_shape + observation_dist.event_shape
        self._has_rsample = observation_dist.has_rsample

        # Normalize.
        initial_logits = initial_logits - initial_logits.logsumexp(-1, True)
        transition_logits = transition_logits - transition_logits.logsumexp(
            -1, True)

        # Convert tensors and distributions to funsors.
        init = tensor_to_funsor(initial_logits, ("state", ))
        trans = tensor_to_funsor(transition_logits,
                                 ("time", "state", "state(time=1)"))
        obs = dist_to_funsor(observation_dist, ("time", "state(time=1)"))
        dtype = obs.inputs["value"].dtype

        # Construct the joint funsor.
        with interpretation(lazy):
            # TODO perform math here once sequential_sum_product has been
            #   implemented as a first-class funsor.
            funsor_dist = Variable("value",
                                   obs.inputs["value"])  # a bogus value
            # Until funsor_dist is defined, we save factors for hand-computation in .log_prob().
            self._init = init
            self._trans = trans
            self._obs = obs

        super(DiscreteHMM, self).__init__(funsor_dist, batch_shape,
                                          event_shape, dtype, validate_args)
Esempio n. 13
0
 def expand(self, batch_shape, _instance=None):
     new = self._get_checked_instance(type(self), _instance)
     batch_shape = torch.Size(batch_shape)
     funsor_dist = self.funsor_dist + tensor_to_funsor(
         torch.zeros(batch_shape))
     super(type(self), new).__init__(funsor_dist,
                                     batch_shape,
                                     self.event_shape,
                                     self.dtype,
                                     validate_args=False)
     new.validate_args = self.__dict__.get('_validate_args')
     return new
Esempio n. 14
0
def test_matrix_and_mvn_to_funsor_diag(batch_shape, x_size, y_size):
    matrix = torch.randn(batch_shape + (x_size, y_size))
    loc = torch.randn(batch_shape + (y_size, ))
    scale = torch.randn(batch_shape + (y_size, )).exp()

    normal = dist.Normal(loc, scale).to_event(1)
    actual = matrix_and_mvn_to_funsor(matrix, normal)
    assert isinstance(actual, AffineNormal)

    mvn = dist.MultivariateNormal(loc, scale_tril=scale.diag_embed())
    expected = matrix_and_mvn_to_funsor(matrix, mvn)

    y = tensor_to_funsor(torch.randn(batch_shape + (y_size, )), (), 1)
    actual_like = actual(value_y=y)
    expected_like = expected(value_y=y)
    assert_close(actual_like, expected_like, atol=1e-4, rtol=1e-4)

    x = tensor_to_funsor(torch.randn(batch_shape + (x_size, )), (), 1)
    actual_norm = actual_like(value_x=x)
    expected_norm = expected_like(value_x=x)
    assert_close(actual_norm, expected_norm, atol=1e-4, rtol=1e-4)
Esempio n. 15
0
    def _forward_funsor(self, features, trip_counts):
        total_hours = len(features)
        observed_hours, num_origins, num_destins = trip_counts.shape
        assert observed_hours == total_hours
        assert num_origins == self.num_stations
        assert num_destins == self.num_stations
        n = self.num_stations
        gate_rate = funsor.Variable("gate_rate_t",
                                    reals(observed_hours, 2 * n * n))["time"]

        @funsor.torch.function(reals(2 * n * n), (reals(n, n, 2), reals(n, n)))
        def unpack_gate_rate(gate_rate):
            batch_shape = gate_rate.shape[:-1]
            gate, rate = gate_rate.reshape(batch_shape + (2, n, n)).unbind(-3)
            gate = gate.sigmoid().clamp(min=0.01, max=0.99)
            rate = bounded_exp(rate, bound=1e4)
            gate = torch.stack((1 - gate, gate), dim=-1)
            return gate, rate

        # Create a Gaussian latent dynamical system.
        init_dist, trans_matrix, trans_dist, obs_matrix, obs_dist = \
            self._dynamics(features[:observed_hours])
        init = dist_to_funsor(init_dist)(value="state")
        trans = matrix_and_mvn_to_funsor(trans_matrix, trans_dist, ("time", ),
                                         "state", "state(time=1)")
        obs = matrix_and_mvn_to_funsor(obs_matrix, obs_dist, ("time", ),
                                       "state(time=1)", "gate_rate")

        # Compute dynamic prior over gate_rate.
        prior = trans + obs(gate_rate=gate_rate)
        prior = MarkovProduct(ops.logaddexp, ops.add, prior, "time",
                              {"state": "state(time=1)"})
        prior += init
        prior = prior.reduce(ops.logaddexp, {"state", "state(time=1)"})

        # Compute zero-inflated Poisson likelihood.
        gate, rate = unpack_gate_rate(gate_rate)
        likelihood = fdist.Categorical(gate["origin", "destin"], value="gated")
        trip_counts = tensor_to_funsor(trip_counts,
                                       ("time", "origin", "destin"))
        likelihood += funsor.Stack(
            "gated",
            (fdist.Poisson(rate["origin", "destin"], value=trip_counts),
             fdist.Delta(0, value=trip_counts)))
        likelihood = likelihood.reduce(ops.logaddexp, "gated")
        likelihood = likelihood.reduce(ops.add, {"time", "origin", "destin"})

        assert set(prior.inputs) == {"gate_rate_t"}, prior.inputs
        assert set(likelihood.inputs) == {"gate_rate_t"}, likelihood.inputs
        return prior, likelihood
Esempio n. 16
0
 def expand(self, batch_shape, _instance=None):
     new = self._get_checked_instance(SwitchingLinearHMM, _instance)
     batch_shape = torch.Size(batch_shape)
     new._init = self._init + tensor_to_funsor(torch.zeros(batch_shape))
     new._trans = self._trans
     new._obs = self._obs
     new.exact = self.exact
     super(SwitchingLinearHMM, new).__init__(self.funsor_dist,
                                             batch_shape,
                                             self.event_shape,
                                             self.dtype,
                                             validate_args=False)
     new.validate_args = self.__dict__.get('_validate_args')
     return new
Esempio n. 17
0
    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        ndims = max(len(self.batch_shape), value.dim() - self.event_dim)
        time = Variable("time", Bint[self.event_shape[0]])
        value = tensor_to_funsor(value, ("time", ),
                                 event_output=self.event_dim - 1,
                                 dtype=self.dtype)

        # Compare with pyro.distributions.hmm.DiscreteHMM.log_prob().
        obs = self._obs(value=value)
        result = self._trans + obs
        result = sequential_sum_product(ops.logaddexp, ops.add, result, time,
                                        {"state": "state(time=1)"})
        result = self._init + result.reduce(ops.logaddexp, "state(time=1)")
        result = result.reduce(ops.logaddexp, "state")

        result = funsor_to_tensor(result, ndims=ndims)
        return result
Esempio n. 18
0
    def log_prob(self, value):
        ndims = max(len(self.batch_shape), value.dim() - 2)
        time = Variable("time", Bint[self.event_shape[0]])
        value = tensor_to_funsor(value, ("time", ), 1)

        seq_sum_prod = naive_sequential_sum_product if self.exact else sequential_sum_product
        with interpretation(eager if self.exact else moment_matching):
            result = self._trans + self._obs(value=value)
            result = seq_sum_prod(ops.logaddexp, ops.add, result, time, {
                "class": "class(time=1)",
                "state": "state(time=1)"
            })
            result += self._init
            result = result.reduce(
                ops.logaddexp,
                frozenset(["class", "state", "class(time=1)",
                           "state(time=1)"]))

            result = funsor_to_tensor(result, ndims=ndims)
            return result
Esempio n. 19
0
def test_funsor_to_cat_and_mvn(batch_shape, event_shape, int_size, real_size):
    logits = torch.randn(batch_shape + event_shape + (int_size, ))
    expected_cat = dist.Categorical(logits=logits)
    expected_mvn = random_mvn(batch_shape + event_shape + (int_size, ),
                              real_size)
    event_dims = tuple("abc"[:len(event_shape)]) + ("component", )
    ndims = len(expected_cat.batch_shape)

    funsor_ = (tensor_to_funsor(logits, event_dims) +
               dist_to_funsor(expected_mvn, event_dims)(value="value"))
    assert isinstance(funsor_, Funsor)

    actual_cat, actual_mvn = funsor_to_cat_and_mvn(funsor_, ndims, event_dims)
    assert isinstance(actual_cat, dist.Categorical)
    assert isinstance(actual_mvn, dist.MultivariateNormal)
    assert actual_cat.batch_shape == expected_cat.batch_shape
    assert actual_mvn.batch_shape == expected_mvn.batch_shape
    assert_close(actual_cat.logits, expected_cat.logits, atol=1e-4, rtol=None)
    assert_close(actual_mvn.loc, expected_mvn.loc, atol=1e-4, rtol=None)
    assert_close(actual_mvn.precision_matrix,
                 expected_mvn.precision_matrix,
                 atol=1e-4,
                 rtol=None)
Esempio n. 20
0
def test_tensor_funsor_tensor(batch_shape, event_shape, event_output):
    event_inputs = ("foo", "bar", "baz")[:len(event_shape) - event_output]
    t = torch.randn(batch_shape + event_shape)
    f = tensor_to_funsor(t, event_inputs, event_output)
    t2 = funsor_to_tensor(f, t.dim(), event_inputs)
    assert_close(t2, t)
Esempio n. 21
0
    def __init__(self,
                 initial_logits,
                 initial_mvn,
                 transition_logits,
                 transition_matrix,
                 transition_mvn,
                 observation_matrix,
                 observation_mvn,
                 exact=False,
                 validate_args=None):
        assert isinstance(initial_logits, torch.Tensor)
        assert isinstance(initial_mvn, torch.distributions.MultivariateNormal)
        assert isinstance(transition_logits, torch.Tensor)
        assert isinstance(transition_matrix, torch.Tensor)
        assert isinstance(transition_mvn,
                          torch.distributions.MultivariateNormal)
        assert isinstance(observation_matrix, torch.Tensor)
        assert isinstance(observation_mvn,
                          torch.distributions.MultivariateNormal)
        hidden_cardinality = initial_logits.size(-1)
        hidden_dim, obs_dim = observation_matrix.shape[-2:]
        assert obs_dim >= hidden_dim // 2, "obs_dim must be at least half of hidden_dim"
        assert initial_mvn.event_shape[0] == hidden_dim
        assert transition_logits.size(-1) == hidden_cardinality
        assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim)
        assert transition_mvn.event_shape[0] == hidden_dim
        assert observation_mvn.event_shape[0] == obs_dim
        init_shape = broadcast_shape(initial_logits.shape,
                                     initial_mvn.batch_shape)
        shape = broadcast_shape(init_shape[:-1] + (1, init_shape[-1]),
                                transition_logits.shape[:-1],
                                transition_matrix.shape[:-2],
                                transition_mvn.batch_shape,
                                observation_matrix.shape[:-2],
                                observation_mvn.batch_shape)
        assert shape[-1] == hidden_cardinality
        batch_shape, time_shape = shape[:-2], shape[-2:-1]
        event_shape = time_shape + (obs_dim, )

        # Normalize.
        initial_logits = initial_logits - initial_logits.logsumexp(-1, True)
        transition_logits = transition_logits - transition_logits.logsumexp(
            -1, True)

        # Convert tensors and distributions to funsors.
        init = (tensor_to_funsor(initial_logits, ("class", )) +
                dist_to_funsor(initial_mvn, ("class", ))(value="state"))
        trans = (tensor_to_funsor(transition_logits,
                                  ("time", "class", "class(time=1)")) +
                 matrix_and_mvn_to_funsor(transition_matrix, transition_mvn,
                                          ("time", "class(time=1)"), "state",
                                          "state(time=1)"))
        obs = matrix_and_mvn_to_funsor(observation_matrix, observation_mvn,
                                       ("time", "class(time=1)"),
                                       "state(time=1)", "value")
        if "class(time=1)" not in set(trans.inputs).union(obs.inputs):
            raise ValueError(
                "neither transition nor observation depend on discrete state")
        dtype = "real"

        # Construct the joint funsor.
        with interpretation(lazy):
            # TODO perform math here once sequential_sum_product has been
            #   implemented as a first-class funsor.
            funsor_dist = Variable("value",
                                   obs.inputs["value"])  # a bogus value
            # Until funsor_dist is defined, we save factors for hand-computation in .log_prob().
            self._init = init
            self._trans = trans
            self._obs = obs

        super(SwitchingLinearHMM,
              self).__init__(funsor_dist, batch_shape, event_shape, dtype,
                             validate_args)
        self.exact = exact