Пример #1
0
    def testing():

        with markov():
            v1 = to_data(
                Tensor(jnp.ones(2), OrderedDict([("1", bint(2))]), 'real'))
            print(1, v1.shape)  # shapes should alternate
            assert v1.shape == (2, )

            with markov():
                v2 = to_data(
                    Tensor(jnp.ones(2), OrderedDict([("2", bint(2))]), 'real'))
                print(2, v2.shape)  # shapes should alternate
                assert v2.shape == (2, 1)

                with markov():
                    v3 = to_data(
                        Tensor(jnp.ones(2), OrderedDict([("3", bint(2))]),
                               'real'))
                    print(3, v3.shape)  # shapes should alternate
                    assert v3.shape == (2, )

                    with markov():
                        v4 = to_data(
                            Tensor(jnp.ones(2), OrderedDict([("4", bint(2))]),
                                   'real'))
                        print(4, v4.shape)  # shapes should alternate

                        assert v4.shape == (2, 1)
Пример #2
0
    def model(data):
        log_prob = funsor.to_funsor(0.)

        trans = dist.Categorical(probs=funsor.Tensor(
            trans_probs,
            inputs=OrderedDict([('prev', funsor.bint(args.hidden_dim))]),
        ))

        emit = dist.Categorical(probs=funsor.Tensor(
            emit_probs,
            inputs=OrderedDict([('latent', funsor.bint(args.hidden_dim))]),
        ))

        x_curr = funsor.Number(0, args.hidden_dim)
        for t, y in enumerate(data):
            x_prev = x_curr

            # A delayed sample statement.
            x_curr = funsor.Variable('x_{}'.format(t),
                                     funsor.bint(args.hidden_dim))
            log_prob += trans(prev=x_prev, value=x_curr)

            if not args.lazy and isinstance(x_prev, funsor.Variable):
                log_prob = log_prob.reduce(ops.logaddexp, x_prev.name)

            log_prob += emit(latent=x_curr, value=funsor.Tensor(y, dtype=2))

        log_prob = log_prob.reduce(ops.logaddexp)
        return log_prob
Пример #3
0
 def expand_inputs(self, name, size):
     if name in self.funsor_dist.inputs:
         assert self.funsor_dist.inputs[name] == funsor.bint(int(size))
         return self
     inputs = OrderedDict([(name, funsor.bint(int(size)))])
     if self.sample_inputs:
         inputs.update(self.sample_inputs)
     return Distribution(self.funsor_dist, sample_inputs=inputs)
Пример #4
0
 def expand_inputs(self, name, size):
     if name in self.funsor_dist.inputs:
         assert self.funsor_dist.inputs[name] == funsor.bint(int(size))
         return self
     inputs = OrderedDict([(name, funsor.bint(int(size)))])
     funsor_dist = self.funsor_dist + funsor.torch.Tensor(
         torch.zeros(size), inputs)
     return Distribution(funsor_dist)
Пример #5
0
def test_align():
    x = Array(np.random.randn(2, 3, 4),
              OrderedDict([
                  ('i', bint(2)),
                  ('j', bint(3)),
                  ('k', bint(4)),
              ]))
    y = x.align(('j', 'k', 'i'))
    assert isinstance(y, Array)
    assert tuple(y.inputs) == ('j', 'k', 'i')
    for i in range(2):
        for j in range(3):
            for k in range(4):
                assert x(i=i, j=j, k=k) == y(i=i, j=j, k=k)
Пример #6
0
 def testing():
     for i in markov(range(5)):
         v1 = to_data(Tensor(jnp.ones(2), OrderedDict([(str(i), bint(2))]), 'real'))
         v2 = to_data(Tensor(jnp.zeros(2), OrderedDict([('a', bint(2))]), 'real'))
         fv1 = to_funsor(v1, reals())
         fv2 = to_funsor(v2, reals())
         print(i, v1.shape)  # shapes should alternate
         if i % 2 == 0:
             assert v1.shape == (2,)
         else:
             assert v1.shape == (2, 1, 1)
         assert v2.shape == (2, 1)
         print(i, fv1.inputs)
         print('a', v2.shape)  # shapes should stay the same
         print('a', fv2.inputs)
Пример #7
0
    def get_tensors_and_dists(self):
        # normalize the transition probabilities
        trans_logits = self.transition_logits - self.transition_logits.logsumexp(
            dim=-1, keepdim=True)
        trans_probs = funsor.Tensor(
            trans_logits,
            OrderedDict([("s", funsor.bint(self.num_components))]))

        trans_mvn = torch.distributions.MultivariateNormal(
            torch.zeros(self.hidden_dim),
            self.log_transition_noise.exp().diag_embed())
        obs_mvn = torch.distributions.MultivariateNormal(
            torch.zeros(self.obs_dim),
            self.log_obs_noise.exp().diag_embed())

        event_dims = (
            "s",
        ) if self.fine_transition_matrix or self.fine_transition_noise else ()
        x_trans_dist = matrix_and_mvn_to_funsor(self.transition_matrix,
                                                trans_mvn, event_dims, "x",
                                                "y")
        event_dims = (
            "s",
        ) if self.fine_observation_matrix or self.fine_observation_noise else (
        )
        y_dist = matrix_and_mvn_to_funsor(self.observation_matrix, obs_mvn,
                                          event_dims, "x", "y")

        return trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist
Пример #8
0
    def model(data):
        log_prob = funsor.Number(0.)

        # s is the discrete latent state,
        # x is the continuous latent state,
        # y is the observed state.
        s_curr = funsor.Tensor(torch.tensor(0), dtype=2)
        x_curr = funsor.Tensor(torch.tensor(0.))
        for t, y in enumerate(data):
            s_prev = s_curr
            x_prev = x_curr

            # A delayed sample statement.
            s_curr = funsor.Variable('s_{}'.format(t), funsor.bint(2))
            log_prob += dist.Categorical(trans_probs[s_prev], value=s_curr)

            # A delayed sample statement.
            x_curr = funsor.Variable('x_{}'.format(t), funsor.reals())
            log_prob += dist.Normal(x_prev, trans_noise[s_curr], value=x_curr)

            # Marginalize out previous delayed sample statements.
            if t > 0:
                log_prob = log_prob.reduce(ops.logaddexp,
                                           {s_prev.name, x_prev.name})

            # An observe statement.
            log_prob += dist.Normal(x_curr, emit_noise, value=y)

        log_prob = log_prob.reduce(ops.logaddexp)
        return log_prob
Пример #9
0
 def testing():
     for i in markov(range(12)):
         if i % 4 == 0:
             v2 = to_data(Tensor(jnp.zeros(2), OrderedDict([('a', bint(2))]), 'real'))
             fv2 = to_funsor(v2, reals())
             assert v2.shape == (2,)
             print('a', v2.shape)
             print('a', fv2.inputs)
Пример #10
0
    def log_prob(self, data):
        trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist = self.get_tensors_and_dists(
        )

        log_prob = funsor.Number(0.)

        s_vars = {
            -1: funsor.Tensor(torch.tensor(0), dtype=self.num_components)
        }
        x_vars = {}

        for t, y in enumerate(data):
            # construct free variables for s_t and x_t
            s_vars[t] = funsor.Variable(f's_{t}',
                                        funsor.bint(self.num_components))
            x_vars[t] = funsor.Variable(f'x_{t}',
                                        funsor.reals(self.hidden_dim))

            # incorporate the discrete switching dynamics
            log_prob += dist.Categorical(trans_probs(s=s_vars[t - 1]),
                                         value=s_vars[t])

            # incorporate the prior term p(x_t | x_{t-1})
            if t == 0:
                log_prob += self.x_init_mvn(value=x_vars[t])
            else:
                log_prob += x_trans_dist(s=s_vars[t],
                                         x=x_vars[t - 1],
                                         y=x_vars[t])

            # do a moment-matching reduction. at this point log_prob depends on (moment_matching_lag + 1)-many
            # pairs of free variables.
            if t > self.moment_matching_lag - 1:
                log_prob = log_prob.reduce(
                    ops.logaddexp,
                    frozenset([
                        s_vars[t - self.moment_matching_lag].name,
                        x_vars[t - self.moment_matching_lag].name
                    ]))

            # incorporate the observation p(y_t | x_t, s_t)
            log_prob += y_dist(s=s_vars[t], x=x_vars[t], y=y)

        T = data.shape[0]
        # reduce any remaining free variables
        for t in range(self.moment_matching_lag):
            log_prob = log_prob.reduce(
                ops.logaddexp,
                frozenset([
                    s_vars[T - self.moment_matching_lag + t].name,
                    x_vars[T - self.moment_matching_lag + t].name
                ]))

        # assert that we've reduced all the free variables in log_prob
        assert not log_prob.inputs, 'unexpected free variables remain'

        # return the PyTorch tensor behind log_prob (which we can directly differentiate)
        return log_prob.data
Пример #11
0
def main(args):
    # Generate fake data.
    data = funsor.Tensor(torch.randn(100),
                         inputs=OrderedDict([('data', funsor.bint(100))]),
                         output=funsor.reals())

    # Train.
    optim = pyro.Adam({'lr': args.learning_rate})
    svi = pyro.SVI(model, pyro.deferred(guide), optim, pyro.elbo)
    for step in range(args.steps):
        svi.step(data)
Пример #12
0
 def __init__(self, name, size, subsample_size=None, dim=None):
     self.name = name
     self.size = size
     self.subsample_size = size if subsample_size is None else subsample_size
     if dim is not None and dim >= 0:
         raise ValueError('dim arg must be negative.')
     self.dim = dim
     self._indices = funsor.Tensor(
         funsor.ops.new_arange(funsor.tensor.get_default_prototype(),
                               self.size),
         OrderedDict([(self.name, funsor.bint(self.size))]), self.size)
     super(plate, self).__init__(None)
Пример #13
0
 def __init__(self, name, size, subsample_size=None, dim=None):
     self.name = name
     self.size = size
     if dim is not None and dim >= 0:
         raise ValueError('dim arg must be negative.')
     self.dim, indices = OrigPlateMessenger._subsample(
         self.name, self.size, subsample_size, dim)
     self.subsample_size = indices.shape[0]
     self._indices = funsor.Tensor(
         indices,
         OrderedDict([(self.name, funsor.bint(self.subsample_size))]),
         self.subsample_size)
     super(plate, self).__init__(None)
Пример #14
0
def test_indexing():
    data = np.random.normal(size=(4, 5))
    inputs = OrderedDict([('i', bint(4)), ('j', bint(5))])
    x = Array(data, inputs)
    check_funsor(x, inputs, reals(), data)

    assert x() is x
    assert x(k=3) is x
    check_funsor(x(1), {'j': bint(5)}, reals(), data[1])
    check_funsor(x(1, 2), {}, reals(), data[1, 2])
    check_funsor(x(1, 2, k=3), {}, reals(), data[1, 2])
    check_funsor(x(1, j=2), {}, reals(), data[1, 2])
    check_funsor(x(1, j=2, k=3), (), reals(), data[1, 2])
    check_funsor(x(1, k=3), {'j': bint(5)}, reals(), data[1])
    check_funsor(x(i=1), {'j': bint(5)}, reals(), data[1])
    check_funsor(x(i=1, j=2), (), reals(), data[1, 2])
    check_funsor(x(i=1, j=2, k=3), (), reals(), data[1, 2])
    check_funsor(x(i=1, k=3), {'j': bint(5)}, reals(), data[1])
    check_funsor(x(j=2), {'i': bint(4)}, reals(), data[:, 2])
    check_funsor(x(j=2, k=3), {'i': bint(4)}, reals(), data[:, 2])
Пример #15
0
def tensor_to_funsor(value, cond_indep_stack, output):
    assert isinstance(value, torch.Tensor)
    event_shape = output.shape
    batch_shape = value.shape[:value.dim() - len(event_shape)]
    if torch._C._get_tracing_state():
        with funsor.tensor.ignore_jit_warnings():
            batch_shape = tuple(map(int, batch_shape))
    inputs = OrderedDict()
    data = value
    for dim, size in enumerate(batch_shape):
        if size == 1:
            data = data.squeeze(dim - value.dim())
        else:
            frame = cond_indep_stack[dim - len(batch_shape)]
            assert size == frame.size, (size, frame)
            inputs[frame.name] = funsor.bint(int(size))
    value = funsor.tensor.Tensor(data, inputs, output.dtype)
    assert value.output == output
    return value
Пример #16
0
def generate_HMM_dataset(model, args):
    """ Generates a sequence of observations from a given funsor model
    """

    data = [
        funsor.Variable('y_{}'.format(t), funsor.bint(args.hidden_dim))
        for t in range(args.time_steps)
    ]

    log_prob = model(data)
    var = [key for key, value in log_prob.inputs.items()]
    # TODO: move sample to model definition, to avoid memory explosion
    r = log_prob.sample(frozenset(var))
    data = torch.tensor([
        r.deltas[i].point.data for i in range(len(r.deltas))
        if r.deltas[i].name.startswith('y')
    ])

    return data
Пример #17
0
def test_advanced_indexing_lazy(output_shape):
    x = Array(np.random.normal(size=(2, 3, 4) + output_shape),
              OrderedDict([
                  ('i', bint(2)),
                  ('j', bint(3)),
                  ('k', bint(4)),
              ]))
    u = Variable('u', bint(2))
    v = Variable('v', bint(3))
    with interpretation(lazy):
        i = Number(1, 2) - u
        j = Number(2, 3) - v
        k = u + v

    expected_data = np.empty((2, 3) + output_shape)
    i_data = funsor.numpy.materialize(i).data.astype(np.int64)
    j_data = funsor.numpy.materialize(j).data.astype(np.int64)
    k_data = funsor.numpy.materialize(k).data.astype(np.int64)
    for u in range(2):
        for v in range(3):
            expected_data[u, v] = x.data[i_data[u], j_data[v], k_data[u, v]]
    expected = Array(expected_data,
                     OrderedDict([
                         ('u', bint(2)),
                         ('v', bint(3)),
                     ]))

    assert_equiv(expected, x(i, j, k))
    assert_equiv(expected, x(i=i, j=j, k=k))

    assert_equiv(expected, x(i=i, j=j)(k=k))
    assert_equiv(expected, x(j=j, k=k)(i=i))
    assert_equiv(expected, x(k=k, i=i)(j=j))

    assert_equiv(expected, x(i=i)(j=j, k=k))
    assert_equiv(expected, x(j=j)(k=k, i=i))
    assert_equiv(expected, x(k=k)(i=i, j=j))

    assert_equiv(expected, x(i=i)(j=j)(k=k))
    assert_equiv(expected, x(i=i)(k=k)(j=j))
    assert_equiv(expected, x(j=j)(i=i)(k=k))
    assert_equiv(expected, x(j=j)(k=k)(i=i))
    assert_equiv(expected, x(k=k)(i=i)(j=j))
    assert_equiv(expected, x(k=k)(j=j)(i=i))
Пример #18
0
    def process_message(self, msg):
        if msg["type"] != "sample" or \
                msg.get("done", False) or msg["is_observed"] or msg["infer"].get("expand", False) or \
                msg["infer"].get("enumerate") != "parallel" or (not msg["fn"].has_enumerate_support):
            if msg["type"] == "control_flow":
                msg["kwargs"]["enum"] = True
            return super().process_message(msg)

        if msg["infer"].get("num_samples", None) is not None:
            raise NotImplementedError("TODO implement multiple sampling")

        if msg["infer"].get("expand", False):
            raise NotImplementedError("expand=True not implemented")

        size = msg["fn"].enumerate_support(expand=False).shape[0]
        raw_value = jnp.arange(0, size)
        funsor_value = funsor.Tensor(
            raw_value, OrderedDict([(msg["name"], funsor.bint(size))]), size)

        msg["value"] = to_data(funsor_value)
        msg["done"] = True
Пример #19
0
def test_advanced_indexing_shape():
    I, J, M, N = 4, 4, 2, 3
    x = Array(np.random.normal(size=(I, J)),
              OrderedDict([
                  ('i', bint(I)),
                  ('j', bint(J)),
              ]))
    m = Array(np.array([2, 3]), OrderedDict([('m', bint(M))]), I)
    n = Array(np.array([0, 1, 1]), OrderedDict([('n', bint(N))]), J)
    assert x.data.shape == (I, J)

    check_funsor(x(i=m), {'j': bint(J), 'm': bint(M)}, reals())
    check_funsor(x(i=m, j=n), {'m': bint(M), 'n': bint(N)}, reals())
    check_funsor(x(i=m, j=n, k=m), {'m': bint(M), 'n': bint(N)}, reals())
    check_funsor(x(i=m, k=m), {'j': bint(J), 'm': bint(M)}, reals())
    check_funsor(x(i=n), {'j': bint(J), 'n': bint(N)}, reals())
    check_funsor(x(i=n, k=m), {'j': bint(J), 'n': bint(N)}, reals())
    check_funsor(x(j=m), {'i': bint(I), 'm': bint(M)}, reals())
    check_funsor(x(j=m, i=n), {'m': bint(M), 'n': bint(N)}, reals())
    check_funsor(x(j=m, i=n, k=m), {'m': bint(M), 'n': bint(N)}, reals())
    check_funsor(x(j=m, k=m), {'i': bint(I), 'm': bint(M)}, reals())
    check_funsor(x(j=n), {'i': bint(I), 'n': bint(N)}, reals())
    check_funsor(x(j=n, k=m), {'i': bint(I), 'n': bint(N)}, reals())
    check_funsor(x(m), {'j': bint(J), 'm': bint(M)}, reals())
    check_funsor(x(m, j=n), {'m': bint(M), 'n': bint(N)}, reals())
    check_funsor(x(m, j=n, k=m), {'m': bint(M), 'n': bint(N)}, reals())
    check_funsor(x(m, k=m), {'j': bint(J), 'm': bint(M)}, reals())
    check_funsor(x(m, n), {'m': bint(M), 'n': bint(N)}, reals())
    check_funsor(x(m, n, k=m), {'m': bint(M), 'n': bint(N)}, reals())
    check_funsor(x(n), {'j': bint(J), 'n': bint(N)}, reals())
    check_funsor(x(n, k=m), {'j': bint(J), 'n': bint(N)}, reals())
    check_funsor(x(n, m), {'m': bint(M), 'n': bint(N)}, reals())
    check_funsor(x(n, m, k=m), {'m': bint(M), 'n': bint(N)}, reals())
Пример #20
0
def test_to_data_error():
    data = np.zeros((3, 3))
    x = Array(data, OrderedDict(i=bint(3)))
    with pytest.raises(ValueError):
        funsor.to_data(x)
Пример #21
0
def test_advanced_indexing_array(output_shape):
    #      u   v
    #     / \ / \
    #    i   j   k
    #     \  |  /
    #      \ | /
    #        x
    output = reals(*output_shape)
    x = random_array(
        OrderedDict([
            ('i', bint(2)),
            ('j', bint(3)),
            ('k', bint(4)),
        ]), output)
    i = random_array(OrderedDict([
        ('u', bint(5)),
    ]), bint(2))
    j = random_array(OrderedDict([
        ('v', bint(6)),
        ('u', bint(5)),
    ]), bint(3))
    k = random_array(OrderedDict([
        ('v', bint(6)),
    ]), bint(4))

    expected_data = np.empty((5, 6) + output_shape)
    for u in range(5):
        for v in range(6):
            expected_data[u, v] = x.data[i.data[u], j.data[v, u], k.data[v]]
    expected = Array(expected_data,
                     OrderedDict([
                         ('u', bint(5)),
                         ('v', bint(6)),
                     ]))

    assert_equiv(expected, x(i, j, k))
    assert_equiv(expected, x(i=i, j=j, k=k))

    assert_equiv(expected, x(i=i, j=j)(k=k))
    assert_equiv(expected, x(j=j, k=k)(i=i))
    assert_equiv(expected, x(k=k, i=i)(j=j))

    assert_equiv(expected, x(i=i)(j=j, k=k))
    assert_equiv(expected, x(j=j)(k=k, i=i))
    assert_equiv(expected, x(k=k)(i=i, j=j))

    assert_equiv(expected, x(i=i)(j=j)(k=k))
    assert_equiv(expected, x(i=i)(k=k)(j=j))
    assert_equiv(expected, x(j=j)(i=i)(k=k))
    assert_equiv(expected, x(j=j)(k=k)(i=i))
    assert_equiv(expected, x(k=k)(i=i)(j=j))
    assert_equiv(expected, x(k=k)(j=j)(i=i))
Пример #22
0
    def filter_and_predict(self, data, smoothing=False):
        trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist = self.get_tensors_and_dists(
        )

        log_prob = funsor.Number(0.)

        s_vars = {
            -1: funsor.Tensor(torch.tensor(0), dtype=self.num_components)
        }
        x_vars = {-1: None}

        predictive_x_dists, predictive_y_dists, filtering_dists = [], [], []
        test_LLs = []

        for t, y in enumerate(data):
            s_vars[t] = funsor.Variable(f's_{t}',
                                        funsor.bint(self.num_components))
            x_vars[t] = funsor.Variable(f'x_{t}',
                                        funsor.reals(self.hidden_dim))

            log_prob += dist.Categorical(trans_probs(s=s_vars[t - 1]),
                                         value=s_vars[t])

            if t == 0:
                log_prob += self.x_init_mvn(value=x_vars[t])
            else:
                log_prob += x_trans_dist(s=s_vars[t],
                                         x=x_vars[t - 1],
                                         y=x_vars[t])

            if t > 0:
                log_prob = log_prob.reduce(
                    ops.logaddexp,
                    frozenset([s_vars[t - 1].name, x_vars[t - 1].name]))

            # do 1-step prediction and compute test LL
            if t > 0:
                predictive_x_dists.append(log_prob)
                _log_prob = log_prob - log_prob.reduce(ops.logaddexp)
                predictive_y_dist = y_dist(s=s_vars[t],
                                           x=x_vars[t]) + _log_prob
                test_LLs.append(
                    predictive_y_dist(y=y).reduce(ops.logaddexp).data.item())
                predictive_y_dist = predictive_y_dist.reduce(
                    ops.logaddexp, frozenset([f"x_{t}", f"s_{t}"]))
                predictive_y_dists.append(
                    funsor_to_mvn(predictive_y_dist, 0, ()))

            log_prob += y_dist(s=s_vars[t], x=x_vars[t], y=y)

            # save filtering dists for forward-backward smoothing
            if smoothing:
                filtering_dists.append(log_prob)

        # do the backward recursion using previously computed ingredients
        if smoothing:
            # seed the backward recursion with the filtering distribution at t=T
            smoothing_dists = [filtering_dists[-1]]
            T = data.size(0)

            s_vars = {
                t: funsor.Variable(f's_{t}', funsor.bint(self.num_components))
                for t in range(T)
            }
            x_vars = {
                t: funsor.Variable(f'x_{t}', funsor.reals(self.hidden_dim))
                for t in range(T)
            }

            # do the backward recursion.
            # let p[t|t-1] be the predictive distribution at time step t.
            # let p[t|t] be the filtering distribution at time step t.
            # let f[t] denote the prior (transition) density at time step t.
            # then the smoothing distribution p[t|T] at time step t is
            # given by the following recursion.
            # p[t-1|T] = p[t-1|t-1] <p[t|T] f[t] / p[t|t-1]>
            # where <...> denotes integration of the latent variables at time step t.
            for t in reversed(range(T - 1)):
                integral = smoothing_dists[-1] - predictive_x_dists[t]
                integral += dist.Categorical(trans_probs(s=s_vars[t]),
                                             value=s_vars[t + 1])
                integral += x_trans_dist(s=s_vars[t],
                                         x=x_vars[t],
                                         y=x_vars[t + 1])
                integral = integral.reduce(
                    ops.logaddexp,
                    frozenset([s_vars[t + 1].name, x_vars[t + 1].name]))
                smoothing_dists.append(filtering_dists[t] + integral)

        # compute predictive test MSE and predictive variances
        predictive_means = torch.stack([d.mean for d in predictive_y_dists
                                        ])  # T-1 ydim
        predictive_vars = torch.stack([
            d.covariance_matrix.diagonal(dim1=-1, dim2=-2)
            for d in predictive_y_dists
        ])
        predictive_mse = (predictive_means - data[1:, :]).pow(2.0).mean(-1)

        if smoothing:
            # compute smoothed mean function
            smoothing_dists = [
                funsor_to_cat_and_mvn(d, 0, (f"s_{t}", ))
                for t, d in enumerate(reversed(smoothing_dists))
            ]
            means = torch.stack([d[1].mean
                                 for d in smoothing_dists])  # T 2 xdim
            means = torch.matmul(means.unsqueeze(-2),
                                 self.observation_matrix).squeeze(
                                     -2)  # T 2 ydim

            probs = torch.stack([d[0].logits for d in smoothing_dists]).exp()
            probs = probs / probs.sum(-1, keepdim=True)  # T 2

            smoothing_means = (probs.unsqueeze(-1) * means).sum(-2)  # T ydim
            smoothing_probs = probs[:, 1]

            return predictive_mse, torch.tensor(np.array(test_LLs)), predictive_means, predictive_vars, \
                smoothing_means, smoothing_probs
        else:
            return predictive_mse, torch.tensor(np.array(test_LLs))