Ejemplo n.º 1
0
def test_mcmc_model_side_enumeration(model, temperature):
    # Perform fake inference.
    # Draw from prior rather than trying to sample from mcmc posterior.
    # This has the wrong distribution but the right type for tests.
    mcmc_trace = handlers.trace(
        handlers.block(handlers.enum(infer.config_enumerate(model)),
                       expose=["loc", "scale"])).get_trace()
    mcmc_data = {
        name: site["value"]
        for name, site in mcmc_trace.nodes.items() if site["type"] == "sample"
    }

    # MAP estimate discretes, conditioned on posterior sampled continous latents.
    actual_trace = handlers.trace(
        infer.infer_discrete(
            # TODO support replayed sites in infer_discrete.
            # handlers.replay(infer.config_enumerate(model), mcmc_trace),
            handlers.condition(infer.config_enumerate(model), mcmc_data),
            temperature=temperature,
        ), ).get_trace()

    # Check site names and shapes.
    expected_trace = handlers.trace(model).get_trace()
    assert set(actual_trace.nodes) == set(expected_trace.nodes)
    assert "z1" not in actual_trace.nodes["scale"]["funsor"]["value"].inputs
Ejemplo n.º 2
0
def test_svi_model_side_enumeration(model, temperature):
    # Perform fake inference.
    # This has the wrong distribution but the right type for tests.
    guide = AutoNormal(
        handlers.enum(
            handlers.block(infer.config_enumerate(model),
                           expose=["loc", "scale"])))
    guide()  # Initialize but don't bother to train.
    guide_trace = handlers.trace(guide).get_trace()
    guide_data = {
        name: site["value"]
        for name, site in guide_trace.nodes.items() if site["type"] == "sample"
    }

    # MAP estimate discretes, conditioned on posterior sampled continous latents.
    actual_trace = handlers.trace(
        infer.infer_discrete(
            # TODO support replayed sites in infer_discrete.
            # handlers.replay(infer.config_enumerate(model), guide_trace)
            handlers.condition(infer.config_enumerate(model), guide_data),
            temperature=temperature,
        )).get_trace()

    # Check site names and shapes.
    expected_trace = handlers.trace(model).get_trace()
    assert set(actual_trace.nodes) == set(expected_trace.nodes)
    assert "z1" not in actual_trace.nodes["scale"]["funsor"]["value"].inputs
Ejemplo n.º 3
0
def _guide_from_model(model):
    try:
        with pyro_backend("contrib.funsor"):
            return handlers.block(
                infer.config_enumerate(model, default="parallel"),
                lambda msg: msg.get("is_observed", False))
    except KeyError:  # for test collection without funsor
        return model
Ejemplo n.º 4
0
def main(args):
    if args.cuda:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')

    logging.info('Loading data')
    data = poly.load_data(poly.JSB_CHORALES)

    logging.info('-' * 40)
    model = models[args.model]
    logging.info('Training {} on {} sequences'.format(
        model.__name__, len(data['train']['sequences'])))
    sequences = data['train']['sequences']
    lengths = data['train']['sequence_lengths']

    # find all the notes that are present at least once in the training set
    present_notes = ((sequences == 1).sum(0).sum(0) > 0)
    # remove notes that are never played (we remove 37/88 notes)
    sequences = sequences[..., present_notes]

    if args.truncate:
        lengths = lengths.clamp(max=args.truncate)
        sequences = sequences[:, :args.truncate]
    num_observations = float(lengths.sum())
    pyro.set_rng_seed(args.seed)
    pyro.clear_param_store()
    pyro.enable_validation(__debug__)

    # We'll train using MAP Baum-Welch, i.e. MAP estimation while marginalizing
    # out the hidden state x. This is accomplished via an automatic guide that
    # learns point estimates of all of our conditional probability tables,
    # named probs_*.
    guide = AutoDelta(
        handlers.block(model,
                       expose_fn=lambda msg: msg["name"].startswith("probs_")))

    # To help debug our tensor shapes, let's print the shape of each site's
    # distribution, value, and log_prob tensor. Note this information is
    # automatically printed on most errors inside SVI.
    if args.print_shapes:
        first_available_dim = -2 if model is model_0 else -3
        guide_trace = handlers.trace(guide).get_trace(
            sequences, lengths, args=args, batch_size=args.batch_size)
        model_trace = handlers.trace(
            handlers.replay(handlers.enum(model, first_available_dim),
                            guide_trace)).get_trace(sequences,
                                                    lengths,
                                                    args=args,
                                                    batch_size=args.batch_size)
        logging.info(model_trace.format_shapes())

    # Bind non-PyTorch parameters to make these functions jittable.
    model = functools.partial(model, args=args)
    guide = functools.partial(guide, args=args)

    # Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting.
    # All of our models have two plates: "data" and "tones".
    optimizer = optim.Adam({'lr': args.learning_rate})
    if args.tmc:
        if args.jit and not args.funsor:
            raise NotImplementedError(
                "jit support not yet added for TraceTMC_ELBO")
        Elbo = infer.JitTraceTMC_ELBO if args.jit else infer.TraceTMC_ELBO
        elbo = Elbo(max_plate_nesting=1 if model is model_0 else 2)
        tmc_model = handlers.infer_config(model, lambda msg: {
            "num_samples": args.tmc_num_samples,
            "expand": False
        } if msg["infer"].get("enumerate", None) == "parallel" else {}
                                          )  # noqa: E501
        svi = infer.SVI(tmc_model, guide, optimizer, elbo)
    else:
        Elbo = infer.JitTraceEnum_ELBO if args.jit else infer.TraceEnum_ELBO
        elbo = Elbo(max_plate_nesting=1 if model is model_0 else 2,
                    strict_enumeration_warning=True,
                    jit_options={"time_compilation": args.time_compilation})
        svi = infer.SVI(model, guide, optimizer, elbo)

    # We'll train on small minibatches.
    logging.info('Step\tLoss')
    for step in range(args.num_steps):
        loss = svi.step(sequences, lengths, batch_size=args.batch_size)
        logging.info('{: >5d}\t{}'.format(step, loss / num_observations))

    if args.jit and args.time_compilation:
        logging.debug('time to compile: {} s.'.format(
            elbo._differentiable_loss.compile_time))

    # We evaluate on the entire training dataset,
    # excluding the prior term so our results are comparable across models.
    train_loss = elbo.loss(model,
                           guide,
                           sequences,
                           lengths,
                           batch_size=sequences.shape[0],
                           include_prior=False)
    logging.info('training loss = {}'.format(train_loss / num_observations))

    # Finally we evaluate on the test dataset.
    logging.info('-' * 40)
    logging.info('Evaluating on {} test sequences'.format(
        len(data['test']['sequences'])))
    sequences = data['test']['sequences'][..., present_notes]
    lengths = data['test']['sequence_lengths']
    if args.truncate:
        lengths = lengths.clamp(max=args.truncate)
    num_observations = float(lengths.sum())

    # note that since we removed unseen notes above (to make the problem a bit easier and for
    # numerical stability) this test loss may not be directly comparable to numbers
    # reported on this dataset elsewhere.
    test_loss = elbo.loss(model,
                          guide,
                          sequences,
                          lengths,
                          batch_size=sequences.shape[0],
                          include_prior=False)
    logging.info('test loss = {}'.format(test_loss / num_observations))

    # We expect models with higher capacity to perform better,
    # but eventually overfit to the training set.
    capacity = sum(
        value.reshape(-1).size(0) for value in pyro.get_param_store().values())
    logging.info('model_{} capacity = {} parameters'.format(
        args.model, capacity))
Ejemplo n.º 5
0
 def compute_probs(self) -> torch.Tensor:
     z_probs = torch.zeros(self.data.Nt, self.data.F, self.Q)
     theta_probs = torch.zeros(self.K, self.data.Nt, self.data.F, self.Q)
     nbatch_size = self.nbatch_size
     fbatch_size = self.fbatch_size
     N = sum(self.data.is_ontarget)
     params = ["m", "x", "y"]
     params = list(map(lambda x: [f"{x}_k{i}" for i in range(self.K)], params))
     params = list(itertools.chain(*params))
     params += ["z", "theta"]
     theta_dims = tuple(i for i in range(0, 2, 2))
     z_dims = tuple(i for i in range(1, 2, 2))
     m_dims = tuple(i for i in range(2, self.K + 2))
     for ndx in torch.split(torch.arange(N), nbatch_size):
         for fdx in torch.split(torch.arange(self.data.F), fbatch_size):
             self.n = ndx
             self.f = fdx
             self.nbatch_size = len(ndx)
             self.fbatch_size = len(fdx)
             qdx = torch.arange(self.Q)
             with torch.no_grad(), pyro.plate(
                 "particles", size=50, dim=-4
             ), handlers.enum(first_available_dim=-5):
                 guide_tr = handlers.trace(self.guide).get_trace()
                 model_tr = handlers.trace(
                     handlers.replay(
                         handlers.block(self.model, hide=["data"]), trace=guide_tr
                     )
                 ).get_trace()
             model_tr.compute_log_prob()
             guide_tr.compute_log_prob()
             # 0 - theta
             # 1 - z
             # 2 - m_1
             # 3 - m_0
             # p(z, theta, phi)
             logp = 0
             for name in params:
                 logp = logp + model_tr.nodes[name]["unscaled_log_prob"]
             # p(z, theta | phi) = p(z, theta, phi) - p(z, theta, phi).sum(z, theta)
             logp = logp - logp.logsumexp(z_dims + theta_dims)
             m_log_probs = [
                 guide_tr.nodes[f"m_k{k}"]["unscaled_log_prob"]
                 for k in range(self.K)
             ]
             expectation = reduce(lambda x, y: x + y, m_log_probs) + logp
             # average over m
             result = expectation.logsumexp(m_dims)
             # marginalize theta
             z_logits = result.logsumexp(theta_dims)
             z_probs[ndx[:, None, None], fdx[:, None], qdx] = (
                 z_logits[1].exp().mean(-4)
             )
             # marginalize z
             theta_logits = result.logsumexp(z_dims)
             theta_probs[:, ndx[:, None, None], fdx[:, None], qdx] = (
                 theta_logits[1:].exp().mean(-4)
             )
     self.n = None
     self.f = None
     self.nbatch_size = nbatch_size
     self.fbatch_size = fbatch_size
     return z_probs, theta_probs