Example #1
0
def main(args):
    # Declare parameters.
    trans_probs = funsor.Tensor(
        torch.tensor([[0.9, 0.1], [0.1, 0.9]], requires_grad=True))
    trans_noise = funsor.Tensor(
        torch.tensor(
            [
                0.1,  # low noise component
                1.0,  # high noisy component
            ],
            requires_grad=True))
    emit_noise = funsor.Tensor(torch.tensor(0.5, requires_grad=True))
    params = [trans_probs.data, trans_noise.data, emit_noise.data]

    # A Gaussian HMM model.
    @funsor.interpreter.interpretation(funsor.terms.moment_matching)
    def model(data):
        log_prob = funsor.Number(0.)

        # s is the discrete latent state,
        # x is the continuous latent state,
        # y is the observed state.
        s_curr = funsor.Tensor(torch.tensor(0), dtype=2)
        x_curr = funsor.Tensor(torch.tensor(0.))
        for t, y in enumerate(data):
            s_prev = s_curr
            x_prev = x_curr

            # A delayed sample statement.
            s_curr = funsor.Variable('s_{}'.format(t), funsor.bint(2))
            log_prob += dist.Categorical(trans_probs[s_prev], value=s_curr)

            # A delayed sample statement.
            x_curr = funsor.Variable('x_{}'.format(t), funsor.reals())
            log_prob += dist.Normal(x_prev, trans_noise[s_curr], value=x_curr)

            # Marginalize out previous delayed sample statements.
            if t > 0:
                log_prob = log_prob.reduce(ops.logaddexp,
                                           {s_prev.name, x_prev.name})

            # An observe statement.
            log_prob += dist.Normal(x_curr, emit_noise, value=y)

        log_prob = log_prob.reduce(ops.logaddexp)
        return log_prob

    # Train model parameters.
    torch.manual_seed(0)
    data = torch.randn(args.time_steps)
    optim = torch.optim.Adam(params, lr=args.learning_rate)
    for step in range(args.train_steps):
        optim.zero_grad()
        log_prob = model(data)
        assert not log_prob.inputs, 'free variables remain'
        loss = -log_prob.data
        loss.backward()
        optim.step()
        if args.verbose and step % 10 == 0:
            print('step {} loss = {}'.format(step, loss.item()))
Example #2
0
    def model(data):
        log_prob = funsor.Number(0.)

        # s is the discrete latent state,
        # x is the continuous latent state,
        # y is the observed state.
        s_curr = funsor.Tensor(torch.tensor(0), dtype=2)
        x_curr = funsor.Tensor(torch.tensor(0.))
        for t, y in enumerate(data):
            s_prev = s_curr
            x_prev = x_curr

            # A delayed sample statement.
            s_curr = funsor.Variable('s_{}'.format(t), funsor.bint(2))
            log_prob += dist.Categorical(trans_probs[s_prev], value=s_curr)

            # A delayed sample statement.
            x_curr = funsor.Variable('x_{}'.format(t), funsor.reals())
            log_prob += dist.Normal(x_prev, trans_noise[s_curr], value=x_curr)

            # Marginalize out previous delayed sample statements.
            if t > 0:
                log_prob = log_prob.reduce(ops.logaddexp,
                                           {s_prev.name, x_prev.name})

            # An observe statement.
            log_prob += dist.Normal(x_curr, emit_noise, value=y)

        log_prob = log_prob.reduce(ops.logaddexp)
        return log_prob
Example #3
0
    def model(data):
        log_prob = funsor.to_funsor(0.)

        trans = dist.Categorical(probs=funsor.Tensor(
            trans_probs,
            inputs=OrderedDict([('prev', funsor.bint(args.hidden_dim))]),
        ))

        emit = dist.Categorical(probs=funsor.Tensor(
            emit_probs,
            inputs=OrderedDict([('latent', funsor.bint(args.hidden_dim))]),
        ))

        x_curr = funsor.Number(0, args.hidden_dim)
        for t, y in enumerate(data):
            x_prev = x_curr

            # A delayed sample statement.
            x_curr = funsor.Variable('x_{}'.format(t),
                                     funsor.bint(args.hidden_dim))
            log_prob += trans(prev=x_prev, value=x_curr)

            if not args.lazy and isinstance(x_prev, funsor.Variable):
                log_prob = log_prob.reduce(ops.logaddexp, x_prev.name)

            log_prob += emit(latent=x_curr, value=funsor.Tensor(y, dtype=2))

        log_prob = log_prob.reduce(ops.logaddexp)
        return log_prob
Example #4
0
def _get_support_value_tensor(funsor_dist, name, **kwargs):
    assert name in funsor_dist.inputs
    return funsor.Tensor(
        funsor.ops.new_arange(funsor_dist.data, funsor_dist.inputs[name].size),
        OrderedDict([(name, funsor_dist.inputs[name])]),
        funsor_dist.inputs[name].size,
    )
Example #5
0
    def get_tensors_and_dists(self):
        # normalize the transition probabilities
        trans_logits = self.transition_logits - self.transition_logits.logsumexp(
            dim=-1, keepdim=True)
        trans_probs = funsor.Tensor(
            trans_logits,
            OrderedDict([("s", funsor.bint(self.num_components))]))

        trans_mvn = torch.distributions.MultivariateNormal(
            torch.zeros(self.hidden_dim),
            self.log_transition_noise.exp().diag_embed())
        obs_mvn = torch.distributions.MultivariateNormal(
            torch.zeros(self.obs_dim),
            self.log_obs_noise.exp().diag_embed())

        event_dims = (
            "s",
        ) if self.fine_transition_matrix or self.fine_transition_noise else ()
        x_trans_dist = matrix_and_mvn_to_funsor(self.transition_matrix,
                                                trans_mvn, event_dims, "x",
                                                "y")
        event_dims = (
            "s",
        ) if self.fine_observation_matrix or self.fine_observation_noise else (
        )
        y_dist = matrix_and_mvn_to_funsor(self.observation_matrix, obs_mvn,
                                          event_dims, "x", "y")

        return trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist
Example #6
0
    def model(data):
        log_prob = funsor.to_funsor(0.)
        xs_curr = [funsor.Tensor(torch.tensor(0.)) for var in var_names]

        for t, y in enumerate(data):
            xs_prev = xs_curr

            # A delayed sample statement.
            xs_curr = [
                funsor.Variable(name + '_{}'.format(t), funsor.reals())
                for name in var_names
            ]

            for i, x_curr in enumerate(xs_curr):
                log_prob += dist.Normal(trans_eqs[var_names[i]](xs_prev),
                                        torch.exp(trans_noises[i]),
                                        value=x_curr)

            if t > 0:
                log_prob = log_prob.reduce(
                    ops.logaddexp,
                    frozenset([x_prev.name for x_prev in xs_prev]))

            # An observe statement.
            log_prob += dist.Normal(emit_eq(xs_curr),
                                    torch.exp(emit_noise),
                                    value=y)

        # Marginalize out all remaining delayed variables.
        return log_prob.reduce(ops.logaddexp), log_prob.gaussian
Example #7
0
def _enum_strategy_mixture(dist, msg):
    sample_dim_name = "{}__PARTICLES".format(msg["name"])
    sample_inputs = OrderedDict(
        {sample_dim_name: funsor.Bint[msg['infer']['num_samples']]})
    plate_names = frozenset(f.name for f in msg["cond_indep_stack"]
                            if f.vectorized)
    ancestor_names = frozenset(
        k for k, v in dist.inputs.items()
        if v.dtype != 'real' and k != msg["name"] and k not in plate_names)
    plate_inputs = OrderedDict((k, dist.inputs[k]) for k in plate_names)
    # TODO should the ancestor_indices be pyro.sampled?
    ancestor_indices = {
        # TODO make this comprehension less gross
        name: _get_support_value(
            funsor.torch.distributions.CategoricalLogits(
                # sample different ancestors for each plate slice
                logits=funsor.Tensor(
                    # TODO avoid use of torch.zeros here in favor of funsor.ops.new_zeros
                    torch.zeros((1, )).expand(
                        tuple(v.dtype for v in plate_inputs.values()) +
                        (dist.inputs[name].dtype, )),
                    plate_inputs), )(value=name).sample(name, sample_inputs),
            name)
        for name in ancestor_names
    }
    sampled_dist = dist(**ancestor_indices).sample(
        msg["name"], sample_inputs if not ancestor_indices else None)
    if ancestor_indices:  # XXX is there a better way to account for this in funsor?
        sampled_dist = sampled_dist - math.log(msg["infer"]["num_samples"])
    return sampled_dist
Example #8
0
def main(args):
    funsor.set_backend("torch")

    # XXX Temporary fix after https://github.com/pyro-ppl/pyro/pull/2701
    import pyro
    pyro.enable_validation(False)

    encoder = Encoder()
    decoder = Decoder()

    encode = funsor.function(Reals[28, 28], (Reals[20], Reals[20]))(encoder)
    decode = funsor.function(Reals[20], Reals[28, 28])(decoder)

    @funsor.interpretation(funsor.montecarlo.MonteCarlo())
    def loss_function(data, subsample_scale):
        # Lazily sample from the guide.
        loc, scale = encode(data)
        q = funsor.Independent(dist.Normal(loc['i'], scale['i'], value='z_i'),
                               'z', 'i', 'z_i')

        # Evaluate the model likelihood at the lazy value z.
        probs = decode('z')
        p = dist.Bernoulli(probs['x', 'y'], value=data['x', 'y'])
        p = p.reduce(ops.add, {'x', 'y'})

        # Construct an elbo. This is where sampling happens.
        elbo = funsor.Integrate(q, p - q, 'z')
        elbo = elbo.reduce(ops.add, 'batch') * subsample_scale
        loss = -elbo
        return loss

    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        DATA_PATH, train=True, download=True, transform=transforms.ToTensor()),
                                               batch_size=args.batch_size,
                                               shuffle=True)

    encoder.train()
    decoder.train()
    optimizer = optim.Adam(list(encoder.parameters()) +
                           list(decoder.parameters()),
                           lr=1e-3)
    for epoch in range(args.num_epochs):
        train_loss = 0
        for batch_idx, (data, _) in enumerate(train_loader):
            subsample_scale = float(len(train_loader.dataset) / len(data))
            data = data[:, 0, :, :]
            data = funsor.Tensor(data, OrderedDict(batch=Bint[len(data)]))

            optimizer.zero_grad()
            loss = loss_function(data, subsample_scale)
            assert isinstance(loss, funsor.Tensor), loss.pretty()
            loss.data.backward()
            train_loss += loss.item()
            optimizer.step()
            if batch_idx % 50 == 0:
                print('  loss = {}'.format(loss.item()))
                if batch_idx and args.smoke_test:
                    return
        print('epoch {} train_loss = {}'.format(epoch, train_loss))
Example #9
0
    def log_prob(self, data):
        trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist = self.get_tensors_and_dists(
        )

        log_prob = funsor.Number(0.)

        s_vars = {
            -1: funsor.Tensor(torch.tensor(0), dtype=self.num_components)
        }
        x_vars = {}

        for t, y in enumerate(data):
            # construct free variables for s_t and x_t
            s_vars[t] = funsor.Variable(f's_{t}',
                                        funsor.bint(self.num_components))
            x_vars[t] = funsor.Variable(f'x_{t}',
                                        funsor.reals(self.hidden_dim))

            # incorporate the discrete switching dynamics
            log_prob += dist.Categorical(trans_probs(s=s_vars[t - 1]),
                                         value=s_vars[t])

            # incorporate the prior term p(x_t | x_{t-1})
            if t == 0:
                log_prob += self.x_init_mvn(value=x_vars[t])
            else:
                log_prob += x_trans_dist(s=s_vars[t],
                                         x=x_vars[t - 1],
                                         y=x_vars[t])

            # do a moment-matching reduction. at this point log_prob depends on (moment_matching_lag + 1)-many
            # pairs of free variables.
            if t > self.moment_matching_lag - 1:
                log_prob = log_prob.reduce(
                    ops.logaddexp,
                    frozenset([
                        s_vars[t - self.moment_matching_lag].name,
                        x_vars[t - self.moment_matching_lag].name
                    ]))

            # incorporate the observation p(y_t | x_t, s_t)
            log_prob += y_dist(s=s_vars[t], x=x_vars[t], y=y)

        T = data.shape[0]
        # reduce any remaining free variables
        for t in range(self.moment_matching_lag):
            log_prob = log_prob.reduce(
                ops.logaddexp,
                frozenset([
                    s_vars[T - self.moment_matching_lag + t].name,
                    x_vars[T - self.moment_matching_lag + t].name
                ]))

        # assert that we've reduced all the free variables in log_prob
        assert not log_prob.inputs, 'unexpected free variables remain'

        # return the PyTorch tensor behind log_prob (which we can directly differentiate)
        return log_prob.data
Example #10
0
def main(args):
    # Generate fake data.
    data = funsor.Tensor(torch.randn(100),
                         inputs=OrderedDict([('data', funsor.bint(100))]),
                         output=funsor.reals())

    # Train.
    optim = pyro.Adam({'lr': args.learning_rate})
    svi = pyro.SVI(model, pyro.deferred(guide), optim, pyro.elbo)
    for step in range(args.steps):
        svi.step(data)
Example #11
0
def test_bernoullilogits_enumerate_support(expand, batch_shape):
    batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
    inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape))

    logits = funsor.Tensor(rand(batch_shape), inputs, 'real')
    with interpretation(lazy):
        d = dist.BernoulliLogits(logits)
    x = d.enumerate_support(expand=expand)
    actual_log_prob = d(value='value2')(value2=x).reduce(ops.logaddexp, 'value')

    raw_dist = d.dist_class(logits=logits.data)
    raw_value = raw_dist.enumerate_support(expand=expand)
    expected_inputs = OrderedDict([('value', Bint[raw_value.shape[0]])])
    expected_inputs.update(inputs)
    expected_log_prob = funsor.Tensor(raw_dist.log_prob(raw_value), expected_inputs).reduce(ops.logaddexp, 'value')

    assert d.has_enumerate_support
    assert x.output == d.value.output
    assert set(x.inputs) == {'value'} | (set(batch_dims) if expand else set())
    assert_close(expected_log_prob, actual_log_prob)
Example #12
0
def main(args):
    encoder = Encoder()
    decoder = Decoder()

    encode = funsor.torch.function(reals(28, 28),
                                   (reals(20), reals(20)))(encoder)
    decode = funsor.torch.function(reals(20), reals(28, 28))(decoder)

    @funsor.interpreter.interpretation(funsor.montecarlo.monte_carlo)
    def loss_function(data, subsample_scale):
        # Lazily sample from the guide.
        loc, scale = encode(data)
        q = funsor.Independent(dist.Normal(loc['i'], scale['i'], value='z'),
                               'z', 'i')

        # Evaluate the model likelihood at the lazy value z.
        probs = decode('z')
        p = dist.Bernoulli(probs['x', 'y'], value=data['x', 'y'])
        p = p.reduce(ops.add, frozenset(['x', 'y']))

        # Construct an elbo. This is where sampling happens.
        elbo = funsor.Integrate(q, p - q, frozenset(['z']))
        elbo = elbo.reduce(ops.add, 'batch') * subsample_scale
        loss = -elbo
        return loss

    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        DATA_PATH, train=True, download=True, transform=transforms.ToTensor()),
                                               batch_size=args.batch_size,
                                               shuffle=True)

    encoder.train()
    decoder.train()
    optimizer = optim.Adam(list(encoder.parameters()) +
                           list(decoder.parameters()),
                           lr=1e-3)
    for epoch in range(args.num_epochs):
        train_loss = 0
        for batch_idx, (data, _) in enumerate(train_loader):
            subsample_scale = float(len(train_loader.dataset) / len(data))
            data = data[:, 0, :, :]
            data = funsor.Tensor(data, OrderedDict(batch=bint(len(data))))

            optimizer.zero_grad()
            loss = loss_function(data, subsample_scale)
            assert isinstance(loss, funsor.torch.Tensor), loss.pretty()
            loss.data.backward()
            train_loss += loss.item()
            optimizer.step()
            if batch_idx % 50 == 0:
                print('  loss = {}'.format(loss.item()))
                if batch_idx and args.smoke_test:
                    return
        print('epoch {} train_loss = {}'.format(epoch, train_loss))
Example #13
0
 def __init__(self, name, size, subsample_size=None, dim=None):
     self.name = name
     self.size = size
     self.subsample_size = size if subsample_size is None else subsample_size
     if dim is not None and dim >= 0:
         raise ValueError('dim arg must be negative.')
     self.dim = dim
     self._indices = funsor.Tensor(
         funsor.ops.new_arange(funsor.tensor.get_default_prototype(),
                               self.size),
         OrderedDict([(self.name, funsor.bint(self.size))]), self.size)
     super(plate, self).__init__(None)
Example #14
0
 def __init__(self, name, size, subsample_size=None, dim=None):
     self.name = name
     self.size = size
     if dim is not None and dim >= 0:
         raise ValueError('dim arg must be negative.')
     self.dim, indices = OrigPlateMessenger._subsample(
         self.name, self.size, subsample_size, dim)
     self.subsample_size = indices.shape[0]
     self._indices = funsor.Tensor(
         indices,
         OrderedDict([(self.name, funsor.bint(self.subsample_size))]),
         self.subsample_size)
     super(plate, self).__init__(None)
Example #15
0
    def __init__(self, name=None, size=None, dim=None, indices=None):
        assert dim is None or dim < 0
        super().__init__()
        # without a name or dim, treat as a "vectorize" effect and allocate a non-visible dim
        self.dim_type = DimType.GLOBAL if name is None and dim is None else DimType.VISIBLE
        self.name = name if name is not None else funsor.interpreter.gensym(
            "PLATE")
        self.size = size
        self.dim = dim
        if not hasattr(self, "_full_size"):
            self._full_size = size
        if indices is None:
            indices = funsor.ops.new_arange(
                funsor.tensor.get_default_prototype(), self.size)
        assert len(indices) == size

        self._indices = funsor.Tensor(
            indices, OrderedDict([(self.name, funsor.Bint[self.size])]),
            self._full_size)
Example #16
0
    def model(data):
        log_prob = funsor.to_funsor(0.)

        x_curr = funsor.Tensor(torch.tensor(0.))
        for t, y in enumerate(data):
            x_prev = x_curr

            # A delayed sample statement.
            x_curr = funsor.Variable('x_{}'.format(t), funsor.reals())
            log_prob += dist.Normal(1 + x_prev / 2., trans_noise, value=x_curr)

            # Optionally marginalize out the previous state.
            if t > 0 and not args.lazy:
                log_prob = log_prob.reduce(ops.logaddexp, x_prev.name)

            # An observe statement.
            log_prob += dist.Normal(0.5 + 3 * x_curr, emit_noise, value=y)

        # Marginalize out all remaining delayed variables.
        log_prob = log_prob.reduce(ops.logaddexp)
        return log_prob
Example #17
0
    def process_message(self, msg):
        if msg["type"] != "sample" or \
                msg.get("done", False) or msg["is_observed"] or msg["infer"].get("expand", False) or \
                msg["infer"].get("enumerate") != "parallel" or (not msg["fn"].has_enumerate_support):
            if msg["type"] == "control_flow":
                msg["kwargs"]["enum"] = True
            return super().process_message(msg)

        if msg["infer"].get("num_samples", None) is not None:
            raise NotImplementedError("TODO implement multiple sampling")

        if msg["infer"].get("expand", False):
            raise NotImplementedError("expand=True not implemented")

        size = msg["fn"].enumerate_support(expand=False).shape[0]
        raw_value = jnp.arange(0, size)
        funsor_value = funsor.Tensor(
            raw_value, OrderedDict([(msg["name"], funsor.bint(size))]), size)

        msg["value"] = to_data(funsor_value)
        msg["done"] = True
Example #18
0
    def differentiable_loss(self, model, guide, *args, **kwargs):

        # get batched, enumerated, to_funsor-ed traces from the guide and model
        with plate(
                size=self.num_particles
        ) if self.num_particles > 1 else contextlib.ExitStack(), enum(
                first_available_dim=(-self.max_plate_nesting -
                                     1) if self.max_plate_nesting else None):
            guide_tr = trace(guide).get_trace(*args, **kwargs)
            model_tr = trace(replay(model, trace=guide_tr)).get_trace(
                *args, **kwargs)

        # extract from traces all metadata that we will need to compute the elbo
        guide_terms = terms_from_trace(guide_tr)
        model_terms = terms_from_trace(model_tr)

        # build up a lazy expression for the elbo
        with funsor.terms.lazy:
            # identify and contract out auxiliary variables in the model with partial_sum_product
            contracted_factors, uncontracted_factors = [], []
            for f in model_terms["log_factors"]:
                if model_terms["measure_vars"].intersection(f.inputs):
                    contracted_factors.append(f)
                else:
                    uncontracted_factors.append(f)
            # incorporate the effects of subsampling and handlers.scale through a common scale factor
            contracted_costs = [
                model_terms["scale"] * f
                for f in funsor.sum_product.partial_sum_product(
                    funsor.ops.logaddexp,
                    funsor.ops.add,
                    model_terms["log_measures"] + contracted_factors,
                    plates=model_terms["plate_vars"],
                    eliminate=model_terms["measure_vars"],
                )
            ]

            # accumulate costs from model (logp) and guide (-logq)
            costs = contracted_costs + uncontracted_factors  # model costs: logp
            costs += [-f for f in guide_terms["log_factors"]
                      ]  # guide costs: -logq

            # compute expected cost
            # Cf. pyro.infer.util.Dice.compute_expectation()
            # https://github.com/pyro-ppl/pyro/blob/0.3.0/pyro/infer/util.py#L212
            # TODO Replace this with funsor.Expectation
            plate_vars = guide_terms["plate_vars"] | model_terms["plate_vars"]
            # compute the marginal logq in the guide corresponding to each cost term
            targets = dict()
            for cost in costs:
                input_vars = frozenset(cost.inputs)
                if input_vars not in targets:
                    targets[input_vars] = funsor.Tensor(
                        funsor.ops.new_zeros(
                            funsor.tensor.get_default_prototype(),
                            tuple(v.size for v in cost.inputs.values()),
                        ),
                        cost.inputs,
                        cost.dtype,
                    )
            with AdjointTape() as tape:
                logzq = funsor.sum_product.sum_product(
                    funsor.ops.logaddexp,
                    funsor.ops.add,
                    guide_terms["log_measures"] + list(targets.values()),
                    plates=plate_vars,
                    eliminate=(plate_vars | guide_terms["measure_vars"]),
                )
            marginals = tape.adjoint(funsor.ops.logaddexp, funsor.ops.add,
                                     logzq, tuple(targets.values()))
            # finally, integrate out guide variables in the elbo and all plates
            elbo = to_funsor(0, output=funsor.Real)
            for cost in costs:
                target = targets[frozenset(cost.inputs)]
                logzq_local = marginals[target].reduce(
                    funsor.ops.logaddexp,
                    frozenset(cost.inputs) - plate_vars)
                log_prob = marginals[target] - logzq_local
                elbo_term = funsor.Integrate(
                    log_prob,
                    cost,
                    guide_terms["measure_vars"] & frozenset(log_prob.inputs),
                )
                elbo += elbo_term.reduce(funsor.ops.add,
                                         plate_vars & frozenset(cost.inputs))

        # evaluate the elbo, using memoize to share tensor computation where possible
        with funsor.interpretations.memoize():
            return -to_data(apply_optimizer(elbo))
Example #19
0
    def filter_and_predict(self, data, smoothing=False):
        trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist = self.get_tensors_and_dists(
        )

        log_prob = funsor.Number(0.)

        s_vars = {
            -1: funsor.Tensor(torch.tensor(0), dtype=self.num_components)
        }
        x_vars = {-1: None}

        predictive_x_dists, predictive_y_dists, filtering_dists = [], [], []
        test_LLs = []

        for t, y in enumerate(data):
            s_vars[t] = funsor.Variable(f's_{t}',
                                        funsor.bint(self.num_components))
            x_vars[t] = funsor.Variable(f'x_{t}',
                                        funsor.reals(self.hidden_dim))

            log_prob += dist.Categorical(trans_probs(s=s_vars[t - 1]),
                                         value=s_vars[t])

            if t == 0:
                log_prob += self.x_init_mvn(value=x_vars[t])
            else:
                log_prob += x_trans_dist(s=s_vars[t],
                                         x=x_vars[t - 1],
                                         y=x_vars[t])

            if t > 0:
                log_prob = log_prob.reduce(
                    ops.logaddexp,
                    frozenset([s_vars[t - 1].name, x_vars[t - 1].name]))

            # do 1-step prediction and compute test LL
            if t > 0:
                predictive_x_dists.append(log_prob)
                _log_prob = log_prob - log_prob.reduce(ops.logaddexp)
                predictive_y_dist = y_dist(s=s_vars[t],
                                           x=x_vars[t]) + _log_prob
                test_LLs.append(
                    predictive_y_dist(y=y).reduce(ops.logaddexp).data.item())
                predictive_y_dist = predictive_y_dist.reduce(
                    ops.logaddexp, frozenset([f"x_{t}", f"s_{t}"]))
                predictive_y_dists.append(
                    funsor_to_mvn(predictive_y_dist, 0, ()))

            log_prob += y_dist(s=s_vars[t], x=x_vars[t], y=y)

            # save filtering dists for forward-backward smoothing
            if smoothing:
                filtering_dists.append(log_prob)

        # do the backward recursion using previously computed ingredients
        if smoothing:
            # seed the backward recursion with the filtering distribution at t=T
            smoothing_dists = [filtering_dists[-1]]
            T = data.size(0)

            s_vars = {
                t: funsor.Variable(f's_{t}', funsor.bint(self.num_components))
                for t in range(T)
            }
            x_vars = {
                t: funsor.Variable(f'x_{t}', funsor.reals(self.hidden_dim))
                for t in range(T)
            }

            # do the backward recursion.
            # let p[t|t-1] be the predictive distribution at time step t.
            # let p[t|t] be the filtering distribution at time step t.
            # let f[t] denote the prior (transition) density at time step t.
            # then the smoothing distribution p[t|T] at time step t is
            # given by the following recursion.
            # p[t-1|T] = p[t-1|t-1] <p[t|T] f[t] / p[t|t-1]>
            # where <...> denotes integration of the latent variables at time step t.
            for t in reversed(range(T - 1)):
                integral = smoothing_dists[-1] - predictive_x_dists[t]
                integral += dist.Categorical(trans_probs(s=s_vars[t]),
                                             value=s_vars[t + 1])
                integral += x_trans_dist(s=s_vars[t],
                                         x=x_vars[t],
                                         y=x_vars[t + 1])
                integral = integral.reduce(
                    ops.logaddexp,
                    frozenset([s_vars[t + 1].name, x_vars[t + 1].name]))
                smoothing_dists.append(filtering_dists[t] + integral)

        # compute predictive test MSE and predictive variances
        predictive_means = torch.stack([d.mean for d in predictive_y_dists
                                        ])  # T-1 ydim
        predictive_vars = torch.stack([
            d.covariance_matrix.diagonal(dim1=-1, dim2=-2)
            for d in predictive_y_dists
        ])
        predictive_mse = (predictive_means - data[1:, :]).pow(2.0).mean(-1)

        if smoothing:
            # compute smoothed mean function
            smoothing_dists = [
                funsor_to_cat_and_mvn(d, 0, (f"s_{t}", ))
                for t, d in enumerate(reversed(smoothing_dists))
            ]
            means = torch.stack([d[1].mean
                                 for d in smoothing_dists])  # T 2 xdim
            means = torch.matmul(means.unsqueeze(-2),
                                 self.observation_matrix).squeeze(
                                     -2)  # T 2 ydim

            probs = torch.stack([d[0].logits for d in smoothing_dists]).exp()
            probs = probs / probs.sum(-1, keepdim=True)  # T 2

            smoothing_means = (probs.unsqueeze(-1) * means).sum(-2)  # T ydim
            smoothing_probs = probs[:, 1]

            return predictive_mse, torch.tensor(np.array(test_LLs)), predictive_means, predictive_vars, \
                smoothing_means, smoothing_probs
        else:
            return predictive_mse, torch.tensor(np.array(test_LLs))
Example #20
0
def test_gaussian_funsor(batch_shape):
    # This tests sample distribution, rsample gradients, log_prob, and log_prob
    # gradients for both Pyro's and Funsor's Gaussian.
    import funsor

    funsor.set_backend("torch")
    num_samples = 100000

    # Declare unconstrained parameters.
    loc = torch.randn(batch_shape + (3, )).requires_grad_()
    t = transform_to(constraints.positive_definite)
    m = torch.randn(batch_shape + (3, 3))
    precision_unconstrained = t.inv(m @ m.transpose(-1, -2)).requires_grad_()

    # Transform to constrained space.
    log_normalizer = torch.zeros(batch_shape)
    precision = t(precision_unconstrained)
    info_vec = (precision @ loc[..., None])[..., 0]

    def check_equal(actual, expected, atol=0.01, rtol=0):
        assert_close(actual.data, expected.data, atol=atol, rtol=rtol)
        grads = torch.autograd.grad(
            (actual - expected).abs().sum(),
            [loc, precision_unconstrained],
            retain_graph=True,
        )
        for grad in grads:
            assert grad.abs().max() < atol

    entropy = dist.MultivariateNormal(loc,
                                      precision_matrix=precision).entropy()

    # Monte carlo estimate entropy via pyro.
    p_gaussian = Gaussian(log_normalizer, info_vec, precision)
    p_log_Z = p_gaussian.event_logsumexp()
    p_rsamples = p_gaussian.rsample((num_samples, ))
    pp_entropy = (p_log_Z - p_gaussian.log_density(p_rsamples)).mean(0)
    check_equal(pp_entropy, entropy)

    # Monte carlo estimate entropy via funsor.
    inputs = OrderedDict([(k, funsor.Bint[v])
                          for k, v in zip("ij", batch_shape)])
    inputs["x"] = funsor.Reals[3]
    f_gaussian = funsor.gaussian.Gaussian(mean=loc,
                                          precision=precision,
                                          inputs=inputs)
    f_log_Z = f_gaussian.reduce(funsor.ops.logaddexp, "x")
    sample_inputs = OrderedDict(particle=funsor.Bint[num_samples])
    deltas = f_gaussian.sample("x", sample_inputs)
    f_rsamples = funsor.montecarlo.extract_samples(deltas)["x"]
    ff_entropy = (f_log_Z - f_gaussian(x=f_rsamples)).reduce(
        funsor.ops.mean, "particle")
    check_equal(ff_entropy.data, entropy)

    # Check Funsor's .rsample against Pyro's .log_prob.
    pf_entropy = (p_log_Z - p_gaussian.log_density(f_rsamples.data)).mean(0)
    check_equal(pf_entropy, entropy)

    # Check Pyro's .rsample against Funsor's .log_prob.
    fp_rsamples = funsor.Tensor(p_rsamples)["particle"]
    for i in "ij"[:len(batch_shape)]:
        fp_rsamples = fp_rsamples[i]
    fp_entropy = (f_log_Z - f_gaussian(x=fp_rsamples)).reduce(
        funsor.ops.mean, "particle")
    check_equal(fp_entropy.data, entropy)
Example #21
0
    def compute_probs(self) -> torch.Tensor:
        theta_probs = torch.zeros(self.K, self.data.Nt, self.data.F, self.Q)
        nbatch_size = self.nbatch_size
        N = sum(self.data.is_ontarget)
        for ndx in torch.split(torch.arange(N), nbatch_size):
            self.n = ndx
            self.nbatch_size = len(ndx)
            with torch.no_grad(), pyro.plate(
                    "particles", size=5,
                    dim=-4), handlers.enum(first_available_dim=-5):
                guide_tr = handlers.trace(self.guide).get_trace()
                model_tr = handlers.trace(
                    handlers.replay(self.model, trace=guide_tr)).get_trace()
            model_tr.compute_log_prob()
            guide_tr.compute_log_prob()

            logp = {}
            result = {}
            for fsx in ("0", f"slice(1, {self.data.F}, None)"):
                logp[fsx] = 0
                # collect log_prob terms p(z, theta, phi)
                for name in [
                        "z",
                        "theta",
                        "m_k0",
                        "m_k1",
                        "x_k0",
                        "x_k1",
                        "y_k0",
                        "y_k1",
                ]:
                    logp[fsx] += model_tr.nodes[f"{name}_f{fsx}"]["funsor"][
                        "log_prob"]
                if fsx == "0":
                    # substitute MAP values of z into p(z=z_map, theta, phi)
                    z_map = funsor.Tensor(self.z_map[ndx, 0].long(),
                                          dtype=2)["aois", "channels"]
                    logp[fsx] = logp[fsx](**{f"z_f{fsx}": z_map})
                    # compute log_measure q for given z_map
                    log_measure = (
                        guide_tr.nodes[f"m_k0_f{fsx}"]["funsor"]["log_measure"]
                        +
                        guide_tr.nodes[f"m_k1_f{fsx}"]["funsor"]["log_measure"]
                    )
                    log_measure = log_measure(**{f"z_f{fsx}": z_map})
                else:
                    # substitute MAP values of z into p(z=z_map, theta, phi)
                    z_map = funsor.Tensor(self.z_map[ndx, 1:].long(),
                                          dtype=2)["aois", "frames",
                                                   "channels"]
                    z_map_prev = funsor.Tensor(self.z_map[ndx, :-1].long(),
                                               dtype=2)["aois", "frames",
                                                        "channels"]
                    fsx_prev = f"slice(0, {self.data.F-1}, None)"
                    logp[fsx] = logp[fsx](**{
                        f"z_f{fsx}": z_map,
                        f"z_f{fsx_prev}": z_map_prev
                    })
                    # compute log_measure q for given z_map
                    log_measure = (
                        guide_tr.nodes[f"m_k0_f{fsx}"]["funsor"]["log_measure"]
                        +
                        guide_tr.nodes[f"m_k1_f{fsx}"]["funsor"]["log_measure"]
                    )
                    log_measure = log_measure(**{
                        f"z_f{fsx}": z_map,
                        f"z_f{fsx_prev}": z_map_prev
                    })
                # compute p(z_map, theta | phi) = p(z_map, theta, phi) - p(z_map, phi)
                logp[fsx] = logp[fsx] - logp[fsx].reduce(
                    funsor.ops.logaddexp, f"theta_f{fsx}")
                # average over m in p * q
                result[fsx] = (logp[fsx] + log_measure).reduce(
                    funsor.ops.logaddexp,
                    frozenset({f"m_k0_f{fsx}", f"m_k1_f{fsx}"}))
                # average over particles
                result[fsx] = result[fsx].exp().reduce(funsor.ops.mean,
                                                       "particles")
            theta_probs[:, ndx, 0] = result["0"].data[..., 1:].permute(2, 0, 1)
            theta_probs[:, ndx, 1:] = (
                result[f"slice(1, {self.data.F}, None)"].data[..., 1:].permute(
                    3, 0, 1, 2))
        self.n = None
        self.nbatch_size = nbatch_size
        return theta_probs