예제 #1
0
파일: slds.py 프로젝트: MillerJJY/funsor
    def model(data):
        log_prob = funsor.Number(0.)

        # s is the discrete latent state,
        # x is the continuous latent state,
        # y is the observed state.
        s_curr = funsor.Tensor(torch.tensor(0), dtype=2)
        x_curr = funsor.Tensor(torch.tensor(0.))
        for t, y in enumerate(data):
            s_prev = s_curr
            x_prev = x_curr

            # A delayed sample statement.
            s_curr = funsor.Variable('s_{}'.format(t), funsor.bint(2))
            log_prob += dist.Categorical(trans_probs[s_prev], value=s_curr)

            # A delayed sample statement.
            x_curr = funsor.Variable('x_{}'.format(t), funsor.reals())
            log_prob += dist.Normal(x_prev, trans_noise[s_curr], value=x_curr)

            # Marginalize out previous delayed sample statements.
            if t > 0:
                log_prob = log_prob.reduce(ops.logaddexp,
                                           {s_prev.name, x_prev.name})

            # An observe statement.
            log_prob += dist.Normal(x_curr, emit_noise, value=y)

        log_prob = log_prob.reduce(ops.logaddexp)
        return log_prob
예제 #2
0
    def log_prob(self, data):
        trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist = self.get_tensors_and_dists(
        )

        log_prob = funsor.Number(0.)

        s_vars = {
            -1: funsor.Tensor(torch.tensor(0), dtype=self.num_components)
        }
        x_vars = {}

        for t, y in enumerate(data):
            # construct free variables for s_t and x_t
            s_vars[t] = funsor.Variable(f's_{t}',
                                        funsor.bint(self.num_components))
            x_vars[t] = funsor.Variable(f'x_{t}',
                                        funsor.reals(self.hidden_dim))

            # incorporate the discrete switching dynamics
            log_prob += dist.Categorical(trans_probs(s=s_vars[t - 1]),
                                         value=s_vars[t])

            # incorporate the prior term p(x_t | x_{t-1})
            if t == 0:
                log_prob += self.x_init_mvn(value=x_vars[t])
            else:
                log_prob += x_trans_dist(s=s_vars[t],
                                         x=x_vars[t - 1],
                                         y=x_vars[t])

            # do a moment-matching reduction. at this point log_prob depends on (moment_matching_lag + 1)-many
            # pairs of free variables.
            if t > self.moment_matching_lag - 1:
                log_prob = log_prob.reduce(
                    ops.logaddexp,
                    frozenset([
                        s_vars[t - self.moment_matching_lag].name,
                        x_vars[t - self.moment_matching_lag].name
                    ]))

            # incorporate the observation p(y_t | x_t, s_t)
            log_prob += y_dist(s=s_vars[t], x=x_vars[t], y=y)

        T = data.shape[0]
        # reduce any remaining free variables
        for t in range(self.moment_matching_lag):
            log_prob = log_prob.reduce(
                ops.logaddexp,
                frozenset([
                    s_vars[T - self.moment_matching_lag + t].name,
                    x_vars[T - self.moment_matching_lag + t].name
                ]))

        # assert that we've reduced all the free variables in log_prob
        assert not log_prob.inputs, 'unexpected free variables remain'

        # return the PyTorch tensor behind log_prob (which we can directly differentiate)
        return log_prob.data
예제 #3
0
    def model(data):
        log_prob = funsor.to_funsor(0.)

        trans = dist.Categorical(probs=funsor.Tensor(
            trans_probs,
            inputs=OrderedDict([('prev', funsor.bint(args.hidden_dim))]),
        ))

        emit = dist.Categorical(probs=funsor.Tensor(
            emit_probs,
            inputs=OrderedDict([('latent', funsor.bint(args.hidden_dim))]),
        ))

        x_curr = funsor.Number(0, args.hidden_dim)
        for t, y in enumerate(data):
            x_prev = x_curr

            # A delayed sample statement.
            x_curr = funsor.Variable('x_{}'.format(t),
                                     funsor.bint(args.hidden_dim))
            log_prob += trans(prev=x_prev, value=x_curr)

            if not args.lazy and isinstance(x_prev, funsor.Variable):
                log_prob = log_prob.reduce(ops.logaddexp, x_prev.name)

            log_prob += emit(latent=x_curr, value=funsor.Tensor(y, dtype=2))

        log_prob = log_prob.reduce(ops.logaddexp)
        return log_prob
예제 #4
0
def one_step_prediction(p_x_tp1, t, var_names, emit_eq, emit_noise):
    """Computes p(y_{t+1}) from p(x_{t+1}). We assume y_t is scalar, so only one emit_eq"""
    log_prob = p_x_tp1

    x_tp1s = [
        funsor.Variable(name + '_{}'.format(t + 1), funsor.reals())
        for name in var_names
    ]
    y_tp1 = funsor.Variable('y_{}'.format(t + 1), funsor.reals())
    log_prob += dist.Normal(emit_eq(x_tp1s),
                            torch.exp(emit_noise),
                            value=y_tp1)
    log_prob = log_prob.reduce(ops.logaddexp,
                               frozenset([x_tp1.name for x_tp1 in x_tp1s]))

    return log_prob
예제 #5
0
    def model(data):
        log_prob = funsor.to_funsor(0.)
        xs_curr = [funsor.Tensor(torch.tensor(0.)) for var in var_names]

        for t, y in enumerate(data):
            xs_prev = xs_curr

            # A delayed sample statement.
            xs_curr = [
                funsor.Variable(name + '_{}'.format(t), funsor.reals())
                for name in var_names
            ]

            for i, x_curr in enumerate(xs_curr):
                log_prob += dist.Normal(trans_eqs[var_names[i]](xs_prev),
                                        torch.exp(trans_noises[i]),
                                        value=x_curr)

            if t > 0:
                log_prob = log_prob.reduce(
                    ops.logaddexp,
                    frozenset([x_prev.name for x_prev in xs_prev]))

            # An observe statement.
            log_prob += dist.Normal(emit_eq(xs_curr),
                                    torch.exp(emit_noise),
                                    value=y)

        # Marginalize out all remaining delayed variables.
        return log_prob.reduce(ops.logaddexp), log_prob.gaussian
예제 #6
0
def next_state(p_x_t, t, var_names, trans_eqs, trans_noises):
    """Computes p(x_{t+1}) from p(x_t)"""
    log_prob = p_x_t

    x_ts = [
        funsor.Variable(name + '_{}'.format(t), funsor.reals())
        for name in var_names
    ]
    x_tp1s = [
        funsor.Variable(name + '_{}'.format(t + 1), funsor.reals())
        for name in var_names
    ]

    for i, x_tp1 in enumerate(x_tp1s):
        log_prob += dist.Normal(trans_eqs[var_names[i]](x_ts),
                                torch.exp(trans_noises[i]),
                                value=x_tp1)

    log_prob = log_prob.reduce(ops.logaddexp,
                               frozenset([x_t.name for x_t in x_ts]))
    return log_prob
예제 #7
0
    def _pyro_sample(self, msg):
        # Eagerly convert fn and value to Funsor.
        dim_to_name = {f.dim: f.name for f in msg["cond_indep_stack"]}
        dim_to_name.update(self.preserved_plates)
        msg["fn"] = funsor.to_funsor(msg["fn"], funsor.Real, dim_to_name)
        domain = msg["fn"].inputs["value"]
        if msg["value"] is None:
            msg["value"] = funsor.Variable(msg["name"], domain)
        else:
            msg["value"] = funsor.to_funsor(msg["value"], domain, dim_to_name)

        msg["done"] = True
        msg["stop"] = True
예제 #8
0
    def _forward_funsor(self, features, trip_counts):
        total_hours = len(features)
        observed_hours, num_origins, num_destins = trip_counts.shape
        assert observed_hours == total_hours
        assert num_origins == self.num_stations
        assert num_destins == self.num_stations
        n = self.num_stations
        gate_rate = funsor.Variable("gate_rate_t",
                                    reals(observed_hours, 2 * n * n))["time"]

        @funsor.torch.function(reals(2 * n * n), (reals(n, n, 2), reals(n, n)))
        def unpack_gate_rate(gate_rate):
            batch_shape = gate_rate.shape[:-1]
            gate, rate = gate_rate.reshape(batch_shape + (2, n, n)).unbind(-3)
            gate = gate.sigmoid().clamp(min=0.01, max=0.99)
            rate = bounded_exp(rate, bound=1e4)
            gate = torch.stack((1 - gate, gate), dim=-1)
            return gate, rate

        # Create a Gaussian latent dynamical system.
        init_dist, trans_matrix, trans_dist, obs_matrix, obs_dist = \
            self._dynamics(features[:observed_hours])
        init = dist_to_funsor(init_dist)(value="state")
        trans = matrix_and_mvn_to_funsor(trans_matrix, trans_dist, ("time", ),
                                         "state", "state(time=1)")
        obs = matrix_and_mvn_to_funsor(obs_matrix, obs_dist, ("time", ),
                                       "state(time=1)", "gate_rate")

        # Compute dynamic prior over gate_rate.
        prior = trans + obs(gate_rate=gate_rate)
        prior = MarkovProduct(ops.logaddexp, ops.add, prior, "time",
                              {"state": "state(time=1)"})
        prior += init
        prior = prior.reduce(ops.logaddexp, {"state", "state(time=1)"})

        # Compute zero-inflated Poisson likelihood.
        gate, rate = unpack_gate_rate(gate_rate)
        likelihood = fdist.Categorical(gate["origin", "destin"], value="gated")
        trip_counts = tensor_to_funsor(trip_counts,
                                       ("time", "origin", "destin"))
        likelihood += funsor.Stack(
            "gated",
            (fdist.Poisson(rate["origin", "destin"], value=trip_counts),
             fdist.Delta(0, value=trip_counts)))
        likelihood = likelihood.reduce(ops.logaddexp, "gated")
        likelihood = likelihood.reduce(ops.add, {"time", "origin", "destin"})

        assert set(prior.inputs) == {"gate_rate_t"}, prior.inputs
        assert set(likelihood.inputs) == {"gate_rate_t"}, likelihood.inputs
        return prior, likelihood
예제 #9
0
def update(p_x_tp1, t, y, var_names, emit_eq, emit_noise):
    """Computes p(x_{t+1} | y_{t+1}) from p(x_{t+1}). This is useful for iterating 1-step ahead predictions"""
    log_prob = p_x_tp1

    x_tp1s = [
        funsor.Variable(name + '_{}'.format(t + 1), funsor.reals())
        for name in var_names
    ]
    log_p_x = log_prob

    log_prob += dist.Normal(emit_eq(x_tp1s), emit_noise, value=y)
    log_p_y = log_prob.reduce(ops.logaddexp,
                              frozenset([x_tp1.name for x_tp1 in x_tp1s]))

    log_p_x_y = log_prob + log_p_x - log_p_y
    return log_p_x_y
예제 #10
0
파일: utils.py 프로젝트: vicgalle/vis
def generate_HMM_dataset(model, args):
    """ Generates a sequence of observations from a given funsor model
    """

    data = [
        funsor.Variable('y_{}'.format(t), funsor.bint(args.hidden_dim))
        for t in range(args.time_steps)
    ]

    log_prob = model(data)
    var = [key for key, value in log_prob.inputs.items()]
    # TODO: move sample to model definition, to avoid memory explosion
    r = log_prob.sample(frozenset(var))
    data = torch.tensor([
        r.deltas[i].point.data for i in range(len(r.deltas))
        if r.deltas[i].name.startswith('y')
    ])

    return data
예제 #11
0
    def model(data):
        log_prob = funsor.to_funsor(0.)

        x_curr = funsor.Tensor(torch.tensor(0.))
        for t, y in enumerate(data):
            x_prev = x_curr

            # A delayed sample statement.
            x_curr = funsor.Variable('x_{}'.format(t), funsor.reals())
            log_prob += dist.Normal(1 + x_prev / 2., trans_noise, value=x_curr)

            # Optionally marginalize out the previous state.
            if t > 0 and not args.lazy:
                log_prob = log_prob.reduce(ops.logaddexp, x_prev.name)

            # An observe statement.
            log_prob += dist.Normal(0.5 + 3 * x_curr, emit_noise, value=y)

        # Marginalize out all remaining delayed variables.
        log_prob = log_prob.reduce(ops.logaddexp)
        return log_prob
예제 #12
0
def log_density(model, model_args, model_kwargs, params):
    """
    Similar to :func:`numpyro.infer.util.log_density` but works for models
    with discrete latent variables. Internally, this uses :mod:`funsor`
    to marginalize discrete latent sites and evaluate the joint log probability.

    :param model: Python callable containing NumPyro primitives. Typically,
        the model has been enumerated by using
        :class:`~numpyro.contrib.funsor.enum_messenger.enum` handler::

            def model(*args, **kwargs):
                ...

            log_joint = log_density(enum(config_enumerate(model)), args, kwargs, params)

    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :param dict params: dictionary of current parameter values keyed by site
        name.
    :return: log of joint density and a corresponding model trace
    """
    model = substitute(model, data=params)
    with plate_to_enum_plate():
        model_trace = packed_trace(model).get_trace(*model_args,
                                                    **model_kwargs)
    log_factors = []
    time_to_factors = defaultdict(list)  # log prob factors
    time_to_init_vars = defaultdict(frozenset)  # _init/... variables
    time_to_markov_dims = defaultdict(frozenset)  # dimensions at markov sites
    sum_vars, prod_vars = frozenset(), frozenset()
    for site in model_trace.values():
        if site['type'] == 'sample':
            value = site['value']
            intermediates = site['intermediates']
            scale = site['scale']
            if intermediates:
                log_prob = site['fn'].log_prob(value, intermediates)
            else:
                log_prob = site['fn'].log_prob(value)

            if (scale is not None) and (not is_identically_one(scale)):
                log_prob = scale * log_prob

            dim_to_name = site["infer"]["dim_to_name"]
            log_prob = funsor.to_funsor(log_prob,
                                        output=funsor.reals(),
                                        dim_to_name=dim_to_name)

            time_dim = None
            for dim, name in dim_to_name.items():
                if name.startswith("_time"):
                    time_dim = funsor.Variable(
                        name, funsor.domains.bint(site["value"].shape[dim]))
                    time_to_factors[time_dim].append(log_prob)
                    time_to_init_vars[time_dim] |= frozenset(
                        s for s in dim_to_name.values()
                        if s.startswith("_init"))
                    break
            if time_dim is None:
                log_factors.append(log_prob)

            if not site['is_observed']:
                sum_vars |= frozenset({site['name']})
            prod_vars |= frozenset(f.name for f in site['cond_indep_stack']
                                   if f.dim is not None)

    for time_dim, init_vars in time_to_init_vars.items():
        for var in init_vars:
            curr_var = "/".join(var.split("/")[1:])
            dim_to_name = model_trace[curr_var]["infer"]["dim_to_name"]
            if var in dim_to_name.values(
            ):  # i.e. _init (i.e. prev) in dim_to_name
                time_to_markov_dims[time_dim] |= frozenset(
                    name for name in dim_to_name.values())

    if len(time_to_factors) > 0:
        markov_factors = compute_markov_factors(time_to_factors,
                                                time_to_init_vars,
                                                time_to_markov_dims, sum_vars,
                                                prod_vars)
        log_factors = log_factors + markov_factors

    with funsor.interpreter.interpretation(funsor.terms.lazy):
        lazy_result = funsor.sum_product.sum_product(funsor.ops.logaddexp,
                                                     funsor.ops.add,
                                                     log_factors,
                                                     eliminate=sum_vars
                                                     | prod_vars,
                                                     plates=prod_vars)
    result = funsor.optimizer.apply_optimizer(lazy_result)
    if len(result.inputs) > 0:
        raise ValueError(
            "Expected the joint log density is a scalar, but got {}. "
            "There seems to be something wrong at the following sites: {}.".
            format(result.data.shape,
                   {k.split("__BOUND")[0]
                    for k in result.inputs}))
    return result.data, model_trace
예제 #13
0
    def filter_and_predict(self, data, smoothing=False):
        trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist = self.get_tensors_and_dists(
        )

        log_prob = funsor.Number(0.)

        s_vars = {
            -1: funsor.Tensor(torch.tensor(0), dtype=self.num_components)
        }
        x_vars = {-1: None}

        predictive_x_dists, predictive_y_dists, filtering_dists = [], [], []
        test_LLs = []

        for t, y in enumerate(data):
            s_vars[t] = funsor.Variable(f's_{t}',
                                        funsor.bint(self.num_components))
            x_vars[t] = funsor.Variable(f'x_{t}',
                                        funsor.reals(self.hidden_dim))

            log_prob += dist.Categorical(trans_probs(s=s_vars[t - 1]),
                                         value=s_vars[t])

            if t == 0:
                log_prob += self.x_init_mvn(value=x_vars[t])
            else:
                log_prob += x_trans_dist(s=s_vars[t],
                                         x=x_vars[t - 1],
                                         y=x_vars[t])

            if t > 0:
                log_prob = log_prob.reduce(
                    ops.logaddexp,
                    frozenset([s_vars[t - 1].name, x_vars[t - 1].name]))

            # do 1-step prediction and compute test LL
            if t > 0:
                predictive_x_dists.append(log_prob)
                _log_prob = log_prob - log_prob.reduce(ops.logaddexp)
                predictive_y_dist = y_dist(s=s_vars[t],
                                           x=x_vars[t]) + _log_prob
                test_LLs.append(
                    predictive_y_dist(y=y).reduce(ops.logaddexp).data.item())
                predictive_y_dist = predictive_y_dist.reduce(
                    ops.logaddexp, frozenset([f"x_{t}", f"s_{t}"]))
                predictive_y_dists.append(
                    funsor_to_mvn(predictive_y_dist, 0, ()))

            log_prob += y_dist(s=s_vars[t], x=x_vars[t], y=y)

            # save filtering dists for forward-backward smoothing
            if smoothing:
                filtering_dists.append(log_prob)

        # do the backward recursion using previously computed ingredients
        if smoothing:
            # seed the backward recursion with the filtering distribution at t=T
            smoothing_dists = [filtering_dists[-1]]
            T = data.size(0)

            s_vars = {
                t: funsor.Variable(f's_{t}', funsor.bint(self.num_components))
                for t in range(T)
            }
            x_vars = {
                t: funsor.Variable(f'x_{t}', funsor.reals(self.hidden_dim))
                for t in range(T)
            }

            # do the backward recursion.
            # let p[t|t-1] be the predictive distribution at time step t.
            # let p[t|t] be the filtering distribution at time step t.
            # let f[t] denote the prior (transition) density at time step t.
            # then the smoothing distribution p[t|T] at time step t is
            # given by the following recursion.
            # p[t-1|T] = p[t-1|t-1] <p[t|T] f[t] / p[t|t-1]>
            # where <...> denotes integration of the latent variables at time step t.
            for t in reversed(range(T - 1)):
                integral = smoothing_dists[-1] - predictive_x_dists[t]
                integral += dist.Categorical(trans_probs(s=s_vars[t]),
                                             value=s_vars[t + 1])
                integral += x_trans_dist(s=s_vars[t],
                                         x=x_vars[t],
                                         y=x_vars[t + 1])
                integral = integral.reduce(
                    ops.logaddexp,
                    frozenset([s_vars[t + 1].name, x_vars[t + 1].name]))
                smoothing_dists.append(filtering_dists[t] + integral)

        # compute predictive test MSE and predictive variances
        predictive_means = torch.stack([d.mean for d in predictive_y_dists
                                        ])  # T-1 ydim
        predictive_vars = torch.stack([
            d.covariance_matrix.diagonal(dim1=-1, dim2=-2)
            for d in predictive_y_dists
        ])
        predictive_mse = (predictive_means - data[1:, :]).pow(2.0).mean(-1)

        if smoothing:
            # compute smoothed mean function
            smoothing_dists = [
                funsor_to_cat_and_mvn(d, 0, (f"s_{t}", ))
                for t, d in enumerate(reversed(smoothing_dists))
            ]
            means = torch.stack([d[1].mean
                                 for d in smoothing_dists])  # T 2 xdim
            means = torch.matmul(means.unsqueeze(-2),
                                 self.observation_matrix).squeeze(
                                     -2)  # T 2 ydim

            probs = torch.stack([d[0].logits for d in smoothing_dists]).exp()
            probs = probs / probs.sum(-1, keepdim=True)  # T 2

            smoothing_means = (probs.unsqueeze(-1) * means).sum(-2)  # T ydim
            smoothing_probs = probs[:, 1]

            return predictive_mse, torch.tensor(np.array(test_LLs)), predictive_means, predictive_vars, \
                smoothing_means, smoothing_probs
        else:
            return predictive_mse, torch.tensor(np.array(test_LLs))
예제 #14
0
 def process_message(self, msg):
     if msg["type"] == "sample":
         if msg["value"] is None:
             # Create a delayed sample.
             msg["value"] = funsor.Variable(msg["name"], msg["fn"].output)
예제 #15
0
def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op):
    """Helper function to compute elbo and extract its components from execution traces."""
    model = substitute(model, data=params)
    with plate_to_enum_plate():
        model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
    log_factors = []
    time_to_factors = defaultdict(list)  # log prob factors
    time_to_init_vars = defaultdict(frozenset)  # PP... variables
    time_to_markov_dims = defaultdict(frozenset)  # dimensions at markov sites
    sum_vars, prod_vars = frozenset(), frozenset()
    history = 1
    log_measures = {}
    for site in model_trace.values():
        if site["type"] == "sample":
            value = site["value"]
            intermediates = site["intermediates"]
            scale = site["scale"]
            if intermediates:
                log_prob = site["fn"].log_prob(value, intermediates)
            else:
                log_prob = site["fn"].log_prob(value)

            if (scale is not None) and (not is_identically_one(scale)):
                log_prob = scale * log_prob

            dim_to_name = site["infer"]["dim_to_name"]
            log_prob_factor = funsor.to_funsor(
                log_prob, output=funsor.Real, dim_to_name=dim_to_name
            )

            time_dim = None
            for dim, name in dim_to_name.items():
                if name.startswith("_time"):
                    time_dim = funsor.Variable(name, funsor.Bint[log_prob.shape[dim]])
                    time_to_factors[time_dim].append(log_prob_factor)
                    history = max(
                        history, max(_get_shift(s) for s in dim_to_name.values())
                    )
                    time_to_init_vars[time_dim] |= frozenset(
                        s for s in dim_to_name.values() if s.startswith("_PREV_")
                    )
                    break
            if time_dim is None:
                log_factors.append(log_prob_factor)

            if not site["is_observed"]:
                log_measures[site["name"]] = log_prob_factor
                sum_vars |= frozenset({site["name"]})

            prod_vars |= frozenset(
                f.name for f in site["cond_indep_stack"] if f.dim is not None
            )

    for time_dim, init_vars in time_to_init_vars.items():
        for var in init_vars:
            curr_var = _shift_name(var, -_get_shift(var))
            dim_to_name = model_trace[curr_var]["infer"]["dim_to_name"]
            if var in dim_to_name.values():  # i.e. _PREV_* (i.e. prev) in dim_to_name
                time_to_markov_dims[time_dim] |= frozenset(
                    name for name in dim_to_name.values()
                )

    if len(time_to_factors) > 0:
        markov_factors = compute_markov_factors(
            time_to_factors,
            time_to_init_vars,
            time_to_markov_dims,
            sum_vars,
            prod_vars,
            history,
            sum_op,
            prod_op,
        )
        log_factors = log_factors + markov_factors

    with funsor.interpretations.lazy:
        lazy_result = funsor.sum_product.sum_product(
            sum_op,
            prod_op,
            log_factors,
            eliminate=sum_vars | prod_vars,
            plates=prod_vars,
        )
    result = funsor.optimizer.apply_optimizer(lazy_result)
    if len(result.inputs) > 0:
        raise ValueError(
            "Expected the joint log density is a scalar, but got {}. "
            "There seems to be something wrong at the following sites: {}.".format(
                result.data.shape, {k.split("__BOUND")[0] for k in result.inputs}
            )
        )
    return result, model_trace, log_measures