Пример #1
0
def test_mcmc_model_side_enumeration(model, temperature):
    # Perform fake inference.
    # Draw from prior rather than trying to sample from mcmc posterior.
    # This has the wrong distribution but the right type for tests.
    mcmc_trace = handlers.trace(
        handlers.block(handlers.enum(infer.config_enumerate(model)),
                       expose=["loc", "scale"])).get_trace()
    mcmc_data = {
        name: site["value"]
        for name, site in mcmc_trace.nodes.items() if site["type"] == "sample"
    }

    # MAP estimate discretes, conditioned on posterior sampled continous latents.
    actual_trace = handlers.trace(
        infer.infer_discrete(
            # TODO support replayed sites in infer_discrete.
            # handlers.replay(infer.config_enumerate(model), mcmc_trace),
            handlers.condition(infer.config_enumerate(model), mcmc_data),
            temperature=temperature,
        ), ).get_trace()

    # Check site names and shapes.
    expected_trace = handlers.trace(model).get_trace()
    assert set(actual_trace.nodes) == set(expected_trace.nodes)
    assert "z1" not in actual_trace.nodes["scale"]["funsor"]["value"].inputs
Пример #2
0
def test_trace_handler(model, backend):
    pytest.importorskip(PACKAGE_NAME[backend])
    with pyro_backend(backend), handlers.seed(rng_seed=2):
        f = MODELS[model]()
        model, model_args, model_kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {})
        # should be implemented
        handlers.trace(model).get_trace(*model_args, **model_kwargs)
Пример #3
0
def test_svi_model_side_enumeration(model, temperature):
    # Perform fake inference.
    # This has the wrong distribution but the right type for tests.
    guide = AutoNormal(
        handlers.enum(
            handlers.block(infer.config_enumerate(model),
                           expose=["loc", "scale"])))
    guide()  # Initialize but don't bother to train.
    guide_trace = handlers.trace(guide).get_trace()
    guide_data = {
        name: site["value"]
        for name, site in guide_trace.nodes.items() if site["type"] == "sample"
    }

    # MAP estimate discretes, conditioned on posterior sampled continous latents.
    actual_trace = handlers.trace(
        infer.infer_discrete(
            # TODO support replayed sites in infer_discrete.
            # handlers.replay(infer.config_enumerate(model), guide_trace)
            handlers.condition(infer.config_enumerate(model), guide_data),
            temperature=temperature,
        )).get_trace()

    # Check site names and shapes.
    expected_trace = handlers.trace(model).get_trace()
    assert set(actual_trace.nodes) == set(expected_trace.nodes)
    assert "z1" not in actual_trace.nodes["scale"]["funsor"]["value"].inputs
Пример #4
0
def test_register_backend(model):
    pytest.importorskip("pyro")
    register_backend("foo", {
        "infer": "pyro.contrib.minipyro",
        "optim": "pyro.contrib.minipyro",
        "pyro": "pyro.contrib.minipyro",
    })
    with pyro_backend("foo"):
        f = MODELS[model]()
        model, model_args, model_kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {})
        handlers.trace(model).get_trace(*model_args, **model_kwargs)
Пример #5
0
def test_distribution_3(temperature):
    #       +---------+  +---------------+
    #  z1 --|--> x1   |  |  z2 ---> x2   |
    #       |       3 |  |             2 |
    #       +---------+  +---------------+
    num_particles = 10000
    data = [torch.tensor([-1.0, -1.0, 0.0]), torch.tensor([-1.0, 1.0])]

    @infer.config_enumerate
    def model(z1=None, z2=None):
        p = pyro.param("p", torch.tensor([0.25, 0.75]))
        loc = pyro.param("loc", torch.tensor([-1.0, 1.0]))
        z1 = pyro.sample("z1", dist.Categorical(p), obs=z1)
        with pyro.plate("data[0]", 3):
            pyro.sample("x1", dist.Normal(loc[z1], 1.0), obs=data[0])
        with pyro.plate("data[1]", 2):
            z2 = pyro.sample("z2", dist.Categorical(p), obs=z2)
            pyro.sample("x2", dist.Normal(loc[z2], 1.0), obs=data[1])

    first_available_dim = -3
    vectorized_model = (model if temperature == 0 else pyro.plate(
        "particles", size=num_particles, dim=-2)(model))
    sampled_model = infer.infer_discrete(vectorized_model, first_available_dim,
                                         temperature)
    sampled_trace = handlers.trace(sampled_model).get_trace()
    conditioned_traces = {
        (z1, z20, z21):
        handlers.trace(model).get_trace(z1=torch.tensor(z1),
                                        z2=torch.tensor([z20, z21]))
        for z1 in [0, 1] for z20 in [0, 1] for z21 in [0, 1]
    }

    # Check joint posterior over (z1, z2[0], z2[1]).
    actual_probs = torch.empty(2, 2, 2)
    expected_probs = torch.empty(2, 2, 2)
    for (z1, z20, z21), tr in conditioned_traces.items():
        expected_probs[z1, z20, z21] = tr.log_prob_sum().exp()
        actual_probs[z1, z20, z21] = ((
            (sampled_trace.nodes["z1"]["value"] == z1)
            & (sampled_trace.nodes["z2"]["value"][..., :1] == z20)
            & (sampled_trace.nodes["z2"]["value"][..., 1:]
               == z21)).float().mean())
    if temperature:
        expected_probs = expected_probs / expected_probs.sum()
    else:
        expected_max, argmax = expected_probs.reshape(-1).max(0)
        actual_max = sampled_trace.log_prob_sum().exp()
        assert_equal(expected_max, actual_max, prec=1e-5)
        expected_probs[:] = 0
        expected_probs.reshape(-1)[argmax] = 1
    assert_equal(expected_probs.reshape(-1),
                 actual_probs.reshape(-1),
                 prec=1e-2)
Пример #6
0
 def compute_probs(self) -> torch.Tensor:
     z_probs = torch.zeros(self.data.Nt, self.data.F)
     theta_probs = torch.zeros(self.K, self.data.Nt, self.data.F)
     nbatch_size = self.nbatch_size
     fbatch_size = self.fbatch_size
     N = sum(self.data.is_ontarget)
     for ndx in torch.split(torch.arange(N), nbatch_size):
         for fdx in torch.split(torch.arange(self.data.F), fbatch_size):
             self.n = ndx
             self.f = fdx
             self.nbatch_size = len(ndx)
             self.fbatch_size = len(fdx)
             with torch.no_grad(), pyro.plate(
                     "particles", size=25,
                     dim=-3), handlers.enum(first_available_dim=-4):
                 guide_tr = handlers.trace(self.guide).get_trace()
                 model_tr = handlers.trace(
                     handlers.replay(self.model,
                                     trace=guide_tr)).get_trace()
             model_tr.compute_log_prob()
             guide_tr.compute_log_prob()
             # 0 - theta
             # 1 - z
             # 2 - m_1
             # 3 - m_0
             # p(z, theta, phi)
             logp = 0
             for name in [
                     "z", "theta", "m_0", "m_1", "x_0", "x_1", "y_0", "y_1"
             ]:
                 logp = logp + model_tr.nodes[name]["unscaled_log_prob"]
             # p(z, theta | phi) = p(z, theta, phi) - p(z, theta, phi).sum(z, theta)
             logp = logp - logp.logsumexp((0, 1))
             expectation = (guide_tr.nodes["m_0"]["unscaled_log_prob"] +
                            guide_tr.nodes["m_1"]["unscaled_log_prob"] +
                            logp)
             # average over m
             result = expectation.logsumexp((2, 3))
             # marginalize theta
             z_logits = result.logsumexp(0)
             z_probs[ndx[:, None], fdx] = z_logits[1].exp().mean(-3)
             # marginalize z
             theta_logits = result.logsumexp(1)
             theta_probs[:, ndx[:, None],
                         fdx] = theta_logits[1:].exp().mean(-3)
     self.n = None
     self.f = None
     self.nbatch_size = nbatch_size
     self.fbatch_size = fbatch_size
     return z_probs, theta_probs
Пример #7
0
def assert_ok(model, guide=None, max_plate_nesting=None, **kwargs):
    """
    Assert that enumeration runs...
    """
    with pyro_backend("pyro"):
        pyro.clear_param_store()

    if guide is None:
        guide = lambda **kwargs: None  # noqa: E731

    q_pyro, q_funsor = LifoQueue(), LifoQueue()
    q_pyro.put(Trace())
    q_funsor.put(Trace())

    while not q_pyro.empty() and not q_funsor.empty():
        with pyro_backend("pyro"):
            with handlers.enum(first_available_dim=-max_plate_nesting - 1):
                guide_tr_pyro = handlers.trace(
                    handlers.queue(
                        guide,
                        q_pyro,
                        escape_fn=iter_discrete_escape,
                        extend_fn=iter_discrete_extend,
                    )).get_trace(**kwargs)
                tr_pyro = handlers.trace(
                    handlers.replay(model,
                                    trace=guide_tr_pyro)).get_trace(**kwargs)

        with pyro_backend("contrib.funsor"):
            with handlers.enum(first_available_dim=-max_plate_nesting - 1):
                guide_tr_funsor = handlers.trace(
                    handlers.queue(
                        guide,
                        q_funsor,
                        escape_fn=iter_discrete_escape,
                        extend_fn=iter_discrete_extend,
                    )).get_trace(**kwargs)
                tr_funsor = handlers.trace(
                    handlers.replay(model,
                                    trace=guide_tr_funsor)).get_trace(**kwargs)

        # make sure all dimensions were cleaned up
        assert _DIM_STACK.local_frame is _DIM_STACK.global_frame
        assert (not _DIM_STACK.global_frame.name_to_dim
                and not _DIM_STACK.global_frame.dim_to_name)
        assert _DIM_STACK.outermost is None

        tr_pyro = prune_subsample_sites(tr_pyro.copy())
        tr_funsor = prune_subsample_sites(tr_funsor.copy())
        _check_traces(tr_pyro, tr_funsor)
Пример #8
0
def test_distribution_1(temperature):
    #      +-------+
    #  z --|--> x  |
    #      +-------+
    num_particles = 10000
    data = torch.tensor([1.0, 2.0, 3.0])

    @infer.config_enumerate
    def model(z=None):
        p = pyro.param("p", torch.tensor([0.75, 0.25]))
        iz = pyro.sample("z", dist.Categorical(p), obs=z)
        z = torch.tensor([0.0, 1.0])[iz]
        logger.info("z.shape = {}".format(z.shape))
        with pyro.plate("data", 3):
            pyro.sample("x", dist.Normal(z, 1.0), obs=data)

    first_available_dim = -3
    vectorized_model = (model if temperature == 0 else pyro.plate(
        "particles", size=num_particles, dim=-2)(model))
    sampled_model = infer.infer_discrete(vectorized_model, first_available_dim,
                                         temperature)
    sampled_trace = handlers.trace(sampled_model).get_trace()
    conditioned_traces = {
        z: handlers.trace(model).get_trace(z=torch.tensor(z).long())
        for z in [0.0, 1.0]
    }

    # Check  posterior over z.
    actual_z_mean = sampled_trace.nodes["z"]["value"].float().mean()
    if temperature:
        expected_z_mean = 1 / (1 +
                               (conditioned_traces[0].log_prob_sum() -
                                conditioned_traces[1].log_prob_sum()).exp())
    else:
        expected_z_mean = (conditioned_traces[1].log_prob_sum() >
                           conditioned_traces[0].log_prob_sum()).float()
        expected_max = max(t.log_prob_sum()
                           for t in conditioned_traces.values())
        actual_max = sampled_trace.log_prob_sum()
        assert_equal(expected_max, actual_max, prec=1e-5)
    assert_equal(actual_z_mean,
                 expected_z_mean,
                 prec=1e-2 if temperature else 1e-5)
Пример #9
0
def main(args):
    if args.cuda:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')

    logging.info('Loading data')
    data = poly.load_data(poly.JSB_CHORALES)

    logging.info('-' * 40)
    model = models[args.model]
    logging.info('Training {} on {} sequences'.format(
        model.__name__, len(data['train']['sequences'])))
    sequences = data['train']['sequences']
    lengths = data['train']['sequence_lengths']

    # find all the notes that are present at least once in the training set
    present_notes = ((sequences == 1).sum(0).sum(0) > 0)
    # remove notes that are never played (we remove 37/88 notes)
    sequences = sequences[..., present_notes]

    if args.truncate:
        lengths = lengths.clamp(max=args.truncate)
        sequences = sequences[:, :args.truncate]
    num_observations = float(lengths.sum())
    pyro.set_rng_seed(args.seed)
    pyro.clear_param_store()
    pyro.enable_validation(__debug__)

    # We'll train using MAP Baum-Welch, i.e. MAP estimation while marginalizing
    # out the hidden state x. This is accomplished via an automatic guide that
    # learns point estimates of all of our conditional probability tables,
    # named probs_*.
    guide = AutoDelta(
        handlers.block(model,
                       expose_fn=lambda msg: msg["name"].startswith("probs_")))

    # To help debug our tensor shapes, let's print the shape of each site's
    # distribution, value, and log_prob tensor. Note this information is
    # automatically printed on most errors inside SVI.
    if args.print_shapes:
        first_available_dim = -2 if model is model_0 else -3
        guide_trace = handlers.trace(guide).get_trace(
            sequences, lengths, args=args, batch_size=args.batch_size)
        model_trace = handlers.trace(
            handlers.replay(handlers.enum(model, first_available_dim),
                            guide_trace)).get_trace(sequences,
                                                    lengths,
                                                    args=args,
                                                    batch_size=args.batch_size)
        logging.info(model_trace.format_shapes())

    # Bind non-PyTorch parameters to make these functions jittable.
    model = functools.partial(model, args=args)
    guide = functools.partial(guide, args=args)

    # Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting.
    # All of our models have two plates: "data" and "tones".
    optimizer = optim.Adam({'lr': args.learning_rate})
    if args.tmc:
        if args.jit and not args.funsor:
            raise NotImplementedError(
                "jit support not yet added for TraceTMC_ELBO")
        Elbo = infer.JitTraceTMC_ELBO if args.jit else infer.TraceTMC_ELBO
        elbo = Elbo(max_plate_nesting=1 if model is model_0 else 2)
        tmc_model = handlers.infer_config(model, lambda msg: {
            "num_samples": args.tmc_num_samples,
            "expand": False
        } if msg["infer"].get("enumerate", None) == "parallel" else {}
                                          )  # noqa: E501
        svi = infer.SVI(tmc_model, guide, optimizer, elbo)
    else:
        Elbo = infer.JitTraceEnum_ELBO if args.jit else infer.TraceEnum_ELBO
        elbo = Elbo(max_plate_nesting=1 if model is model_0 else 2,
                    strict_enumeration_warning=True,
                    jit_options={"time_compilation": args.time_compilation})
        svi = infer.SVI(model, guide, optimizer, elbo)

    # We'll train on small minibatches.
    logging.info('Step\tLoss')
    for step in range(args.num_steps):
        loss = svi.step(sequences, lengths, batch_size=args.batch_size)
        logging.info('{: >5d}\t{}'.format(step, loss / num_observations))

    if args.jit and args.time_compilation:
        logging.debug('time to compile: {} s.'.format(
            elbo._differentiable_loss.compile_time))

    # We evaluate on the entire training dataset,
    # excluding the prior term so our results are comparable across models.
    train_loss = elbo.loss(model,
                           guide,
                           sequences,
                           lengths,
                           batch_size=sequences.shape[0],
                           include_prior=False)
    logging.info('training loss = {}'.format(train_loss / num_observations))

    # Finally we evaluate on the test dataset.
    logging.info('-' * 40)
    logging.info('Evaluating on {} test sequences'.format(
        len(data['test']['sequences'])))
    sequences = data['test']['sequences'][..., present_notes]
    lengths = data['test']['sequence_lengths']
    if args.truncate:
        lengths = lengths.clamp(max=args.truncate)
    num_observations = float(lengths.sum())

    # note that since we removed unseen notes above (to make the problem a bit easier and for
    # numerical stability) this test loss may not be directly comparable to numbers
    # reported on this dataset elsewhere.
    test_loss = elbo.loss(model,
                          guide,
                          sequences,
                          lengths,
                          batch_size=sequences.shape[0],
                          include_prior=False)
    logging.info('test loss = {}'.format(test_loss / num_observations))

    # We expect models with higher capacity to perform better,
    # but eventually overfit to the training set.
    capacity = sum(
        value.reshape(-1).size(0) for value in pyro.get_param_store().values())
    logging.info('model_{} capacity = {} parameters'.format(
        args.model, capacity))
Пример #10
0
def test_enumeration_multi(model, weeks_data, days_data, vars1, vars2, history,
                           use_replay):
    pyro.clear_param_store()

    with pyro_backend("contrib.funsor"):
        with handlers.enum():
            enum_model = infer.config_enumerate(model, default="parallel")
            # sequential factors
            trace = handlers.trace(enum_model).get_trace(
                weeks_data, days_data, history, False)

            # vectorized trace
            if use_replay:
                guide_trace = handlers.trace(
                    _guide_from_model(model)).get_trace(
                        weeks_data, days_data, history, True)
                vectorized_trace = handlers.trace(
                    handlers.replay(model, trace=guide_trace)).get_trace(
                        weeks_data, days_data, history, True)
            else:
                vectorized_trace = handlers.trace(enum_model).get_trace(
                    weeks_data, days_data, history, True)

        factors = list()
        # sequential weeks factors
        for i in range(len(weeks_data)):
            for v in vars1:
                factors.append(trace.nodes["{}_{}".format(
                    v, i)]["funsor"]["log_prob"])
        # sequential days factors
        for j in range(len(days_data)):
            for v in vars2:
                factors.append(trace.nodes["{}_{}".format(
                    v, j)]["funsor"]["log_prob"])

        vectorized_factors = list()
        # vectorized weeks factors
        for i in range(history):
            for v in vars1:
                vectorized_factors.append(
                    vectorized_trace.nodes["{}_{}".format(
                        v, i)]["funsor"]["log_prob"])
        for i in range(history, len(weeks_data)):
            for v in vars1:
                vectorized_factors.append(
                    vectorized_trace.nodes["{}_{}".format(
                        v, slice(history,
                                 len(weeks_data)))]["funsor"]["log_prob"](**{
                                     "weeks":
                                     i - history
                                 }, **{
                                     "{}_{}".format(
                                         k,
                                         slice(history - j,
                                               len(weeks_data) - j)):
                                     "{}_{}".format(k, i - j)
                                     for j in range(history + 1) for k in vars1
                                 }))
        # vectorized days factors
        for i in range(history):
            for v in vars2:
                vectorized_factors.append(
                    vectorized_trace.nodes["{}_{}".format(
                        v, i)]["funsor"]["log_prob"])
        for i in range(history, len(days_data)):
            for v in vars2:
                vectorized_factors.append(
                    vectorized_trace.nodes["{}_{}".format(
                        v, slice(history,
                                 len(days_data)))]["funsor"]["log_prob"](**{
                                     "days":
                                     i - history
                                 }, **{
                                     "{}_{}".format(
                                         k,
                                         slice(history - j,
                                               len(days_data) - j)):
                                     "{}_{}".format(k, i - j)
                                     for j in range(history + 1) for k in vars2
                                 }))

        # assert correct factors
        for f1, f2 in zip(factors, vectorized_factors):
            assert_close(f2, f1.align(tuple(f2.inputs)))

        # assert correct step

        expected_measure_vars = frozenset()
        actual_weeks_step = vectorized_trace.nodes["weeks"]["value"]
        # expected step: assume that all but the last var is markov
        expected_weeks_step = frozenset()
        for v in vars1[:-1]:
            v_step = tuple("{}_{}".format(v, i) for i in range(history)) \
                     + tuple("{}_{}".format(v, slice(j, len(weeks_data)-history+j)) for j in range(history+1))
            expected_weeks_step |= frozenset({v_step})
            # grab measure_vars, found only at sites that are not replayed
            if not use_replay:
                expected_measure_vars |= frozenset(v_step)

        actual_days_step = vectorized_trace.nodes["days"]["value"]
        # expected step: assume that all but the last var is markov
        expected_days_step = frozenset()
        for v in vars2[:-1]:
            v_step = tuple("{}_{}".format(v, i) for i in range(history)) \
                     + tuple("{}_{}".format(v, slice(j, len(days_data)-history+j)) for j in range(history+1))
            expected_days_step |= frozenset({v_step})
            # grab measure_vars, found only at sites that are not replayed
            if not use_replay:
                expected_measure_vars |= frozenset(v_step)

        assert actual_weeks_step == expected_weeks_step
        assert actual_days_step == expected_days_step

        # check measure_vars
        actual_measure_vars = terms_from_trace(
            vectorized_trace)["measure_vars"]
        assert actual_measure_vars == expected_measure_vars
Пример #11
0
    def compute_probs(self) -> torch.Tensor:
        z_probs = torch.zeros(self.data.Nt, self.data.F, self.Q)
        theta_probs = torch.zeros(self.K, self.data.Nt, self.data.F, self.Q)
        nbatch_size = self.nbatch_size
        fbatch_size = self.fbatch_size
        N = sum(self.data.is_ontarget)
        params = ["m", "x", "y"]
        params = list(
            map(lambda x: [f"{x}_k{i}" for i in range(self.K)], params))
        params = list(itertools.chain(*params))
        params += ["z", "theta"]
        params = list(
            map(lambda x: [f"{x}_q{i}" for i in range(self.Q)], params))
        params = list(itertools.chain(*params))
        theta_dims = tuple(i for i in range(0, self.Q * 2, 2))
        z_dims = tuple(i for i in range(1, self.Q * 2, 2))
        m_dims = tuple(i for i in range(self.Q * 2, self.Q * (self.K + 2)))
        for ndx in torch.split(torch.arange(N), nbatch_size):
            for fdx in torch.split(torch.arange(self.data.F), fbatch_size):
                self.n = ndx
                self.f = fdx
                self.nbatch_size = len(ndx)
                self.fbatch_size = len(fdx)
                with torch.no_grad(), pyro.plate(
                        "particles", size=5,
                        dim=-3), handlers.enum(first_available_dim=-4):
                    guide_tr = handlers.trace(self.guide).get_trace()
                    model_tr = handlers.trace(
                        handlers.replay(self.model,
                                        trace=guide_tr)).get_trace()
                model_tr.compute_log_prob()
                guide_tr.compute_log_prob()
                # 0 - theta
                # 1 - z
                # 2 - m_1
                # 3 - m_0
                # p(z, theta, phi)
                logp = 0

                for name in params:
                    logp = logp + model_tr.nodes[name]["unscaled_log_prob"]
                # p(z, theta | phi) = p(z, theta, phi) - p(z, theta, phi).sum(z, theta)
                logp = logp - logp.logsumexp(z_dims + theta_dims)
                m_log_probs = [
                    guide_tr.nodes[f"m_k{k}_q{q}"]["unscaled_log_prob"]
                    for k in range(self.K) for q in range(self.Q)
                ]
                expectation = reduce(lambda x, y: x + y, m_log_probs) + logp
                # average over m
                result = expectation.logsumexp(m_dims)
                # marginalize theta
                z_logits = result.logsumexp(theta_dims)
                a = z_logits.exp().mean(-3)
                for q in range(self.Q):
                    sum_dims = tuple(i for i in range(self.Q) if i != q)
                    if sum_dims:
                        a = a.sum(sum_dims)
                    z_probs[ndx[:, None], fdx, q] = a[1]
                # marginalize z
                b = result.logsumexp(z_dims)
                for q in range(self.Q):
                    sum_dims = tuple(i for i in range(self.Q) if i != q)
                    if sum_dims:
                        b = b.logsumexp(sum_dims)
                    theta_probs[:, ndx[:, None], fdx, q] = b[1:].exp().mean(-3)
        self.n = None
        self.f = None
        self.nbatch_size = nbatch_size
        self.fbatch_size = fbatch_size
        return z_probs, theta_probs
Пример #12
0
    def compute_probs(self) -> torch.Tensor:
        theta_probs = torch.zeros(self.K, self.data.Nt, self.data.F, self.Q)
        nbatch_size = self.nbatch_size
        N = sum(self.data.is_ontarget)
        for ndx in torch.split(torch.arange(N), nbatch_size):
            self.n = ndx
            self.nbatch_size = len(ndx)
            with torch.no_grad(), pyro.plate(
                    "particles", size=5,
                    dim=-4), handlers.enum(first_available_dim=-5):
                guide_tr = handlers.trace(self.guide).get_trace()
                model_tr = handlers.trace(
                    handlers.replay(self.model, trace=guide_tr)).get_trace()
            model_tr.compute_log_prob()
            guide_tr.compute_log_prob()

            logp = {}
            result = {}
            for fsx in ("0", f"slice(1, {self.data.F}, None)"):
                logp[fsx] = 0
                # collect log_prob terms p(z, theta, phi)
                for name in [
                        "z",
                        "theta",
                        "m_k0",
                        "m_k1",
                        "x_k0",
                        "x_k1",
                        "y_k0",
                        "y_k1",
                ]:
                    logp[fsx] += model_tr.nodes[f"{name}_f{fsx}"]["funsor"][
                        "log_prob"]
                if fsx == "0":
                    # substitute MAP values of z into p(z=z_map, theta, phi)
                    z_map = funsor.Tensor(self.z_map[ndx, 0].long(),
                                          dtype=2)["aois", "channels"]
                    logp[fsx] = logp[fsx](**{f"z_f{fsx}": z_map})
                    # compute log_measure q for given z_map
                    log_measure = (
                        guide_tr.nodes[f"m_k0_f{fsx}"]["funsor"]["log_measure"]
                        +
                        guide_tr.nodes[f"m_k1_f{fsx}"]["funsor"]["log_measure"]
                    )
                    log_measure = log_measure(**{f"z_f{fsx}": z_map})
                else:
                    # substitute MAP values of z into p(z=z_map, theta, phi)
                    z_map = funsor.Tensor(self.z_map[ndx, 1:].long(),
                                          dtype=2)["aois", "frames",
                                                   "channels"]
                    z_map_prev = funsor.Tensor(self.z_map[ndx, :-1].long(),
                                               dtype=2)["aois", "frames",
                                                        "channels"]
                    fsx_prev = f"slice(0, {self.data.F-1}, None)"
                    logp[fsx] = logp[fsx](**{
                        f"z_f{fsx}": z_map,
                        f"z_f{fsx_prev}": z_map_prev
                    })
                    # compute log_measure q for given z_map
                    log_measure = (
                        guide_tr.nodes[f"m_k0_f{fsx}"]["funsor"]["log_measure"]
                        +
                        guide_tr.nodes[f"m_k1_f{fsx}"]["funsor"]["log_measure"]
                    )
                    log_measure = log_measure(**{
                        f"z_f{fsx}": z_map,
                        f"z_f{fsx_prev}": z_map_prev
                    })
                # compute p(z_map, theta | phi) = p(z_map, theta, phi) - p(z_map, phi)
                logp[fsx] = logp[fsx] - logp[fsx].reduce(
                    funsor.ops.logaddexp, f"theta_f{fsx}")
                # average over m in p * q
                result[fsx] = (logp[fsx] + log_measure).reduce(
                    funsor.ops.logaddexp,
                    frozenset({f"m_k0_f{fsx}", f"m_k1_f{fsx}"}))
                # average over particles
                result[fsx] = result[fsx].exp().reduce(funsor.ops.mean,
                                                       "particles")
            theta_probs[:, ndx, 0] = result["0"].data[..., 1:].permute(2, 0, 1)
            theta_probs[:, ndx, 1:] = (
                result[f"slice(1, {self.data.F}, None)"].data[..., 1:].permute(
                    3, 0, 1, 2))
        self.n = None
        self.nbatch_size = nbatch_size
        return theta_probs