def assert_ok(model, guide=None, max_plate_nesting=None, **kwargs): """ Assert that enumeration runs... """ with pyro_backend("pyro"): pyro.clear_param_store() if guide is None: guide = lambda **kwargs: None # noqa: E731 q_pyro, q_funsor = LifoQueue(), LifoQueue() q_pyro.put(Trace()) q_funsor.put(Trace()) while not q_pyro.empty() and not q_funsor.empty(): with pyro_backend("pyro"): with handlers.enum(first_available_dim=-max_plate_nesting - 1): guide_tr_pyro = handlers.trace( handlers.queue( guide, q_pyro, escape_fn=iter_discrete_escape, extend_fn=iter_discrete_extend, )).get_trace(**kwargs) tr_pyro = handlers.trace( handlers.replay(model, trace=guide_tr_pyro)).get_trace(**kwargs) with pyro_backend("contrib.funsor"): with handlers.enum(first_available_dim=-max_plate_nesting - 1): guide_tr_funsor = handlers.trace( handlers.queue( guide, q_funsor, escape_fn=iter_discrete_escape, extend_fn=iter_discrete_extend, )).get_trace(**kwargs) tr_funsor = handlers.trace( handlers.replay(model, trace=guide_tr_funsor)).get_trace(**kwargs) # make sure all dimensions were cleaned up assert _DIM_STACK.local_frame is _DIM_STACK.global_frame assert (not _DIM_STACK.global_frame.name_to_dim and not _DIM_STACK.global_frame.dim_to_name) assert _DIM_STACK.outermost is None tr_pyro = prune_subsample_sites(tr_pyro.copy()) tr_funsor = prune_subsample_sites(tr_funsor.copy()) _check_traces(tr_pyro, tr_funsor)
def compute_probs(self) -> torch.Tensor: z_probs = torch.zeros(self.data.Nt, self.data.F) theta_probs = torch.zeros(self.K, self.data.Nt, self.data.F) nbatch_size = self.nbatch_size fbatch_size = self.fbatch_size N = sum(self.data.is_ontarget) for ndx in torch.split(torch.arange(N), nbatch_size): for fdx in torch.split(torch.arange(self.data.F), fbatch_size): self.n = ndx self.f = fdx self.nbatch_size = len(ndx) self.fbatch_size = len(fdx) with torch.no_grad(), pyro.plate( "particles", size=25, dim=-3), handlers.enum(first_available_dim=-4): guide_tr = handlers.trace(self.guide).get_trace() model_tr = handlers.trace( handlers.replay(self.model, trace=guide_tr)).get_trace() model_tr.compute_log_prob() guide_tr.compute_log_prob() # 0 - theta # 1 - z # 2 - m_1 # 3 - m_0 # p(z, theta, phi) logp = 0 for name in [ "z", "theta", "m_0", "m_1", "x_0", "x_1", "y_0", "y_1" ]: logp = logp + model_tr.nodes[name]["unscaled_log_prob"] # p(z, theta | phi) = p(z, theta, phi) - p(z, theta, phi).sum(z, theta) logp = logp - logp.logsumexp((0, 1)) expectation = (guide_tr.nodes["m_0"]["unscaled_log_prob"] + guide_tr.nodes["m_1"]["unscaled_log_prob"] + logp) # average over m result = expectation.logsumexp((2, 3)) # marginalize theta z_logits = result.logsumexp(0) z_probs[ndx[:, None], fdx] = z_logits[1].exp().mean(-3) # marginalize z theta_logits = result.logsumexp(1) theta_probs[:, ndx[:, None], fdx] = theta_logits[1:].exp().mean(-3) self.n = None self.f = None self.nbatch_size = nbatch_size self.fbatch_size = fbatch_size return z_probs, theta_probs
def main(args): if args.cuda: torch.set_default_tensor_type('torch.cuda.FloatTensor') logging.info('Loading data') data = poly.load_data(poly.JSB_CHORALES) logging.info('-' * 40) model = models[args.model] logging.info('Training {} on {} sequences'.format( model.__name__, len(data['train']['sequences']))) sequences = data['train']['sequences'] lengths = data['train']['sequence_lengths'] # find all the notes that are present at least once in the training set present_notes = ((sequences == 1).sum(0).sum(0) > 0) # remove notes that are never played (we remove 37/88 notes) sequences = sequences[..., present_notes] if args.truncate: lengths = lengths.clamp(max=args.truncate) sequences = sequences[:, :args.truncate] num_observations = float(lengths.sum()) pyro.set_rng_seed(args.seed) pyro.clear_param_store() pyro.enable_validation(__debug__) # We'll train using MAP Baum-Welch, i.e. MAP estimation while marginalizing # out the hidden state x. This is accomplished via an automatic guide that # learns point estimates of all of our conditional probability tables, # named probs_*. guide = AutoDelta( handlers.block(model, expose_fn=lambda msg: msg["name"].startswith("probs_"))) # To help debug our tensor shapes, let's print the shape of each site's # distribution, value, and log_prob tensor. Note this information is # automatically printed on most errors inside SVI. if args.print_shapes: first_available_dim = -2 if model is model_0 else -3 guide_trace = handlers.trace(guide).get_trace( sequences, lengths, args=args, batch_size=args.batch_size) model_trace = handlers.trace( handlers.replay(handlers.enum(model, first_available_dim), guide_trace)).get_trace(sequences, lengths, args=args, batch_size=args.batch_size) logging.info(model_trace.format_shapes()) # Bind non-PyTorch parameters to make these functions jittable. model = functools.partial(model, args=args) guide = functools.partial(guide, args=args) # Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting. # All of our models have two plates: "data" and "tones". optimizer = optim.Adam({'lr': args.learning_rate}) if args.tmc: if args.jit and not args.funsor: raise NotImplementedError( "jit support not yet added for TraceTMC_ELBO") Elbo = infer.JitTraceTMC_ELBO if args.jit else infer.TraceTMC_ELBO elbo = Elbo(max_plate_nesting=1 if model is model_0 else 2) tmc_model = handlers.infer_config(model, lambda msg: { "num_samples": args.tmc_num_samples, "expand": False } if msg["infer"].get("enumerate", None) == "parallel" else {} ) # noqa: E501 svi = infer.SVI(tmc_model, guide, optimizer, elbo) else: Elbo = infer.JitTraceEnum_ELBO if args.jit else infer.TraceEnum_ELBO elbo = Elbo(max_plate_nesting=1 if model is model_0 else 2, strict_enumeration_warning=True, jit_options={"time_compilation": args.time_compilation}) svi = infer.SVI(model, guide, optimizer, elbo) # We'll train on small minibatches. logging.info('Step\tLoss') for step in range(args.num_steps): loss = svi.step(sequences, lengths, batch_size=args.batch_size) logging.info('{: >5d}\t{}'.format(step, loss / num_observations)) if args.jit and args.time_compilation: logging.debug('time to compile: {} s.'.format( elbo._differentiable_loss.compile_time)) # We evaluate on the entire training dataset, # excluding the prior term so our results are comparable across models. train_loss = elbo.loss(model, guide, sequences, lengths, batch_size=sequences.shape[0], include_prior=False) logging.info('training loss = {}'.format(train_loss / num_observations)) # Finally we evaluate on the test dataset. logging.info('-' * 40) logging.info('Evaluating on {} test sequences'.format( len(data['test']['sequences']))) sequences = data['test']['sequences'][..., present_notes] lengths = data['test']['sequence_lengths'] if args.truncate: lengths = lengths.clamp(max=args.truncate) num_observations = float(lengths.sum()) # note that since we removed unseen notes above (to make the problem a bit easier and for # numerical stability) this test loss may not be directly comparable to numbers # reported on this dataset elsewhere. test_loss = elbo.loss(model, guide, sequences, lengths, batch_size=sequences.shape[0], include_prior=False) logging.info('test loss = {}'.format(test_loss / num_observations)) # We expect models with higher capacity to perform better, # but eventually overfit to the training set. capacity = sum( value.reshape(-1).size(0) for value in pyro.get_param_store().values()) logging.info('model_{} capacity = {} parameters'.format( args.model, capacity))
def test_enumeration_multi(model, weeks_data, days_data, vars1, vars2, history, use_replay): pyro.clear_param_store() with pyro_backend("contrib.funsor"): with handlers.enum(): enum_model = infer.config_enumerate(model, default="parallel") # sequential factors trace = handlers.trace(enum_model).get_trace( weeks_data, days_data, history, False) # vectorized trace if use_replay: guide_trace = handlers.trace( _guide_from_model(model)).get_trace( weeks_data, days_data, history, True) vectorized_trace = handlers.trace( handlers.replay(model, trace=guide_trace)).get_trace( weeks_data, days_data, history, True) else: vectorized_trace = handlers.trace(enum_model).get_trace( weeks_data, days_data, history, True) factors = list() # sequential weeks factors for i in range(len(weeks_data)): for v in vars1: factors.append(trace.nodes["{}_{}".format( v, i)]["funsor"]["log_prob"]) # sequential days factors for j in range(len(days_data)): for v in vars2: factors.append(trace.nodes["{}_{}".format( v, j)]["funsor"]["log_prob"]) vectorized_factors = list() # vectorized weeks factors for i in range(history): for v in vars1: vectorized_factors.append( vectorized_trace.nodes["{}_{}".format( v, i)]["funsor"]["log_prob"]) for i in range(history, len(weeks_data)): for v in vars1: vectorized_factors.append( vectorized_trace.nodes["{}_{}".format( v, slice(history, len(weeks_data)))]["funsor"]["log_prob"](**{ "weeks": i - history }, **{ "{}_{}".format( k, slice(history - j, len(weeks_data) - j)): "{}_{}".format(k, i - j) for j in range(history + 1) for k in vars1 })) # vectorized days factors for i in range(history): for v in vars2: vectorized_factors.append( vectorized_trace.nodes["{}_{}".format( v, i)]["funsor"]["log_prob"]) for i in range(history, len(days_data)): for v in vars2: vectorized_factors.append( vectorized_trace.nodes["{}_{}".format( v, slice(history, len(days_data)))]["funsor"]["log_prob"](**{ "days": i - history }, **{ "{}_{}".format( k, slice(history - j, len(days_data) - j)): "{}_{}".format(k, i - j) for j in range(history + 1) for k in vars2 })) # assert correct factors for f1, f2 in zip(factors, vectorized_factors): assert_close(f2, f1.align(tuple(f2.inputs))) # assert correct step expected_measure_vars = frozenset() actual_weeks_step = vectorized_trace.nodes["weeks"]["value"] # expected step: assume that all but the last var is markov expected_weeks_step = frozenset() for v in vars1[:-1]: v_step = tuple("{}_{}".format(v, i) for i in range(history)) \ + tuple("{}_{}".format(v, slice(j, len(weeks_data)-history+j)) for j in range(history+1)) expected_weeks_step |= frozenset({v_step}) # grab measure_vars, found only at sites that are not replayed if not use_replay: expected_measure_vars |= frozenset(v_step) actual_days_step = vectorized_trace.nodes["days"]["value"] # expected step: assume that all but the last var is markov expected_days_step = frozenset() for v in vars2[:-1]: v_step = tuple("{}_{}".format(v, i) for i in range(history)) \ + tuple("{}_{}".format(v, slice(j, len(days_data)-history+j)) for j in range(history+1)) expected_days_step |= frozenset({v_step}) # grab measure_vars, found only at sites that are not replayed if not use_replay: expected_measure_vars |= frozenset(v_step) assert actual_weeks_step == expected_weeks_step assert actual_days_step == expected_days_step # check measure_vars actual_measure_vars = terms_from_trace( vectorized_trace)["measure_vars"] assert actual_measure_vars == expected_measure_vars
def compute_probs(self) -> torch.Tensor: z_probs = torch.zeros(self.data.Nt, self.data.F, self.Q) theta_probs = torch.zeros(self.K, self.data.Nt, self.data.F, self.Q) nbatch_size = self.nbatch_size fbatch_size = self.fbatch_size N = sum(self.data.is_ontarget) params = ["m", "x", "y"] params = list( map(lambda x: [f"{x}_k{i}" for i in range(self.K)], params)) params = list(itertools.chain(*params)) params += ["z", "theta"] params = list( map(lambda x: [f"{x}_q{i}" for i in range(self.Q)], params)) params = list(itertools.chain(*params)) theta_dims = tuple(i for i in range(0, self.Q * 2, 2)) z_dims = tuple(i for i in range(1, self.Q * 2, 2)) m_dims = tuple(i for i in range(self.Q * 2, self.Q * (self.K + 2))) for ndx in torch.split(torch.arange(N), nbatch_size): for fdx in torch.split(torch.arange(self.data.F), fbatch_size): self.n = ndx self.f = fdx self.nbatch_size = len(ndx) self.fbatch_size = len(fdx) with torch.no_grad(), pyro.plate( "particles", size=5, dim=-3), handlers.enum(first_available_dim=-4): guide_tr = handlers.trace(self.guide).get_trace() model_tr = handlers.trace( handlers.replay(self.model, trace=guide_tr)).get_trace() model_tr.compute_log_prob() guide_tr.compute_log_prob() # 0 - theta # 1 - z # 2 - m_1 # 3 - m_0 # p(z, theta, phi) logp = 0 for name in params: logp = logp + model_tr.nodes[name]["unscaled_log_prob"] # p(z, theta | phi) = p(z, theta, phi) - p(z, theta, phi).sum(z, theta) logp = logp - logp.logsumexp(z_dims + theta_dims) m_log_probs = [ guide_tr.nodes[f"m_k{k}_q{q}"]["unscaled_log_prob"] for k in range(self.K) for q in range(self.Q) ] expectation = reduce(lambda x, y: x + y, m_log_probs) + logp # average over m result = expectation.logsumexp(m_dims) # marginalize theta z_logits = result.logsumexp(theta_dims) a = z_logits.exp().mean(-3) for q in range(self.Q): sum_dims = tuple(i for i in range(self.Q) if i != q) if sum_dims: a = a.sum(sum_dims) z_probs[ndx[:, None], fdx, q] = a[1] # marginalize z b = result.logsumexp(z_dims) for q in range(self.Q): sum_dims = tuple(i for i in range(self.Q) if i != q) if sum_dims: b = b.logsumexp(sum_dims) theta_probs[:, ndx[:, None], fdx, q] = b[1:].exp().mean(-3) self.n = None self.f = None self.nbatch_size = nbatch_size self.fbatch_size = fbatch_size return z_probs, theta_probs
def compute_probs(self) -> torch.Tensor: theta_probs = torch.zeros(self.K, self.data.Nt, self.data.F, self.Q) nbatch_size = self.nbatch_size N = sum(self.data.is_ontarget) for ndx in torch.split(torch.arange(N), nbatch_size): self.n = ndx self.nbatch_size = len(ndx) with torch.no_grad(), pyro.plate( "particles", size=5, dim=-4), handlers.enum(first_available_dim=-5): guide_tr = handlers.trace(self.guide).get_trace() model_tr = handlers.trace( handlers.replay(self.model, trace=guide_tr)).get_trace() model_tr.compute_log_prob() guide_tr.compute_log_prob() logp = {} result = {} for fsx in ("0", f"slice(1, {self.data.F}, None)"): logp[fsx] = 0 # collect log_prob terms p(z, theta, phi) for name in [ "z", "theta", "m_k0", "m_k1", "x_k0", "x_k1", "y_k0", "y_k1", ]: logp[fsx] += model_tr.nodes[f"{name}_f{fsx}"]["funsor"][ "log_prob"] if fsx == "0": # substitute MAP values of z into p(z=z_map, theta, phi) z_map = funsor.Tensor(self.z_map[ndx, 0].long(), dtype=2)["aois", "channels"] logp[fsx] = logp[fsx](**{f"z_f{fsx}": z_map}) # compute log_measure q for given z_map log_measure = ( guide_tr.nodes[f"m_k0_f{fsx}"]["funsor"]["log_measure"] + guide_tr.nodes[f"m_k1_f{fsx}"]["funsor"]["log_measure"] ) log_measure = log_measure(**{f"z_f{fsx}": z_map}) else: # substitute MAP values of z into p(z=z_map, theta, phi) z_map = funsor.Tensor(self.z_map[ndx, 1:].long(), dtype=2)["aois", "frames", "channels"] z_map_prev = funsor.Tensor(self.z_map[ndx, :-1].long(), dtype=2)["aois", "frames", "channels"] fsx_prev = f"slice(0, {self.data.F-1}, None)" logp[fsx] = logp[fsx](**{ f"z_f{fsx}": z_map, f"z_f{fsx_prev}": z_map_prev }) # compute log_measure q for given z_map log_measure = ( guide_tr.nodes[f"m_k0_f{fsx}"]["funsor"]["log_measure"] + guide_tr.nodes[f"m_k1_f{fsx}"]["funsor"]["log_measure"] ) log_measure = log_measure(**{ f"z_f{fsx}": z_map, f"z_f{fsx_prev}": z_map_prev }) # compute p(z_map, theta | phi) = p(z_map, theta, phi) - p(z_map, phi) logp[fsx] = logp[fsx] - logp[fsx].reduce( funsor.ops.logaddexp, f"theta_f{fsx}") # average over m in p * q result[fsx] = (logp[fsx] + log_measure).reduce( funsor.ops.logaddexp, frozenset({f"m_k0_f{fsx}", f"m_k1_f{fsx}"})) # average over particles result[fsx] = result[fsx].exp().reduce(funsor.ops.mean, "particles") theta_probs[:, ndx, 0] = result["0"].data[..., 1:].permute(2, 0, 1) theta_probs[:, ndx, 1:] = ( result[f"slice(1, {self.data.F}, None)"].data[..., 1:].permute( 3, 0, 1, 2)) self.n = None self.nbatch_size = nbatch_size return theta_probs