コード例 #1
0
def test_to_funsor(shape, dtype):
    t = np.random.normal(size=shape).astype(dtype)
    f = funsor.to_funsor(t)
    assert isinstance(f, Array)
    assert funsor.to_funsor(t, reals(*shape)) is f
    with pytest.raises(ValueError):
        funsor.to_funsor(t, reals(5, *shape))
コード例 #2
0
 def testing():
     for i in markov(range(5)):
         v1 = to_data(Tensor(jnp.ones(2), OrderedDict([(str(i), bint(2))]), 'real'))
         v2 = to_data(Tensor(jnp.zeros(2), OrderedDict([('a', bint(2))]), 'real'))
         fv1 = to_funsor(v1, reals())
         fv2 = to_funsor(v2, reals())
         print(i, v1.shape)  # shapes should alternate
         if i % 2 == 0:
             assert v1.shape == (2,)
         else:
             assert v1.shape == (2, 1, 1)
         assert v2.shape == (2, 1)
         print(i, fv1.inputs)
         print('a', v2.shape)  # shapes should stay the same
         print('a', fv2.inputs)
コード例 #3
0
ファイル: slds.py プロジェクト: MillerJJY/funsor
    def model(data):
        log_prob = funsor.Number(0.)

        # s is the discrete latent state,
        # x is the continuous latent state,
        # y is the observed state.
        s_curr = funsor.Tensor(torch.tensor(0), dtype=2)
        x_curr = funsor.Tensor(torch.tensor(0.))
        for t, y in enumerate(data):
            s_prev = s_curr
            x_prev = x_curr

            # A delayed sample statement.
            s_curr = funsor.Variable('s_{}'.format(t), funsor.bint(2))
            log_prob += dist.Categorical(trans_probs[s_prev], value=s_curr)

            # A delayed sample statement.
            x_curr = funsor.Variable('x_{}'.format(t), funsor.reals())
            log_prob += dist.Normal(x_prev, trans_noise[s_curr], value=x_curr)

            # Marginalize out previous delayed sample statements.
            if t > 0:
                log_prob = log_prob.reduce(ops.logaddexp,
                                           {s_prev.name, x_prev.name})

            # An observe statement.
            log_prob += dist.Normal(x_curr, emit_noise, value=y)

        log_prob = log_prob.reduce(ops.logaddexp)
        return log_prob
コード例 #4
0
    def model(data):
        log_prob = funsor.to_funsor(0.)
        xs_curr = [funsor.Tensor(torch.tensor(0.)) for var in var_names]

        for t, y in enumerate(data):
            xs_prev = xs_curr

            # A delayed sample statement.
            xs_curr = [
                funsor.Variable(name + '_{}'.format(t), funsor.reals())
                for name in var_names
            ]

            for i, x_curr in enumerate(xs_curr):
                log_prob += dist.Normal(trans_eqs[var_names[i]](xs_prev),
                                        torch.exp(trans_noises[i]),
                                        value=x_curr)

            if t > 0:
                log_prob = log_prob.reduce(
                    ops.logaddexp,
                    frozenset([x_prev.name for x_prev in xs_prev]))

            # An observe statement.
            log_prob += dist.Normal(emit_eq(xs_curr),
                                    torch.exp(emit_noise),
                                    value=y)

        # Marginalize out all remaining delayed variables.
        return log_prob.reduce(ops.logaddexp), log_prob.gaussian
コード例 #5
0
def one_step_prediction(p_x_tp1, t, var_names, emit_eq, emit_noise):
    """Computes p(y_{t+1}) from p(x_{t+1}). We assume y_t is scalar, so only one emit_eq"""
    log_prob = p_x_tp1

    x_tp1s = [
        funsor.Variable(name + '_{}'.format(t + 1), funsor.reals())
        for name in var_names
    ]
    y_tp1 = funsor.Variable('y_{}'.format(t + 1), funsor.reals())
    log_prob += dist.Normal(emit_eq(x_tp1s),
                            torch.exp(emit_noise),
                            value=y_tp1)
    log_prob = log_prob.reduce(ops.logaddexp,
                               frozenset([x_tp1.name for x_tp1 in x_tp1s]))

    return log_prob
コード例 #6
0
 def testing():
     for i in markov(range(12)):
         if i % 4 == 0:
             v2 = to_data(Tensor(jnp.zeros(2), OrderedDict([('a', bint(2))]), 'real'))
             fv2 = to_funsor(v2, reals())
             assert v2.shape == (2,)
             print('a', v2.shape)
             print('a', fv2.inputs)
コード例 #7
0
    def log_prob(self, data):
        trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist = self.get_tensors_and_dists(
        )

        log_prob = funsor.Number(0.)

        s_vars = {
            -1: funsor.Tensor(torch.tensor(0), dtype=self.num_components)
        }
        x_vars = {}

        for t, y in enumerate(data):
            # construct free variables for s_t and x_t
            s_vars[t] = funsor.Variable(f's_{t}',
                                        funsor.bint(self.num_components))
            x_vars[t] = funsor.Variable(f'x_{t}',
                                        funsor.reals(self.hidden_dim))

            # incorporate the discrete switching dynamics
            log_prob += dist.Categorical(trans_probs(s=s_vars[t - 1]),
                                         value=s_vars[t])

            # incorporate the prior term p(x_t | x_{t-1})
            if t == 0:
                log_prob += self.x_init_mvn(value=x_vars[t])
            else:
                log_prob += x_trans_dist(s=s_vars[t],
                                         x=x_vars[t - 1],
                                         y=x_vars[t])

            # do a moment-matching reduction. at this point log_prob depends on (moment_matching_lag + 1)-many
            # pairs of free variables.
            if t > self.moment_matching_lag - 1:
                log_prob = log_prob.reduce(
                    ops.logaddexp,
                    frozenset([
                        s_vars[t - self.moment_matching_lag].name,
                        x_vars[t - self.moment_matching_lag].name
                    ]))

            # incorporate the observation p(y_t | x_t, s_t)
            log_prob += y_dist(s=s_vars[t], x=x_vars[t], y=y)

        T = data.shape[0]
        # reduce any remaining free variables
        for t in range(self.moment_matching_lag):
            log_prob = log_prob.reduce(
                ops.logaddexp,
                frozenset([
                    s_vars[T - self.moment_matching_lag + t].name,
                    x_vars[T - self.moment_matching_lag + t].name
                ]))

        # assert that we've reduced all the free variables in log_prob
        assert not log_prob.inputs, 'unexpected free variables remain'

        # return the PyTorch tensor behind log_prob (which we can directly differentiate)
        return log_prob.data
コード例 #8
0
ファイル: minipyro.py プロジェクト: pangyyyyy/funsor
 def compiled(*params_and_args):
     unconstrained_params = params_and_args[:len(self._param_trace)]
     args = params_and_args[len(self._param_trace):]
     for name, unconstrained_param in zip(self._param_trace, unconstrained_params):
         constrained_param = param(name)  # assume param has been initialized
         assert constrained_param.data.unconstrained() is unconstrained_param
         self._param_trace[name]["value"] = constrained_param
     result = replay(self.fn, guide_trace=self._param_trace)(*args)
     assert not result.inputs
     assert result.output == funsor.reals()
     return funsor.to_data(result)
コード例 #9
0
def main(args):
    # Generate fake data.
    data = funsor.Tensor(torch.randn(100),
                         inputs=OrderedDict([('data', funsor.bint(100))]),
                         output=funsor.reals())

    # Train.
    optim = pyro.Adam({'lr': args.learning_rate})
    svi = pyro.SVI(model, pyro.deferred(guide), optim, pyro.elbo)
    for step in range(args.steps):
        svi.step(data)
コード例 #10
0
def next_state(p_x_t, t, var_names, trans_eqs, trans_noises):
    """Computes p(x_{t+1}) from p(x_t)"""
    log_prob = p_x_t

    x_ts = [
        funsor.Variable(name + '_{}'.format(t), funsor.reals())
        for name in var_names
    ]
    x_tp1s = [
        funsor.Variable(name + '_{}'.format(t + 1), funsor.reals())
        for name in var_names
    ]

    for i, x_tp1 in enumerate(x_tp1s):
        log_prob += dist.Normal(trans_eqs[var_names[i]](x_ts),
                                torch.exp(trans_noises[i]),
                                value=x_tp1)

    log_prob = log_prob.reduce(ops.logaddexp,
                               frozenset([x_t.name for x_t in x_ts]))
    return log_prob
コード例 #11
0
def test_advanced_indexing_array(output_shape):
    #      u   v
    #     / \ / \
    #    i   j   k
    #     \  |  /
    #      \ | /
    #        x
    output = reals(*output_shape)
    x = random_array(
        OrderedDict([
            ('i', bint(2)),
            ('j', bint(3)),
            ('k', bint(4)),
        ]), output)
    i = random_array(OrderedDict([
        ('u', bint(5)),
    ]), bint(2))
    j = random_array(OrderedDict([
        ('v', bint(6)),
        ('u', bint(5)),
    ]), bint(3))
    k = random_array(OrderedDict([
        ('v', bint(6)),
    ]), bint(4))

    expected_data = np.empty((5, 6) + output_shape)
    for u in range(5):
        for v in range(6):
            expected_data[u, v] = x.data[i.data[u], j.data[v, u], k.data[v]]
    expected = Array(expected_data,
                     OrderedDict([
                         ('u', bint(5)),
                         ('v', bint(6)),
                     ]))

    assert_equiv(expected, x(i, j, k))
    assert_equiv(expected, x(i=i, j=j, k=k))

    assert_equiv(expected, x(i=i, j=j)(k=k))
    assert_equiv(expected, x(j=j, k=k)(i=i))
    assert_equiv(expected, x(k=k, i=i)(j=j))

    assert_equiv(expected, x(i=i)(j=j, k=k))
    assert_equiv(expected, x(j=j)(k=k, i=i))
    assert_equiv(expected, x(k=k)(i=i, j=j))

    assert_equiv(expected, x(i=i)(j=j)(k=k))
    assert_equiv(expected, x(i=i)(k=k)(j=j))
    assert_equiv(expected, x(j=j)(i=i)(k=k))
    assert_equiv(expected, x(j=j)(k=k)(i=i))
    assert_equiv(expected, x(k=k)(i=i)(j=j))
    assert_equiv(expected, x(k=k)(j=j)(i=i))
コード例 #12
0
def param(name,
          init_value=None,
          constraint=torch.distributions.constraints.real,
          event_dim=None):
    cond_indep_stack = {}
    output = None
    if init_value is not None:
        if event_dim is None:
            event_dim = init_value.dim()
        output = funsor.reals(*init_value.shape[init_value.dim() - event_dim:])

    def fn(init_value, constraint):
        if name in PARAM_STORE:
            unconstrained_value, constraint = PARAM_STORE[name]
        else:
            # Initialize with a constrained value.
            assert init_value is not None
            with torch.no_grad():
                constrained_value = init_value.detach()
                unconstrained_value = torch.distributions.transform_to(
                    constraint).inv(constrained_value)
            unconstrained_value.requires_grad_()
            unconstrained_value._funsor_metadata = (cond_indep_stack, output)
            PARAM_STORE[name] = unconstrained_value, constraint

        # Transform from unconstrained space to constrained space.
        constrained_value = torch.distributions.transform_to(constraint)(
            unconstrained_value)
        constrained_value.unconstrained = weakref.ref(unconstrained_value)
        return tensor_to_funsor(constrained_value,
                                *unconstrained_value._funsor_metadata)

    # if there are no active Messengers, we just draw a sample and return it as expected:
    if not PYRO_STACK:
        return fn(init_value, constraint)

    # Otherwise, we initialize a message...
    initial_msg = {
        "type": "param",
        "name": name,
        "fn": fn,
        "args": (init_value, constraint),
        "value": None,
        "cond_indep_stack":
        cond_indep_stack,  # maps dim to CondIndepStackFrame
        "output": output,
    }

    # ...and use apply_stack to send it to the Messengers
    msg = apply_stack(initial_msg)
    assert isinstance(msg["value"], funsor.Funsor)
    return msg["value"]
コード例 #13
0
ファイル: eeg_slds.py プロジェクト: lawrencechen0921/funsor
    def __init__(self,
                 num_components,   # the number of switching states K
                 hidden_dim,       # the dimension of the continuous latent space
                 obs_dim,          # the dimension of the continuous outputs
                 fine_transition_matrix=True,    # controls whether the transition matrix depends on s_t
                 fine_transition_noise=False,    # controls whether the transition noise depends on s_t
                 fine_observation_matrix=False,  # controls whether the observation matrix depends on s_t
                 fine_observation_noise=False,   # controls whether the observation noise depends on s_t
                 moment_matching_lag=1):         # controls the expense of the moment matching approximation

        self.num_components = num_components
        self.hidden_dim = hidden_dim
        self.obs_dim = obs_dim
        self.moment_matching_lag = moment_matching_lag
        self.fine_transition_noise = fine_transition_noise
        self.fine_observation_matrix = fine_observation_matrix
        self.fine_observation_noise = fine_observation_noise
        self.fine_transition_matrix = fine_transition_matrix

        assert moment_matching_lag > 0
        assert fine_transition_noise or fine_observation_matrix or fine_observation_noise or fine_transition_matrix, \
            "The continuous dynamics need to be coupled to the discrete dynamics in at least one way [use at " + \
            "least one of the arguments --ftn --ftm --fon --fom]"

        super(SLDS, self).__init__()

        # initialize the various parameters of the model
        self.transition_logits = nn.Parameter(0.1 * torch.randn(num_components, num_components))
        if fine_transition_matrix:
            transition_matrix = torch.eye(hidden_dim) + 0.05 * torch.randn(num_components, hidden_dim, hidden_dim)
        else:
            transition_matrix = torch.eye(hidden_dim) + 0.05 * torch.randn(hidden_dim, hidden_dim)
        self.transition_matrix = nn.Parameter(transition_matrix)
        if fine_transition_noise:
            self.log_transition_noise = nn.Parameter(0.1 * torch.randn(num_components, hidden_dim))
        else:
            self.log_transition_noise = nn.Parameter(0.1 * torch.randn(hidden_dim))
        if fine_observation_matrix:
            self.observation_matrix = nn.Parameter(0.3 * torch.randn(num_components, hidden_dim, obs_dim))
        else:
            self.observation_matrix = nn.Parameter(0.3 * torch.randn(hidden_dim, obs_dim))
        if fine_observation_noise:
            self.log_obs_noise = nn.Parameter(0.1 * torch.randn(num_components, obs_dim))
        else:
            self.log_obs_noise = nn.Parameter(0.1 * torch.randn(obs_dim))

        # define the prior distribution p(x_0) over the continuous latent at the initial time step t=0
        x_init_mvn = torch.distributions.MultivariateNormal(torch.zeros(self.hidden_dim), torch.eye(self.hidden_dim))
        self.x_init_mvn = mvn_to_funsor(x_init_mvn, real_inputs=OrderedDict([('x_0', funsor.reals(self.hidden_dim))]))
コード例 #14
0
def update(p_x_tp1, t, y, var_names, emit_eq, emit_noise):
    """Computes p(x_{t+1} | y_{t+1}) from p(x_{t+1}). This is useful for iterating 1-step ahead predictions"""
    log_prob = p_x_tp1

    x_tp1s = [
        funsor.Variable(name + '_{}'.format(t + 1), funsor.reals())
        for name in var_names
    ]
    log_p_x = log_prob

    log_prob += dist.Normal(emit_eq(x_tp1s), emit_noise, value=y)
    log_p_y = log_prob.reduce(ops.logaddexp,
                              frozenset([x_tp1.name for x_tp1 in x_tp1s]))

    log_p_x_y = log_prob + log_p_x - log_p_y
    return log_p_x_y
コード例 #15
0
ファイル: kalman_filter.py プロジェクト: MillerJJY/funsor
    def model(data):
        log_prob = funsor.to_funsor(0.)

        x_curr = funsor.Tensor(torch.tensor(0.))
        for t, y in enumerate(data):
            x_prev = x_curr

            # A delayed sample statement.
            x_curr = funsor.Variable('x_{}'.format(t), funsor.reals())
            log_prob += dist.Normal(1 + x_prev / 2., trans_noise, value=x_curr)

            # Optionally marginalize out the previous state.
            if t > 0 and not args.lazy:
                log_prob = log_prob.reduce(ops.logaddexp, x_prev.name)

            # An observe statement.
            log_prob += dist.Normal(0.5 + 3 * x_curr, emit_noise, value=y)

        # Marginalize out all remaining delayed variables.
        log_prob = log_prob.reduce(ops.logaddexp)
        return log_prob
コード例 #16
0
def test_advanced_indexing_shape():
    I, J, M, N = 4, 4, 2, 3
    x = Array(np.random.normal(size=(I, J)),
              OrderedDict([
                  ('i', bint(I)),
                  ('j', bint(J)),
              ]))
    m = Array(np.array([2, 3]), OrderedDict([('m', bint(M))]), I)
    n = Array(np.array([0, 1, 1]), OrderedDict([('n', bint(N))]), J)
    assert x.data.shape == (I, J)

    check_funsor(x(i=m), {'j': bint(J), 'm': bint(M)}, reals())
    check_funsor(x(i=m, j=n), {'m': bint(M), 'n': bint(N)}, reals())
    check_funsor(x(i=m, j=n, k=m), {'m': bint(M), 'n': bint(N)}, reals())
    check_funsor(x(i=m, k=m), {'j': bint(J), 'm': bint(M)}, reals())
    check_funsor(x(i=n), {'j': bint(J), 'n': bint(N)}, reals())
    check_funsor(x(i=n, k=m), {'j': bint(J), 'n': bint(N)}, reals())
    check_funsor(x(j=m), {'i': bint(I), 'm': bint(M)}, reals())
    check_funsor(x(j=m, i=n), {'m': bint(M), 'n': bint(N)}, reals())
    check_funsor(x(j=m, i=n, k=m), {'m': bint(M), 'n': bint(N)}, reals())
    check_funsor(x(j=m, k=m), {'i': bint(I), 'm': bint(M)}, reals())
    check_funsor(x(j=n), {'i': bint(I), 'n': bint(N)}, reals())
    check_funsor(x(j=n, k=m), {'i': bint(I), 'n': bint(N)}, reals())
    check_funsor(x(m), {'j': bint(J), 'm': bint(M)}, reals())
    check_funsor(x(m, j=n), {'m': bint(M), 'n': bint(N)}, reals())
    check_funsor(x(m, j=n, k=m), {'m': bint(M), 'n': bint(N)}, reals())
    check_funsor(x(m, k=m), {'j': bint(J), 'm': bint(M)}, reals())
    check_funsor(x(m, n), {'m': bint(M), 'n': bint(N)}, reals())
    check_funsor(x(m, n, k=m), {'m': bint(M), 'n': bint(N)}, reals())
    check_funsor(x(n), {'j': bint(J), 'n': bint(N)}, reals())
    check_funsor(x(n, k=m), {'j': bint(J), 'n': bint(N)}, reals())
    check_funsor(x(n, m), {'m': bint(M), 'n': bint(N)}, reals())
    check_funsor(x(n, m, k=m), {'m': bint(M), 'n': bint(N)}, reals())
コード例 #17
0
def test_indexing():
    data = np.random.normal(size=(4, 5))
    inputs = OrderedDict([('i', bint(4)), ('j', bint(5))])
    x = Array(data, inputs)
    check_funsor(x, inputs, reals(), data)

    assert x() is x
    assert x(k=3) is x
    check_funsor(x(1), {'j': bint(5)}, reals(), data[1])
    check_funsor(x(1, 2), {}, reals(), data[1, 2])
    check_funsor(x(1, 2, k=3), {}, reals(), data[1, 2])
    check_funsor(x(1, j=2), {}, reals(), data[1, 2])
    check_funsor(x(1, j=2, k=3), (), reals(), data[1, 2])
    check_funsor(x(1, k=3), {'j': bint(5)}, reals(), data[1])
    check_funsor(x(i=1), {'j': bint(5)}, reals(), data[1])
    check_funsor(x(i=1, j=2), (), reals(), data[1, 2])
    check_funsor(x(i=1, j=2, k=3), (), reals(), data[1, 2])
    check_funsor(x(i=1, k=3), {'j': bint(5)}, reals(), data[1])
    check_funsor(x(j=2), {'i': bint(4)}, reals(), data[:, 2])
    check_funsor(x(j=2, k=3), {'i': bint(4)}, reals(), data[:, 2])
コード例 #18
0
    def filter_and_predict(self, data, smoothing=False):
        trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist = self.get_tensors_and_dists(
        )

        log_prob = funsor.Number(0.)

        s_vars = {
            -1: funsor.Tensor(torch.tensor(0), dtype=self.num_components)
        }
        x_vars = {-1: None}

        predictive_x_dists, predictive_y_dists, filtering_dists = [], [], []
        test_LLs = []

        for t, y in enumerate(data):
            s_vars[t] = funsor.Variable(f's_{t}',
                                        funsor.bint(self.num_components))
            x_vars[t] = funsor.Variable(f'x_{t}',
                                        funsor.reals(self.hidden_dim))

            log_prob += dist.Categorical(trans_probs(s=s_vars[t - 1]),
                                         value=s_vars[t])

            if t == 0:
                log_prob += self.x_init_mvn(value=x_vars[t])
            else:
                log_prob += x_trans_dist(s=s_vars[t],
                                         x=x_vars[t - 1],
                                         y=x_vars[t])

            if t > 0:
                log_prob = log_prob.reduce(
                    ops.logaddexp,
                    frozenset([s_vars[t - 1].name, x_vars[t - 1].name]))

            # do 1-step prediction and compute test LL
            if t > 0:
                predictive_x_dists.append(log_prob)
                _log_prob = log_prob - log_prob.reduce(ops.logaddexp)
                predictive_y_dist = y_dist(s=s_vars[t],
                                           x=x_vars[t]) + _log_prob
                test_LLs.append(
                    predictive_y_dist(y=y).reduce(ops.logaddexp).data.item())
                predictive_y_dist = predictive_y_dist.reduce(
                    ops.logaddexp, frozenset([f"x_{t}", f"s_{t}"]))
                predictive_y_dists.append(
                    funsor_to_mvn(predictive_y_dist, 0, ()))

            log_prob += y_dist(s=s_vars[t], x=x_vars[t], y=y)

            # save filtering dists for forward-backward smoothing
            if smoothing:
                filtering_dists.append(log_prob)

        # do the backward recursion using previously computed ingredients
        if smoothing:
            # seed the backward recursion with the filtering distribution at t=T
            smoothing_dists = [filtering_dists[-1]]
            T = data.size(0)

            s_vars = {
                t: funsor.Variable(f's_{t}', funsor.bint(self.num_components))
                for t in range(T)
            }
            x_vars = {
                t: funsor.Variable(f'x_{t}', funsor.reals(self.hidden_dim))
                for t in range(T)
            }

            # do the backward recursion.
            # let p[t|t-1] be the predictive distribution at time step t.
            # let p[t|t] be the filtering distribution at time step t.
            # let f[t] denote the prior (transition) density at time step t.
            # then the smoothing distribution p[t|T] at time step t is
            # given by the following recursion.
            # p[t-1|T] = p[t-1|t-1] <p[t|T] f[t] / p[t|t-1]>
            # where <...> denotes integration of the latent variables at time step t.
            for t in reversed(range(T - 1)):
                integral = smoothing_dists[-1] - predictive_x_dists[t]
                integral += dist.Categorical(trans_probs(s=s_vars[t]),
                                             value=s_vars[t + 1])
                integral += x_trans_dist(s=s_vars[t],
                                         x=x_vars[t],
                                         y=x_vars[t + 1])
                integral = integral.reduce(
                    ops.logaddexp,
                    frozenset([s_vars[t + 1].name, x_vars[t + 1].name]))
                smoothing_dists.append(filtering_dists[t] + integral)

        # compute predictive test MSE and predictive variances
        predictive_means = torch.stack([d.mean for d in predictive_y_dists
                                        ])  # T-1 ydim
        predictive_vars = torch.stack([
            d.covariance_matrix.diagonal(dim1=-1, dim2=-2)
            for d in predictive_y_dists
        ])
        predictive_mse = (predictive_means - data[1:, :]).pow(2.0).mean(-1)

        if smoothing:
            # compute smoothed mean function
            smoothing_dists = [
                funsor_to_cat_and_mvn(d, 0, (f"s_{t}", ))
                for t, d in enumerate(reversed(smoothing_dists))
            ]
            means = torch.stack([d[1].mean
                                 for d in smoothing_dists])  # T 2 xdim
            means = torch.matmul(means.unsqueeze(-2),
                                 self.observation_matrix).squeeze(
                                     -2)  # T 2 ydim

            probs = torch.stack([d[0].logits for d in smoothing_dists]).exp()
            probs = probs / probs.sum(-1, keepdim=True)  # T 2

            smoothing_means = (probs.unsqueeze(-1) * means).sum(-2)  # T ydim
            smoothing_probs = probs[:, 1]

            return predictive_mse, torch.tensor(np.array(test_LLs)), predictive_means, predictive_vars, \
                smoothing_means, smoothing_probs
        else:
            return predictive_mse, torch.tensor(np.array(test_LLs))
コード例 #19
0
def log_density(model, model_args, model_kwargs, params):
    """
    Similar to :func:`numpyro.infer.util.log_density` but works for models
    with discrete latent variables. Internally, this uses :mod:`funsor`
    to marginalize discrete latent sites and evaluate the joint log probability.

    :param model: Python callable containing NumPyro primitives. Typically,
        the model has been enumerated by using
        :class:`~numpyro.contrib.funsor.enum_messenger.enum` handler::

            def model(*args, **kwargs):
                ...

            log_joint = log_density(enum(config_enumerate(model)), args, kwargs, params)

    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :param dict params: dictionary of current parameter values keyed by site
        name.
    :return: log of joint density and a corresponding model trace
    """
    model = substitute(model, data=params)
    with plate_to_enum_plate():
        model_trace = packed_trace(model).get_trace(*model_args,
                                                    **model_kwargs)
    log_factors = []
    time_to_factors = defaultdict(list)  # log prob factors
    time_to_init_vars = defaultdict(frozenset)  # _init/... variables
    time_to_markov_dims = defaultdict(frozenset)  # dimensions at markov sites
    sum_vars, prod_vars = frozenset(), frozenset()
    for site in model_trace.values():
        if site['type'] == 'sample':
            value = site['value']
            intermediates = site['intermediates']
            scale = site['scale']
            if intermediates:
                log_prob = site['fn'].log_prob(value, intermediates)
            else:
                log_prob = site['fn'].log_prob(value)

            if (scale is not None) and (not is_identically_one(scale)):
                log_prob = scale * log_prob

            dim_to_name = site["infer"]["dim_to_name"]
            log_prob = funsor.to_funsor(log_prob,
                                        output=funsor.reals(),
                                        dim_to_name=dim_to_name)

            time_dim = None
            for dim, name in dim_to_name.items():
                if name.startswith("_time"):
                    time_dim = funsor.Variable(
                        name, funsor.domains.bint(site["value"].shape[dim]))
                    time_to_factors[time_dim].append(log_prob)
                    time_to_init_vars[time_dim] |= frozenset(
                        s for s in dim_to_name.values()
                        if s.startswith("_init"))
                    break
            if time_dim is None:
                log_factors.append(log_prob)

            if not site['is_observed']:
                sum_vars |= frozenset({site['name']})
            prod_vars |= frozenset(f.name for f in site['cond_indep_stack']
                                   if f.dim is not None)

    for time_dim, init_vars in time_to_init_vars.items():
        for var in init_vars:
            curr_var = "/".join(var.split("/")[1:])
            dim_to_name = model_trace[curr_var]["infer"]["dim_to_name"]
            if var in dim_to_name.values(
            ):  # i.e. _init (i.e. prev) in dim_to_name
                time_to_markov_dims[time_dim] |= frozenset(
                    name for name in dim_to_name.values())

    if len(time_to_factors) > 0:
        markov_factors = compute_markov_factors(time_to_factors,
                                                time_to_init_vars,
                                                time_to_markov_dims, sum_vars,
                                                prod_vars)
        log_factors = log_factors + markov_factors

    with funsor.interpreter.interpretation(funsor.terms.lazy):
        lazy_result = funsor.sum_product.sum_product(funsor.ops.logaddexp,
                                                     funsor.ops.add,
                                                     log_factors,
                                                     eliminate=sum_vars
                                                     | prod_vars,
                                                     plates=prod_vars)
    result = funsor.optimizer.apply_optimizer(lazy_result)
    if len(result.inputs) > 0:
        raise ValueError(
            "Expected the joint log density is a scalar, but got {}. "
            "There seems to be something wrong at the following sites: {}.".
            format(result.data.shape,
                   {k.split("__BOUND")[0]
                    for k in result.inputs}))
    return result.data, model_trace