def test_mcmc_model_side_enumeration(model, temperature): # Perform fake inference. # Draw from prior rather than trying to sample from mcmc posterior. # This has the wrong distribution but the right type for tests. mcmc_trace = handlers.trace( handlers.block(handlers.enum(infer.config_enumerate(model)), expose=["loc", "scale"])).get_trace() mcmc_data = { name: site["value"] for name, site in mcmc_trace.nodes.items() if site["type"] == "sample" } # MAP estimate discretes, conditioned on posterior sampled continous latents. actual_trace = handlers.trace( infer.infer_discrete( # TODO support replayed sites in infer_discrete. # handlers.replay(infer.config_enumerate(model), mcmc_trace), handlers.condition(infer.config_enumerate(model), mcmc_data), temperature=temperature, ), ).get_trace() # Check site names and shapes. expected_trace = handlers.trace(model).get_trace() assert set(actual_trace.nodes) == set(expected_trace.nodes) assert "z1" not in actual_trace.nodes["scale"]["funsor"]["value"].inputs
def test_svi_model_side_enumeration(model, temperature): # Perform fake inference. # This has the wrong distribution but the right type for tests. guide = AutoNormal( handlers.enum( handlers.block(infer.config_enumerate(model), expose=["loc", "scale"]))) guide() # Initialize but don't bother to train. guide_trace = handlers.trace(guide).get_trace() guide_data = { name: site["value"] for name, site in guide_trace.nodes.items() if site["type"] == "sample" } # MAP estimate discretes, conditioned on posterior sampled continous latents. actual_trace = handlers.trace( infer.infer_discrete( # TODO support replayed sites in infer_discrete. # handlers.replay(infer.config_enumerate(model), guide_trace) handlers.condition(infer.config_enumerate(model), guide_data), temperature=temperature, )).get_trace() # Check site names and shapes. expected_trace = handlers.trace(model).get_trace() assert set(actual_trace.nodes) == set(expected_trace.nodes) assert "z1" not in actual_trace.nodes["scale"]["funsor"]["value"].inputs
def _guide_from_model(model): try: with pyro_backend("contrib.funsor"): return handlers.block( infer.config_enumerate(model, default="parallel"), lambda msg: msg.get("is_observed", False)) except KeyError: # for test collection without funsor return model
def main(args): if args.cuda: torch.set_default_tensor_type('torch.cuda.FloatTensor') logging.info('Loading data') data = poly.load_data(poly.JSB_CHORALES) logging.info('-' * 40) model = models[args.model] logging.info('Training {} on {} sequences'.format( model.__name__, len(data['train']['sequences']))) sequences = data['train']['sequences'] lengths = data['train']['sequence_lengths'] # find all the notes that are present at least once in the training set present_notes = ((sequences == 1).sum(0).sum(0) > 0) # remove notes that are never played (we remove 37/88 notes) sequences = sequences[..., present_notes] if args.truncate: lengths = lengths.clamp(max=args.truncate) sequences = sequences[:, :args.truncate] num_observations = float(lengths.sum()) pyro.set_rng_seed(args.seed) pyro.clear_param_store() pyro.enable_validation(__debug__) # We'll train using MAP Baum-Welch, i.e. MAP estimation while marginalizing # out the hidden state x. This is accomplished via an automatic guide that # learns point estimates of all of our conditional probability tables, # named probs_*. guide = AutoDelta( handlers.block(model, expose_fn=lambda msg: msg["name"].startswith("probs_"))) # To help debug our tensor shapes, let's print the shape of each site's # distribution, value, and log_prob tensor. Note this information is # automatically printed on most errors inside SVI. if args.print_shapes: first_available_dim = -2 if model is model_0 else -3 guide_trace = handlers.trace(guide).get_trace( sequences, lengths, args=args, batch_size=args.batch_size) model_trace = handlers.trace( handlers.replay(handlers.enum(model, first_available_dim), guide_trace)).get_trace(sequences, lengths, args=args, batch_size=args.batch_size) logging.info(model_trace.format_shapes()) # Bind non-PyTorch parameters to make these functions jittable. model = functools.partial(model, args=args) guide = functools.partial(guide, args=args) # Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting. # All of our models have two plates: "data" and "tones". optimizer = optim.Adam({'lr': args.learning_rate}) if args.tmc: if args.jit and not args.funsor: raise NotImplementedError( "jit support not yet added for TraceTMC_ELBO") Elbo = infer.JitTraceTMC_ELBO if args.jit else infer.TraceTMC_ELBO elbo = Elbo(max_plate_nesting=1 if model is model_0 else 2) tmc_model = handlers.infer_config(model, lambda msg: { "num_samples": args.tmc_num_samples, "expand": False } if msg["infer"].get("enumerate", None) == "parallel" else {} ) # noqa: E501 svi = infer.SVI(tmc_model, guide, optimizer, elbo) else: Elbo = infer.JitTraceEnum_ELBO if args.jit else infer.TraceEnum_ELBO elbo = Elbo(max_plate_nesting=1 if model is model_0 else 2, strict_enumeration_warning=True, jit_options={"time_compilation": args.time_compilation}) svi = infer.SVI(model, guide, optimizer, elbo) # We'll train on small minibatches. logging.info('Step\tLoss') for step in range(args.num_steps): loss = svi.step(sequences, lengths, batch_size=args.batch_size) logging.info('{: >5d}\t{}'.format(step, loss / num_observations)) if args.jit and args.time_compilation: logging.debug('time to compile: {} s.'.format( elbo._differentiable_loss.compile_time)) # We evaluate on the entire training dataset, # excluding the prior term so our results are comparable across models. train_loss = elbo.loss(model, guide, sequences, lengths, batch_size=sequences.shape[0], include_prior=False) logging.info('training loss = {}'.format(train_loss / num_observations)) # Finally we evaluate on the test dataset. logging.info('-' * 40) logging.info('Evaluating on {} test sequences'.format( len(data['test']['sequences']))) sequences = data['test']['sequences'][..., present_notes] lengths = data['test']['sequence_lengths'] if args.truncate: lengths = lengths.clamp(max=args.truncate) num_observations = float(lengths.sum()) # note that since we removed unseen notes above (to make the problem a bit easier and for # numerical stability) this test loss may not be directly comparable to numbers # reported on this dataset elsewhere. test_loss = elbo.loss(model, guide, sequences, lengths, batch_size=sequences.shape[0], include_prior=False) logging.info('test loss = {}'.format(test_loss / num_observations)) # We expect models with higher capacity to perform better, # but eventually overfit to the training set. capacity = sum( value.reshape(-1).size(0) for value in pyro.get_param_store().values()) logging.info('model_{} capacity = {} parameters'.format( args.model, capacity))
def compute_probs(self) -> torch.Tensor: z_probs = torch.zeros(self.data.Nt, self.data.F, self.Q) theta_probs = torch.zeros(self.K, self.data.Nt, self.data.F, self.Q) nbatch_size = self.nbatch_size fbatch_size = self.fbatch_size N = sum(self.data.is_ontarget) params = ["m", "x", "y"] params = list(map(lambda x: [f"{x}_k{i}" for i in range(self.K)], params)) params = list(itertools.chain(*params)) params += ["z", "theta"] theta_dims = tuple(i for i in range(0, 2, 2)) z_dims = tuple(i for i in range(1, 2, 2)) m_dims = tuple(i for i in range(2, self.K + 2)) for ndx in torch.split(torch.arange(N), nbatch_size): for fdx in torch.split(torch.arange(self.data.F), fbatch_size): self.n = ndx self.f = fdx self.nbatch_size = len(ndx) self.fbatch_size = len(fdx) qdx = torch.arange(self.Q) with torch.no_grad(), pyro.plate( "particles", size=50, dim=-4 ), handlers.enum(first_available_dim=-5): guide_tr = handlers.trace(self.guide).get_trace() model_tr = handlers.trace( handlers.replay( handlers.block(self.model, hide=["data"]), trace=guide_tr ) ).get_trace() model_tr.compute_log_prob() guide_tr.compute_log_prob() # 0 - theta # 1 - z # 2 - m_1 # 3 - m_0 # p(z, theta, phi) logp = 0 for name in params: logp = logp + model_tr.nodes[name]["unscaled_log_prob"] # p(z, theta | phi) = p(z, theta, phi) - p(z, theta, phi).sum(z, theta) logp = logp - logp.logsumexp(z_dims + theta_dims) m_log_probs = [ guide_tr.nodes[f"m_k{k}"]["unscaled_log_prob"] for k in range(self.K) ] expectation = reduce(lambda x, y: x + y, m_log_probs) + logp # average over m result = expectation.logsumexp(m_dims) # marginalize theta z_logits = result.logsumexp(theta_dims) z_probs[ndx[:, None, None], fdx[:, None], qdx] = ( z_logits[1].exp().mean(-4) ) # marginalize z theta_logits = result.logsumexp(z_dims) theta_probs[:, ndx[:, None, None], fdx[:, None], qdx] = ( theta_logits[1:].exp().mean(-4) ) self.n = None self.f = None self.nbatch_size = nbatch_size self.fbatch_size = fbatch_size return z_probs, theta_probs