def test_block_full(self): model_trace = poutine.trace(poutine.block(self.model)).get_trace() guide_trace = poutine.trace(poutine.block(self.guide)).get_trace() for name in model_trace.nodes.keys(): assert model_trace.nodes[name]["type"] in ("args", "return") for name in guide_trace.nodes.keys(): assert guide_trace.nodes[name]["type"] in ("args", "return")
def test_block_full_hide_expose(self): try: poutine.block(self.model, hide=self.partial_sample_sites.keys(), expose=self.partial_sample_sites.keys())() assert False except AssertionError: assert True
def test_block_full_expose(self): model_trace = poutine.trace(poutine.block(self.model, expose=self.model_sites)).get_trace() guide_trace = poutine.trace(poutine.block(self.guide, expose=self.guide_sites)).get_trace() for name in self.model_sites: assert name in model_trace for name in self.guide_sites: assert name in guide_trace
def test_guide_list(auto_class): def model(): pyro.sample("x", dist.Normal(0., 1.)) pyro.sample("y", dist.MultivariateNormal(torch.zeros(5), torch.eye(5, 5))) guide = AutoGuideList(model) guide.add(auto_class(poutine.block(model, expose=["x"]), prefix="auto_x")) guide.add(auto_class(poutine.block(model, expose=["y"]), prefix="auto_y")) guide()
def test_block_partial_expose(self): model_trace = poutine.trace( poutine.block(self.model, expose=self.partial_sample_sites.keys())).get_trace() guide_trace = poutine.trace( poutine.block(self.guide, expose=self.partial_sample_sites.keys())).get_trace() for name in self.full_sample_sites.keys(): if name in self.partial_sample_sites: assert name in model_trace assert name in guide_trace else: assert name not in model_trace assert name not in guide_trace
def _setup_prototype(self, *args, **kwargs): # run the model so we can inspect its structure model = config_enumerate(self.model, default="parallel") self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(*args, **kwargs) self.prototype_trace = prune_subsample_sites(self.prototype_trace) if self.master is not None: self.master()._check_prototype(self.prototype_trace) self._discrete_sites = [] self._cond_indep_stacks = {} self._iaranges = {} for name, site in self.prototype_trace.nodes.items(): if site["type"] != "sample" or site["is_observed"]: continue if site["infer"].get("enumerate") != "parallel": raise NotImplementedError('Expected sample site "{}" to be discrete and ' 'configured for parallel enumeration'.format(name)) # collect discrete sample sites fn = site["fn"] Dist = type(fn) if Dist in (dist.Bernoulli, dist.Categorical, dist.OneHotCategorical): params = [("probs", fn.probs.detach().clone(), fn.arg_constraints["probs"])] else: raise NotImplementedError("{} is not supported".format(Dist.__name__)) self._discrete_sites.append((site, Dist, params)) # collect independence contexts self._cond_indep_stacks[name] = site["cond_indep_stack"] for frame in site["cond_indep_stack"]: if frame.vectorized: self._iaranges[frame.name] = frame else: raise NotImplementedError("AutoDiscreteParallel does not support pyro.irange")
def _gen_weighted_samples(self, *args, **kwargs): for tr, log_w in poutine.block(self.trace_dist._traces)(*args, **kwargs): if self.sites == "_RETURN": val = tr.nodes["_RETURN"]["value"] else: val = {name: tr.nodes[name]["value"] for name in self.sites} yield (val, log_w)
def test_trace_data(self): tr1 = poutine.trace( poutine.block(self.model, expose_types=["sample"])).get_trace() tr2 = poutine.trace( poutine.condition(self.model, data=tr1)).get_trace() assert tr2.nodes["latent2"]["type"] == "sample" and \ tr2.nodes["latent2"]["is_observed"] assert tr2.nodes["latent2"]["value"] is tr1.nodes["latent2"]["value"]
def test_block_tutorial_case(self): model_trace = poutine.trace(self.model).get_trace() guide_trace = poutine.trace( poutine.block(self.guide, hide_types=["observe"])).get_trace() assert "latent1" in model_trace assert "latent1" in guide_trace assert "obs" in model_trace assert "obs" not in guide_trace
def __call__(self, *args, **kwargs): # if first time if self.compiled is None: # param capture with poutine.block(): with poutine.trace(param_only=True) as first_param_capture: self.fn(*args, **kwargs) self._param_names = list(set(first_param_capture.trace.nodes.keys())) weakself = weakref.ref(self) @torch.jit.compile(**self._jit_options) def compiled(unconstrained_params, *args): self = weakself() constrained_params = {} for name, unconstrained_param in zip(self._param_names, unconstrained_params): constrained_param = pyro.param(name) # assume param has been initialized assert constrained_param.unconstrained() is unconstrained_param constrained_params[name] = constrained_param return poutine.replay( self.fn, params=constrained_params)(*args, **kwargs) self.compiled = compiled param_list = [pyro.param(name).unconstrained() for name in self._param_names] with poutine.block(hide=self._param_names): with poutine.trace(param_only=True) as param_capture: ret = self.compiled(param_list, *args, **kwargs) new_params = filter(lambda name: name not in self._param_names, param_capture.trace.nodes.keys()) for name in new_params: # enforce uniqueness if name not in self._param_names: self._param_names.append(name) return ret
def __call__(self, *args, **kwargs): traces, logits = [], [] for tr, logit in poutine.block(self._traces)(*args, **kwargs): traces.append(tr) logits.append(logit) logits = torch.stack(logits).squeeze() logits -= util.log_sum_exp(logits) if not isinstance(logits, torch.autograd.Variable): logits = Variable(logits) ix = dist.categorical(logits=logits, one_hot=False) return traces[ix.data[0]]
def test_discrete_parallel(continuous_class): K = 2 data = torch.tensor([0., 1., 10., 11., 12.]) def model(data): weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K))) locs = pyro.sample('locs', dist.Normal(0, 10).expand_by([K]).independent(1)) scale = pyro.sample('scale', dist.LogNormal(0, 1)) with pyro.iarange('data', len(data)): weights = weights.expand(torch.Size((len(data),)) + weights.shape) assignment = pyro.sample('assignment', dist.Categorical(weights)) pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data) guide = AutoGuideList(model) guide.add(continuous_class(poutine.block(model, hide=["assignment"]))) guide.add(AutoDiscreteParallel(poutine.block(model, expose=["assignment"]))) elbo = TraceEnum_ELBO(max_iarange_nesting=1) loss = elbo.loss_and_grads(model, guide, data) assert np.isfinite(loss), loss
def run(self, *args, **kwargs): """ Calls `self._traces` to populate execution traces from a stochastic Pyro model. :param args: optional args taken by `self._traces`. :param kwargs: optional keywords args taken by `self._traces`. """ self._init() for tr, logit in poutine.block(self._traces)(*args, **kwargs): self.exec_traces.append(tr) self.log_weights.append(logit) self._categorical = Categorical(logits=torch.tensor(self.log_weights)) return self
def __init__(self, model, guide=None, num_samples=None): """ Constructor. default to num_samples = 10, guide = model """ super(Importance, self).__init__() if num_samples is None: num_samples = 10 logger.warn("num_samples not provided, defaulting to {}".format(num_samples)) if guide is None: # propose from the prior by making a guide from the model by hiding observes guide = poutine.block(model, hide_types=["observe"]) self.num_samples = num_samples self.model = model self.guide = guide
def _setup_prototype(self, *args, **kwargs): # run the model so we can inspect its structure self.prototype_trace = poutine.block(poutine.trace(self.model).get_trace)(*args, **kwargs) self.prototype_trace = prune_subsample_sites(self.prototype_trace) if self.master is not None: self.master()._check_prototype(self.prototype_trace) self._iaranges = {} for name, site in self.prototype_trace.nodes.items(): if site["type"] != "sample" or site["is_observed"]: continue for frame in site["cond_indep_stack"]: if frame.vectorized: self._iaranges[frame.name] = frame else: raise NotImplementedError("AutoGuideList does not support pyro.irange")
def run(self, *args, **kwargs): """ Calls `self._traces` to populate execution traces from a stochastic Pyro model. :param args: optional args taken by `self._traces`. :param kwargs: optional keywords args taken by `self._traces`. """ self._reset() with poutine.block(): for i, vals in enumerate(self._traces(*args, **kwargs)): if len(vals) == 2: chain_id = 0 tr, logit = vals else: tr, logit, chain_id = vals assert chain_id < self.num_chains self.exec_traces.append(tr) self.log_weights.append(logit) self.chain_ids.append(chain_id) self._idx_by_chain[chain_id].append(i) self._categorical = Categorical(logits=torch.tensor(self.log_weights)) return self
def predict(self, x, trg_y): """Predict labels. Args: x (torch.utils.data.DataLoader): trg_y (np.array): array for storing the predicted labels Returns: float: loss """ # resize `self._test_wbench` if necessary n_instances = x[0].shape[0] self._resize_wbench(n_instances) self._test_wbench *= 0 # print("self._test_wbench:", repr(self._test_wbench)) with poutine.block(): with torch.no_grad(): for wbench_i in self._test_wbench: wbench_i[:n_instances] = self.guide(*x) mean = np.mean(self._test_wbench, axis=0) trg_y[:] = np.argmax(mean[:n_instances], axis=-1) return trg_y
def main(_argv): transition_alphas = torch.tensor([[10., 90.], [90., 10.]]) emission_alphas = torch.tensor([[[30., 20., 5.]], [[5., 10., 100.]]]) lengths = torch.randint(10, 30, (10000,)) trace = poutine.trace(model).get_trace(transition_alphas, emission_alphas, lengths) obs_sequences = [site['value'] for name, site in trace.nodes.items() if name.startswith("element_")] obs_sequences = torch.stack(obs_sequences, dim=-2) guide = AutoDelta(poutine.block(model, hide_fn=lambda site: site['name'].startswith('state')), init_loc_fn=init_to_sample) svi = SVI(model, guide, Adam(dict(lr=0.1)), JitTraceEnum_ELBO()) total = 1000 with tqdm.trange(total) as t: for i in t: loss = svi.step(0.5 * torch.ones((2, 2), dtype=torch.float), 0.3 * torch.ones((2, 1, 3), dtype=torch.float), lengths, obs_sequences) t.set_description_str(f"SVI ({i}/{total}): {loss}") median = guide.median() print("Transition probs: ", median['transition_probs'].detach().numpy()) print("Emission probs: ", median['emission_probs'].squeeze().detach().numpy())
def posterior_entropy(y_dist, design): # Important that y_dist is sampled *within* the function y = pyro.sample("conditioning_y", y_dist) y_dict = { label: y[i, ...] for i, label in enumerate(observation_labels) } conditioned_model = pyro.condition(model, data=y_dict) # Here just using SVI to run the MAP optimization guide.train() SVI(conditioned_model, guide=guide, loss=loss, optim=optim, num_steps=num_steps, num_samples=1).run(design) # Recover the entropy with poutine.block(): final_loss = loss(conditioned_model, guide, design) guide.finalize(final_loss, target_labels) entropy = mean_field_entropy(guide, [design], whitelist=target_labels) return entropy
def relbo(model, guide, *args, **kwargs): approximation = kwargs.pop('approximation', None) # Run the guide with the arguments passed to SVI.step() and trace the execution, # i.e. record all the calls to Pyro primitives like sample() and param(). guide_trace = trace(guide).get_trace(*args, **kwargs) # Now run the model with the same arguments and trace the execution. Because # model is being run with replay, whenever we encounter a sample site in the # model, instead of sampling from the corresponding distribution in the model, # we instead reuse the corresponding sample from the guide. In probabilistic # terms, this means our loss is constructed as an expectation w.r.t. the joint # distribution defined by the guide. model_trace = trace(replay(model, guide_trace)).get_trace(*args, **kwargs) approximation_trace = trace( replay(block(approximation, expose=["obs"]), guide_trace)).get_trace(*args, **kwargs) # We will accumulate the various terms of the ELBO in `elbo`. elbo = 0. # Loop over all the sample sites in the model and add the corresponding # log p(z) term to the ELBO. Note that this will also include any observed # data, i.e. sample sites with the keyword `obs=...`. elbo = elbo + model_trace.log_prob_sum() # Loop over all the sample sites in the guide and add the corresponding # -log q(z) term to the ELBO. elbo = elbo - guide_trace.log_prob_sum() elbo = elbo - approximation_trace.log_prob_sum() # Return (-elbo) since by convention we do gradient descent on a loss and # the ELBO is a lower bound that needs to be maximized. if elbo < 10e-8 and PRINT_TRACES: print('Guide trace') print(guide_trace.log_prob_sum()) print('Model trace') print(model_trace.log_prob_sum()) print('Approximation trace') print(approximation_trace.log_prob_sum()) return -elbo
def pol_att(self, x, y, t, e): num_samples = self.num_samples if not torch._C._get_tracing_state(): assert x.dim() == 2 and x.size(-1) == self.feature_dim dataloader = [x] print("Evaluating {} minibatches".format(len(dataloader))) result_pol = [] result_eatt = [] for x in dataloader: # x = self.whiten(x) with pyro.plate("num_particles", num_samples, dim=-2): with poutine.trace() as tr, poutine.block(hide=["y", "t"]): self.guide(x) with poutine.do(data=dict(t=torch.zeros(()))): y0 = poutine.replay(self.model.y_mean, tr.trace)(x) with poutine.do(data=dict(t=torch.ones(()))): y1 = poutine.replay(self.model.y_mean, tr.trace)(x) ite = (y1 - y0).mean(0) ite[t > 0] = -ite[t > 0] eatt = torch.abs(torch.mean(ite[(t + e) > 1])) pols = [] for s in range(num_samples): pols.append(policy_val(ypred1=y1[s], ypred0=y0[s], y=y, t=t)) pol = torch.stack(pols).mean(0) if not torch._C._get_tracing_state(): print("batch eATT = {:0.6g}".format(eatt)) print("batch RPOL = {:0.6g}".format(pol)) result_pol.append(pol) result_eatt.append(eatt) return torch.stack(result_pol), torch.stack(result_eatt)
def _guess_max_plate_nesting(self, model, guide, *args, **kwargs): """ Guesses max_plate_nesting by running the (model,guide) pair once without enumeration. This optimistically assumes static model structure. """ # Ignore validation to allow model-enumerated sites absent from the guide. with poutine.block(): guide_trace = poutine.trace(guide).get_trace(*args, **kwargs) model_trace = poutine.trace( poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs) guide_trace = prune_subsample_sites(guide_trace) model_trace = prune_subsample_sites(model_trace) sites = [site for trace in (model_trace, guide_trace) for site in trace.nodes.values() if site["type"] == "sample"] # Validate shapes now, since shape constraints will be weaker once # max_plate_nesting is changed from float('inf') to some finite value. # Here we know the traces are not enumerated, but later we'll need to # allow broadcasting of dims to the left of max_plate_nesting. if is_validation_enabled(): guide_trace.compute_log_prob() model_trace.compute_log_prob() for site in sites: check_site_shape(site, max_plate_nesting=float('inf')) dims = [frame.dim for site in sites for frame in site["cond_indep_stack"] if frame.vectorized] self.max_plate_nesting = -min(dims) if dims else 0 if self.vectorize_particles and self.num_particles > 1: self.max_plate_nesting += 1 logging.info('Guessed max_plate_nesting = {}'.format(self.max_plate_nesting))
def belief_policy_model_guide(belief, t, discount=1.0, discount_factor=0.95, max_depth=10): # prior weights is uniform weights = pyro.param("action_weights", torch.ones(len(actions)), constraint=dist.constraints.simplex) state = states[pyro.sample("s%d" % t, belief)] history = "" # we can start from empty history # This is just to generate the other variables with poutine.block(hide=["a%d" % t]): # Need to hide 'at' generated by the policy_model; # we don't care about it; because that's what we # are inferring -- we are calling pyro.sample("a%d" % t ..) next. history_policy_model(state, history, t, discount=discount, discount_factor=discount_factor, max_depth=max_depth) # We eventually generate actions based on the weights action = pyro.sample("a%d" % t, dist.Categorical(weights))
def align_samples(samples, model, particle_dim): """ Unsqueeze stacked samples such that their particle dim all aligns. This traces ``model`` to determine the ``event_dim`` of each site. """ assert particle_dim < 0 sample = {name: value[0] for name, value in samples.items()} with poutine.block(), poutine.trace() as tr, poutine.condition( data=sample): model() samples = samples.copy() for name, value in samples.items(): event_dim = tr.trace.nodes[name]["fn"].event_dim pad = event_dim - particle_dim - value.dim() if pad < 0: raise ValueError( "Cannot align samples, try moving particle_dim left") if pad > 0: shape = value.shape[:1] + (1, ) * pad + value.shape[1:] samples[name] = value.reshape(shape) return samples
def initialize(seed, model, data): global global_guide, svi pyro.set_rng_seed(seed) pyro.clear_param_store() exposed_params = [] # set the parameters inferred through the guide based on the kind of data if 'gr' in mtype: if dtype == 'norm': exposed_params = ['weights', 'concentration'] elif dtype == 'raw': exposed_params = ['weights', 'alpha', 'beta'] elif 'dim' in mtype: if dtype == 'norm': exposed_params = [ 'topic_weights', 'topic_concentration', 'participant_topics' ] elif dtype == 'raw': exposed_params = [ 'topic_weights', 'topic_a', 'topic_b', 'participant_topics' ] global_guide = AutoDelta(poutine.block(model, expose=exposed_params)) svi = SVI(model, global_guide, optim, loss=elbo) return svi.loss(model, global_guide, data)
def main(args): pyro.set_rng_seed(0) pyro.clear_param_store() pyro.enable_validation(__debug__) # load data if args.dataset == "dipper": capture_history_file = os.path.dirname( os.path.abspath(__file__)) + '/dipper_capture_history.csv' elif args.dataset == "vole": capture_history_file = os.path.dirname( os.path.abspath(__file__)) + '/meadow_voles_capture_history.csv' else: raise ValueError("Available datasets are \'dipper\' and \'vole\'.") capture_history = torch.tensor( np.genfromtxt(capture_history_file, delimiter=',')).float()[:, 1:] N, T = capture_history.shape print( "Loaded {} capture history for {} individuals collected over {} time periods." .format(args.dataset, N, T)) if args.dataset == "dipper" and args.model in ["4", "5"]: sex_file = os.path.dirname( os.path.abspath(__file__)) + '/dipper_sex.csv' sex = torch.tensor(np.genfromtxt(sex_file, delimiter=',')).float()[:, 1] print("Loaded dipper sex data.") elif args.dataset == "vole" and args.model in ["4", "5"]: raise ValueError( "Cannot run model_{} on meadow voles data, since we lack sex " + "information for these animals.".format(args.model)) else: sex = None model = models[args.model] # we use poutine.block to only expose the continuous latent variables # in the models to AutoDiagonalNormal (all of which begin with 'phi' # or 'rho') def expose_fn(msg): return msg["name"][0:3] in ['phi', 'rho'] # we use a mean field diagonal normal variational distributions (i.e. guide) # for the continuous latent variables. guide = AutoDiagonalNormal(poutine.block(model, expose_fn=expose_fn)) # since we enumerate the discrete random variables, # we need to use TraceEnum_ELBO or TraceTMC_ELBO. optim = Adam({'lr': args.learning_rate}) if args.tmc: elbo = TraceTMC_ELBO(max_plate_nesting=1) tmc_model = poutine.infer_config(model, lambda msg: { "num_samples": args.tmc_num_samples, "expand": False } if msg["infer"].get("enumerate", None) == "parallel" else {} ) # noqa: E501 svi = SVI(tmc_model, guide, optim, elbo) else: elbo = TraceEnum_ELBO(max_plate_nesting=1, num_particles=20, vectorize_particles=True) svi = SVI(model, guide, optim, elbo) losses = [] print( "Beginning training of model_{} with Stochastic Variational Inference." .format(args.model)) for step in range(args.num_steps): loss = svi.step(capture_history, sex) losses.append(loss) if step % 20 == 0 and step > 0 or step == args.num_steps - 1: print("[iteration %03d] loss: %.3f" % (step, np.mean(losses[-20:]))) # evaluate final trained model elbo_eval = TraceEnum_ELBO(max_plate_nesting=1, num_particles=2000, vectorize_particles=True) svi_eval = SVI(model, guide, optim, elbo_eval) print("Final loss: %.4f" % svi_eval.evaluate_loss(capture_history, sex))
def auto_guide_list_x(model): guide = AutoGuideList(model) guide.add(AutoDelta(poutine.block(model, expose=["x"]))) guide.add(AutoDiagonalNormal(poutine.block(model, hide=["x"]))) return guide
# for x in batch: # ret.append(inner(observed[x], start, 0, 0, prefix=str(x))[0]) # return ret return inner(observed, start, 0, 0) if __name__ == '__main__': observed = ["the dog bit the man".split()] nonterminals = ["S", "NP", "VP"] preterminals = ["DT", "NN", "VB"] model = partial(model, start="S", nonterminals=nonterminals, productions=[("NP", "VP"), ("DT", "NN"), ("VB", "NN")], preterminals=preterminals, terminals=["the", "dog", "man", "bit"]) expose_params = ["binary_%s" % nonterm for nonterm in nonterminals] \ + ["unary_%s" % preterm for preterm in preterminals] guide = AutoContinuous(poutine.block(model, expose=expose_params)) elbo = Trace_ELBO() optim = Adam({"lr": 0.001}) svi = SVI(model, guide, optim, elbo) for _ in range(30): loss = svi.step(observed[0]) print(loss)
def _create_autoguide( self, model, amortised, encoder_kwargs, data_transform, encoder_mode, init_loc_fn=init_to_mean(fallback=init_to_feasible), n_cat_list: list = [], encoder_instance=None, guide_class=AutoNormal, guide_kwargs: Optional[dict] = None, ): if guide_kwargs is None: guide_kwargs = dict() if not amortised: if getattr(model, "discrete_variables", None) is not None: model = poutine.block(model, hide=model.discrete_variables) if issubclass(guide_class, poutine.messenger.Messenger): # messenger guides don't need create_plates function _guide = guide_class( model, init_loc_fn=init_loc_fn, **guide_kwargs, ) else: _guide = guide_class( model, init_loc_fn=init_loc_fn, **guide_kwargs, create_plates=self.model.create_plates, ) else: encoder_kwargs = encoder_kwargs if isinstance( encoder_kwargs, dict) else dict() n_hidden = encoder_kwargs[ "n_hidden"] if "n_hidden" in encoder_kwargs.keys() else 200 if isinstance(data_transform, np.ndarray): # add extra info about gene clusters as input to NN self.register_buffer( "gene_clusters", torch.tensor(data_transform.astype("float32"))) n_in = model.n_vars + data_transform.shape[1] data_transform = self._data_transform_clusters() elif data_transform == "log1p": # use simple log1p transform data_transform = torch.log1p n_in = self.model.n_vars elif (isinstance(data_transform, dict) and "var_std" in list(data_transform.keys()) and "var_mean" in list(data_transform.keys())): # use data transform by scaling n_in = model.n_vars self.register_buffer( "var_mean", torch.tensor( data_transform["var_mean"].astype("float32").reshape( (1, n_in))), ) self.register_buffer( "var_std", torch.tensor( data_transform["var_std"].astype("float32").reshape( (1, n_in))), ) data_transform = self._data_transform_scale() else: # use custom data transform data_transform = data_transform n_in = model.n_vars amortised_vars = model.list_obs_plate_vars() if len(amortised_vars["input"]) >= 2: encoder_kwargs["n_cat_list"] = n_cat_list amortised_vars["input_transform"][0] = data_transform if getattr(model, "discrete_variables", None) is not None: model = poutine.block(model, hide=model.discrete_variables) _guide = AutoAmortisedHierarchicalNormalMessenger( model, amortised_plate_sites=amortised_vars, n_in=n_in, n_hidden=n_hidden, encoder_kwargs=encoder_kwargs, encoder_mode=encoder_mode, encoder_instance=encoder_instance, **guide_kwargs, ) return _guide
len(mix_params), device=x.device)) beta_scale = pyro.param( 'beta_resp_scale', torch.tril( 1. * torch.eye(len(mix_params), len(mix_params), device=x.device)), constraint=constraints.lower_cholesky) pyro.sample( "beta_resp", dist.MultivariateNormal(beta_loc, scale_tril=beta_scale).to_event(1)) # In[11]: guide = AutoGuideList(model) guide.add(AutoDiagonalNormal(poutine.block(model, expose=['theta', 'L_omega']))) guide.add(my_local_guide) # automatically wrapped in an AutoCallable # # Run variational inference # In[12]: # prepare data for running inference train_x = torch.tensor(alt_attributes, dtype=torch.float) train_x = train_x.cuda() train_y = torch.tensor(true_choices, dtype=torch.int) train_y = train_y.cuda() alt_av_cuda = torch.from_numpy(alt_availability) alt_av_cuda = alt_av_cuda.cuda() alt_av_mat = alt_availability.copy() alt_av_mat[np.where(alt_av_mat == 0)] = -1e9
def trainVI( data, hidden_dim, learning_rate=1e-03, n_iters=500, trace_embeddings_interval=20, writer=None, moved_points={}, sigma_fix=1e-3, ): """Train (Deep) PPCA model with VI. Using tensorboardX SummaryWriter to log to tensorboard. Logging the internal result by setting "trace_embeddings_interval" to the expected interval (50, 100, ...). Set it to value larger than "n_iters" to disable interval logging. """ pyro.enable_validation(True) pyro.set_rng_seed(0) pyro.clear_param_store() model = partial( ppca_model, hidden_dim=hidden_dim, z_dim=2, moved_points=moved_points, sigma_fix=sigma_fix, ) guide = AutoGuideList(model) guide.add(AutoDiagonalNormal(model=poutine.block(model, expose=["sigma"]))) guide.add(AutoDiagonalNormal(model=poutine.block(model, expose=["Z"]), prefix="qZ")) optim = Adam({"lr": learning_rate}) svi = SVI(model, guide, optim, loss=Trace_ELBO()) fig_title = f"lr={learning_rate}/hidden-dim={hidden_dim}" metric = DRMetric(X=data, Y=None) data = torch.tensor(data, dtype=torch.float) for n_iter in tqdm(range(n_iters)): loss = svi.step(data) if writer and n_iter % 10 == 0: writer.add_scalar("train_vi/loss", loss, n_iter) if writer and n_iter % trace_embeddings_interval == 0: z2d_loc = pyro.param("qZ_loc").reshape(-1, 2).data.numpy() auc_rnx = metric.update(Y=z2d_loc).auc_rnx() writer.add_scalar("metrics/auc_rnx", auc_rnx, n_iter) fig = get_fig_plot_z2d(z2d_loc, fig_title + f", auc_rnx={auc_rnx:.3f}") writer.add_figure("train_vi/z2d", fig, n_iter) # show named rvs print("List params and their size: ") for p_name, p_val in pyro.get_param_store().items(): print(p_name, p_val.shape) z2d_loc = pyro.param("qZ_loc").reshape(-1, 2).data.numpy() z2d_scale = pyro.param("qZ_scale").reshape(-1, 2).data.numpy() if writer: fig = get_fig_plot_z2d(z2d_loc, fig_title) writer.add_figure("train_vi/z2d", fig, n_iters) return z2d_loc, z2d_scale
def main(args): if args.cuda: torch.set_default_tensor_type("torch.cuda.FloatTensor") logging.info("Loading data") data = poly.load_data(poly.JSB_CHORALES) logging.info("-" * 40) model = models[args.model] logging.info("Training {} on {} sequences".format( model.__name__, len(data["train"]["sequences"]))) sequences = data["train"]["sequences"] lengths = data["train"]["sequence_lengths"] # find all the notes that are present at least once in the training set present_notes = (sequences == 1).sum(0).sum(0) > 0 # remove notes that are never played (we remove 37/88 notes) sequences = sequences[..., present_notes] if args.truncate: lengths = lengths.clamp(max=args.truncate) sequences = sequences[:, :args.truncate] num_observations = float(lengths.sum()) pyro.set_rng_seed(args.seed) pyro.clear_param_store() # We'll train using MAP Baum-Welch, i.e. MAP estimation while marginalizing # out the hidden state x. This is accomplished via an automatic guide that # learns point estimates of all of our conditional probability tables, # named probs_*. guide = AutoDelta( poutine.block(model, expose_fn=lambda msg: msg["name"].startswith("probs_"))) # To help debug our tensor shapes, let's print the shape of each site's # distribution, value, and log_prob tensor. Note this information is # automatically printed on most errors inside SVI. if args.print_shapes: first_available_dim = -2 if model is model_0 else -3 guide_trace = poutine.trace(guide).get_trace( sequences, lengths, args=args, batch_size=args.batch_size) model_trace = poutine.trace( poutine.replay(poutine.enum(model, first_available_dim), guide_trace)).get_trace(sequences, lengths, args=args, batch_size=args.batch_size) logging.info(model_trace.format_shapes()) # Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting. # All of our models have two plates: "data" and "tones". optim = Adam({"lr": args.learning_rate}) if args.tmc: if args.jit: raise NotImplementedError( "jit support not yet added for TraceTMC_ELBO") elbo = TraceTMC_ELBO(max_plate_nesting=1 if model is model_0 else 2) tmc_model = poutine.infer_config( model, lambda msg: { "num_samples": args.tmc_num_samples, "expand": False } if msg["infer"].get("enumerate", None) == "parallel" else {}, ) # noqa: E501 svi = SVI(tmc_model, guide, optim, elbo) else: Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO elbo = Elbo( max_plate_nesting=1 if model is model_0 else 2, strict_enumeration_warning=(model is not model_7), jit_options={"time_compilation": args.time_compilation}, ) svi = SVI(model, guide, optim, elbo) # We'll train on small minibatches. logging.info("Step\tLoss") for step in range(args.num_steps): loss = svi.step(sequences, lengths, args=args, batch_size=args.batch_size) logging.info("{: >5d}\t{}".format(step, loss / num_observations)) if args.jit and args.time_compilation: logging.debug("time to compile: {} s.".format( elbo._differentiable_loss.compile_time)) # We evaluate on the entire training dataset, # excluding the prior term so our results are comparable across models. train_loss = elbo.loss(model, guide, sequences, lengths, args, include_prior=False) logging.info("training loss = {}".format(train_loss / num_observations)) # Finally we evaluate on the test dataset. logging.info("-" * 40) logging.info("Evaluating on {} test sequences".format( len(data["test"]["sequences"]))) sequences = data["test"]["sequences"][..., present_notes] lengths = data["test"]["sequence_lengths"] if args.truncate: lengths = lengths.clamp(max=args.truncate) num_observations = float(lengths.sum()) # note that since we removed unseen notes above (to make the problem a bit easier and for # numerical stability) this test loss may not be directly comparable to numbers # reported on this dataset elsewhere. test_loss = elbo.loss(model, guide, sequences, lengths, args=args, include_prior=False) logging.info("test loss = {}".format(test_loss / num_observations)) # We expect models with higher capacity to perform better, # but eventually overfit to the training set. capacity = sum( value.reshape(-1).size(0) for value in pyro.get_param_store().values()) logging.info("{} capacity = {} parameters".format(model.__name__, capacity))
def fit_advi_iterative(self, n=3, method='advi', n_type='restart', n_iter=None, learning_rate=None, progressbar=True, num_workers=2, train_proportion=None, stratify_cv=None, l2_weight=False, sample_scaling_weight=0.5, checkpoints=None, checkpoint_dir='./checkpoints', tracking=False): r""" Train posterior using ADVI method. (maximising likehood of the data and minimising KL-divergence of posterior to prior) :param n: number of independent initialisations :param method: to allow for potential use of SVGD or MCMC (currently only ADVI implemented). :param n_type: type of repeated initialisation: 'restart' to pick different initial value, 'cv' for molecular cross-validation - splits counts into n datasets, for now, only n=2 is implemented 'bootstrap' for fitting the model to multiple downsampled datasets. Run `mod.bootstrap_data()` to generate variants of data :param n_iter: number of iterations, supersedes self.n_iter :param train_proportion: if not None, which proportion of cells to use for training and which for validation. :param checkpoints: int, list of int's or None, number of checkpoints to save while model training or list of iterations to save checkpoints on :param checkpoint_dir: str, directory to save checkpoints in :param tracking: bool, track all latent variables during training - if True makes training 2 times slower :return: None """ # initialise parameter store self.svi = {} self.hist = {} self.guide_i = {} self.samples = {} self.node_samples = {} if tracking: self.logp_hist = {} if n_iter is None: n_iter = self.n_iter if type(checkpoints) is int: if n_iter < checkpoints: checkpoints = n_iter checkpoints = np.linspace(0, n_iter, checkpoints + 1, dtype=int)[1:] self.checkpoints = list(checkpoints) else: self.checkpoints = checkpoints self.checkpoint_dir = checkpoint_dir self.n_type = n_type self.l2_weight = l2_weight self.sample_scaling_weight = sample_scaling_weight self.train_proportion = train_proportion if stratify_cv is not None: self.stratify_cv = stratify_cv if train_proportion is not None: self.validation_hist = {} self.training_hist = {} if tracking: self.logp_hist_val = {} self.logp_hist_train = {} if learning_rate is None: learning_rate = self.learning_rate if np.isin(n_type, ['bootstrap']): if self.X_data_sample is None: self.bootstrap_data(n=n) elif np.isin(n_type, ['cv']): self.generate_cv_data() # cv data added to self.X_data_sample init_names = ['init_' + str(i + 1) for i in np.arange(n)] for i, name in enumerate(init_names): ################### Initialise parameters & optimiser ################### # initialise Variational distribution = guide if method is 'advi': self.guide_i[name] = AutoGuideList(self.model) normal_guide_block = poutine.block( self.model, expose_all=True, hide_all=False, hide=self.point_estim + flatten_iterable(self.custom_guides.keys())) self.guide_i[name].append( AutoNormal(normal_guide_block, init_loc_fn=init_to_mean)) self.guide_i[name].append( AutoDelta( poutine.block(self.model, hide_all=True, expose=self.point_estim))) for k, v in self.custom_guides.items(): self.guide_i[name].append(v) elif method is 'custom': self.guide_i[name] = self.guide # initialise SVI inference method self.svi[name] = SVI( self.model, self.guide_i[name], optim.ClippedAdam({ 'lr': learning_rate, # limit the gradient step from becoming too large 'clip_norm': self.total_grad_norm_constraint }), loss=JitTrace_ELBO()) pyro.clear_param_store() self.set_initial_values() # record ELBO Loss history here self.hist[name] = [] if tracking: self.logp_hist[name] = defaultdict(list) if train_proportion is not None: self.validation_hist[name] = [] if tracking: self.logp_hist_val[name] = defaultdict(list) ################### Select data for this iteration ################### if np.isin(n_type, ['cv', 'bootstrap']): X_data = self.X_data_sample[i].astype(self.data_type) else: X_data = self.X_data.astype(self.data_type) ################### Training / validation split ################### # split into training and validation if train_proportion is not None: idx = np.arange(len(X_data)) train_idx, val_idx = train_test_split( idx, train_size=train_proportion, shuffle=True, stratify=self.stratify_cv) extra_data_val = { k: torch.FloatTensor(v[val_idx]).to(self.device) for k, v in self.extra_data.items() } extra_data_train = { k: torch.FloatTensor(v[train_idx]) for k, v in self.extra_data.items() } x_data_val = torch.FloatTensor(X_data[val_idx]).to(self.device) x_data = torch.FloatTensor(X_data[train_idx]) else: # just convert data to CPU tensors x_data = torch.FloatTensor(X_data) extra_data_train = { k: torch.FloatTensor(v) for k, v in self.extra_data.items() } ################### Move data to cuda - FULL data ################### # if not minibatch do this: if self.minibatch_size is None: # move tensors to CUDA x_data = x_data.to(self.device) for k in extra_data_train.keys(): extra_data_train[k] = extra_data_train[k].to(self.device) # extra_data_train = {k: v.to(self.device) for k, v in extra_data_train.items()} ################### MINIBATCH data ################### else: # create minibatches dataset = MiniBatchDataset(x_data, extra_data_train, return_idx=True) loader = DataLoader(dataset, batch_size=self.minibatch_size, num_workers=0) # TODO num_workers ################### Training the model ################### # start training in epochs epochs_iterator = tqdm(range(n_iter)) for epoch in epochs_iterator: if self.minibatch_size is None: ################### Training FULL data ################### iter_loss = self.step_train(name, x_data, extra_data_train) self.hist[name].append(iter_loss) # save data for posterior sampling self.x_data = x_data self.extra_data_train = extra_data_train if tracking: guide_tr, model_tr = self.step_trace( name, x_data, extra_data_train) self.logp_hist[name]['guide'].append( guide_tr.log_prob_sum().item()) self.logp_hist[name]['model'].append( model_tr.log_prob_sum().item()) for k, v in model_tr.nodes.items(): if "log_prob_sum" in v: self.logp_hist[name][k].append( v["log_prob_sum"].item()) else: ################### Training MINIBATCH data ################### aver_loss = [] if tracking: aver_logp_guide = [] aver_logp_model = [] aver_logp = defaultdict(list) for batch in loader: x_data_batch, extra_data_batch = batch x_data_batch = x_data_batch.to(self.device) extra_data_batch = { k: v.to(self.device) for k, v in extra_data_batch.items() } loss = self.step_train(name, x_data_batch, extra_data_batch) if tracking: guide_tr, model_tr = self.step_trace( name, x_data_batch, extra_data_batch) aver_logp_guide.append( guide_tr.log_prob_sum().item()) aver_logp_model.append( model_tr.log_prob_sum().item()) for k, v in model_tr.nodes.items(): if "log_prob_sum" in v: aver_logp[k].append( v["log_prob_sum"].item()) aver_loss.append(loss) iter_loss = np.sum(aver_loss) # save data for posterior sampling self.x_data = x_data_batch self.extra_data_train = extra_data_batch self.hist[name].append(iter_loss) if tracking: iter_logp_guide = np.sum(aver_logp_guide) iter_logp_model = np.sum(aver_logp_model) self.logp_hist[name]['guide'].append(iter_logp_guide) self.logp_hist[name]['model'].append(iter_logp_model) for k, v in aver_logp.items(): self.logp_hist[name][k].append(np.sum(v)) if self.checkpoints is not None: if (epoch + 1) in self.checkpoints: self.save_checkpoint(epoch + 1, prefix=name) ################### Evaluating cross-validation loss ################### if train_proportion is not None: iter_loss_val = self.step_eval_loss( name, x_data_val, extra_data_val) if tracking: guide_tr, model_tr = self.step_trace( name, x_data_val, extra_data_val) self.logp_hist_val[name]['guide'].append( guide_tr.log_prob_sum().item()) self.logp_hist_val[name]['model'].append( model_tr.log_prob_sum().item()) for k, v in model_tr.nodes.items(): if "log_prob_sum" in v: self.logp_hist_val[name][k].append( v["log_prob_sum"].item()) self.validation_hist[name].append(iter_loss_val) epochs_iterator.set_description(f'ELBO Loss: ' + '{:.4e}'.format(iter_loss) \ + ': Val loss: ' + '{:.4e}'.format(iter_loss_val)) else: epochs_iterator.set_description('ELBO Loss: ' + '{:.4e}'.format(iter_loss)) if epoch % 20 == 0: torch.cuda.empty_cache() if train_proportion is not None: # rescale loss self.validation_hist[name] = [ i / (1 - train_proportion) for i in self.validation_hist[name] ] self.hist[name] = [ i / train_proportion for i in self.hist[name] ] # reassing the main loss to be displayed self.training_hist[name] = self.hist[name] self.hist[name] = self.validation_hist[name] if tracking: for k, v in self.logp_hist[name].items(): self.logp_hist[name][k] = [ i / train_proportion for i in self.logp_hist[name][k] ] self.logp_hist_val[name][k] = [ i / (1 - train_proportion) for i in self.logp_hist_val[name][k] ] self.logp_hist_train[name] = self.logp_hist[name] self.logp_hist[name] = self.logp_hist_val[name] if self.verbose: print(plt.plot(np.log10(self.hist[name][0:])))
def fit_advi_iterative_simple( self, n: int = 3, method='advi', n_type='restart', n_iter=None, learning_rate=None, progressbar=True, ): r""" Find posterior using ADVI (deprecated) (maximising likehood of the data and minimising KL-divergence of posterior to prior) :param n: number of independent initialisations :param method: which approximation of the posterior (guide) to use?. * ``'advi'`` - Univariate normal approximation (pyro.infer.autoguide.AutoDiagonalNormal) * ``'custom'`` - Custom guide using conjugate posteriors :return: self.svi dictionary with svi pyro objects for each n, and sefl.elbo dictionary storing training history. """ # Pass data to pyro / pytorch self.x_data = torch.tensor(self.X_data.astype( self.data_type)) # .double() # initialise parameter store self.svi = {} self.hist = {} self.guide_i = {} self.samples = {} self.node_samples = {} self.n_type = n_type if n_iter is None: n_iter = self.n_iter if learning_rate is None: learning_rate = self.learning_rate if np.isin(n_type, ['bootstrap']): if self.X_data_sample is None: self.bootstrap_data(n=n) elif np.isin(n_type, ['cv']): self.generate_cv_data() # cv data added to self.X_data_sample init_names = ['init_' + str(i + 1) for i in np.arange(n)] for i, name in enumerate(init_names): # initialise Variational distributiion = guide if method is 'advi': self.guide_i[name] = AutoGuideList(self.model) self.guide_i[name].append( AutoNormal(poutine.block(self.model, expose_all=True, hide_all=False, hide=self.point_estim), init_loc_fn=init_to_mean)) self.guide_i[name].append( AutoDelta( poutine.block(self.model, hide_all=True, expose=self.point_estim))) elif method is 'custom': self.guide_i[name] = self.guide # initialise SVI inference method self.svi[name] = SVI( self.model, self.guide_i[name], optim.ClippedAdam({ 'lr': learning_rate, # limit the gradient step from becoming too large 'clip_norm': self.total_grad_norm_constraint }), loss=JitTrace_ELBO()) pyro.clear_param_store() # record ELBO Loss history here self.hist[name] = [] # pick dataset depending on the training mode and move to GPU if np.isin(n_type, ['cv', 'bootstrap']): self.x_data = torch.tensor(self.X_data_sample[i].astype( self.data_type)) else: self.x_data = torch.tensor(self.X_data.astype(self.data_type)) if self.use_cuda: # move tensors and modules to CUDA self.x_data = self.x_data.cuda() # train for n_iter it_iterator = tqdm(range(n_iter)) for it in it_iterator: hist = self.svi[name].step(self.x_data) it_iterator.set_description('ELBO Loss: ' + str(np.round(hist, 3))) self.hist[name].append(hist) # if it % 50 == 0 & self.verbose: # logging.info("Elbo loss: {}".format(hist)) if it % 500 == 0: torch.cuda.empty_cache()
def step(self, state): with poutine.block(hide_types=["observe"]): super().step(state.copy())
def _quantized_model(self): """ Quantized vectorized model used for parallel-scan enumerated inference. This method is called only outside particle_plate. """ C = len(self.compartments) T = self.duration Q = self.num_quant_bins R_shape = getattr(self.population, "shape", ()) # Region shape. # Sample global parameters and auxiliary variables. params = self.global_model() auxiliary, non_compartmental = self._sample_auxiliary() # Manually enumerate. curr, logp = quantize_enumerate(auxiliary, min=0, max=self.population, num_quant_bins=self.num_quant_bins) curr = OrderedDict(zip(self.compartments, curr.unbind(0))) logp = OrderedDict(zip(self.compartments, logp.unbind(0))) curr.update(non_compartmental) # Truncate final value from the right then pad initial value onto the left. init = self.initialize(params) prev = {} for name, value in init.items(): if name in self.compartments: if isinstance(value, torch.Tensor): value = value[..., None] # Because curr is enumerated on the right. prev[name] = cat2(value, curr[name][:-1], dim=-3 if self.is_regional else -2) else: # non-compartmental prev[name] = cat2(init[name], curr[name][:-1], dim=-curr[name].dim()) # Reshape to support broadcasting, similar to EnumMessenger. def enum_reshape(tensor, position): assert tensor.size(-1) == Q assert tensor.dim() <= self.max_plate_nesting + 2 tensor = tensor.permute(tensor.dim() - 1, *range(tensor.dim() - 1)) shape = [Q] + [1] * (position + self.max_plate_nesting - (tensor.dim() - 2)) shape.extend(tensor.shape[1:]) return tensor.reshape(shape) for e, name in enumerate(self.compartments): curr[name] = enum_reshape(curr[name], e) logp[name] = enum_reshape(logp[name], e) prev[name] = enum_reshape(prev[name], e + C) # Enable approximate inference by using aux as a non-enumerated proxy # for enumerated compartment values. for name in self.approximate: aux = auxiliary[self.compartments.index(name)] curr[name + "_approx"] = aux prev[name + "_approx"] = cat2(init[name], aux[:-1], dim=-2 if self.is_regional else -1) # Record transition factors. with poutine.block(), poutine.trace() as tr: with self.time_plate: t = slice(0, T, 1) # Used to slice data tensors. self._transition_bwd(params, prev, curr, t) tr.trace.compute_log_prob() for name, site in tr.trace.nodes.items(): if site["type"] == "sample": log_prob = site["log_prob"] if log_prob.dim() <= self.max_plate_nesting: # Not enumerated. pyro.factor("transition_" + name, site["log_prob_sum"]) continue if self.is_regional and log_prob.shape[-1:] != R_shape: # Poor man's tensor variable elimination. log_prob = log_prob.expand(log_prob.shape[:-1] + R_shape) / R_shape[0] logp[name] = site["log_prob"] # Manually perform variable elimination. logp = reduce(operator.add, logp.values()) logp = logp.reshape(Q ** C, Q ** C, T, -1) # prev, curr, T, batch logp = logp.permute(3, 2, 0, 1).squeeze(0) # batch, T, prev, curr logp = pyro.distributions.hmm._sequential_logmatmulexp(logp) # batch, prev, curr logp = logp.reshape(-1, Q ** C * Q ** C).logsumexp(-1).sum() warn_if_nan(logp) pyro.factor("transition", logp) # Apply final likelihood. prev = {name: prev[name + "_approx"] for name in self.approximate} curr = {name: curr[name + "_approx"] for name in self.approximate} with _disallow_latent_variables(".finalize()"): self.finalize(params, prev, curr) self._clear_plates()
def guide_autodelta(self): self.guide = AutoDelta(poutine.block(self.model, expose=['weights', 'locs', 'scale']))
def guide_autodiagnorm(self): self.guide = AutoDiagonalNormal(poutine.block(self.model, expose=['weights', 'locs', 'scale']))
def guide(self,MAP = False,*args, **kwargs): return AutoDelta(poutine.block(self.model, expose=['pi', 'norm_factor', 'cnv_probs']), init_loc_fn=self.init_fn())
def _contstructSVI(self, optimizationStrategy, total_steps): """ Constructs a list of SVI functions use to train model. Takes optimizationStrategy as input, which must be in ['Full','PosteriorOnly','GuideOnly'] """ #generate random sample just to build the param_store test_data = self.datasetSampler.GenerateRandomTrainingSample(2) if self.useCuda: test_data = self._sendDataToGPU(test_data) self.phenotypeModel.guide(*test_data) param_store = pyro.get_param_store() if optimizationStrategy == 'Full': _model = self.phenotypeModel.model _guide = self.phenotypeModel.guide else: if optimizationStrategy == 'DecoderOnly': _model = self.phenotypeModel.model _guide = poutine.block( self.phenotypeModel.guide, hide=[ x for x in param_store.get_all_param_names() if ('encoder' in x) ]) else: _model = poutine.block( self.phenotypeModel.model, hide=[ x for x in param_store.get_all_param_names() if ('decoder' in x) ]) _guide = self.phenotypeModel.guide #initialize the SVI class AdamWOptimArgs = {'weight_decay': self.AdamW_Weight_Decay} scheduler = OneCycleLR({ 'optimizer': AdamW, 'optim_args': AdamWOptimArgs, 'max_lr': self.maxLearningRate, 'total_steps': total_steps, 'pct_start': self.OneCycleParams['pctCycleIncrease'], 'div_factor': self.OneCycleParams['initLRDivisionFactor'], 'final_div_factor': self.OneCycleParams['finalLRDivisionFactor'] }) return { 'svi': SVI(_model, _guide, scheduler, loss=Trace_ELBO(num_particles=self.numParticles)), 'scheduler': scheduler }
def guide_multivariatenormal(self): self.guide = AutoMultivariateNormal(poutine.block(self.model, expose=['weights', 'locs', 'scale']))
def main(args): if args.cuda: torch.set_default_tensor_type('torch.cuda.FloatTensor') logging.info('Loading data') data = poly.load_data(poly.JSB_CHORALES) logging.info('-' * 40) model = models[args.model] logging.info('Training {} on {} sequences'.format( model.__name__, len(data['train']['sequences']))) sequences = data['train']['sequences'] lengths = data['train']['sequence_lengths'] # find all the notes that are present at least once in the training set present_notes = ((sequences == 1).sum(0).sum(0) > 0) # remove notes that are never played (we remove 37/88 notes) sequences = sequences[..., present_notes] if args.truncate: lengths.clamp_(max=args.truncate) sequences = sequences[:, :args.truncate] num_observations = float(lengths.sum()) pyro.set_rng_seed(0) pyro.clear_param_store() pyro.enable_validation(True) # We'll train using MAP Baum-Welch, i.e. MAP estimation while marginalizing # out the hidden state x. This is accomplished via an automatic guide that # learns point estimates of all of our conditional probability tables, # named probs_*. guide = AutoDelta( poutine.block(model, expose_fn=lambda msg: msg["name"].startswith("probs_"))) # To help debug our tensor shapes, let's print the shape of each site's # distribution, value, and log_prob tensor. Note this information is # automatically printed on most errors inside SVI. if args.print_shapes: first_available_dim = -2 if model is model_0 else -3 guide_trace = poutine.trace(guide).get_trace( sequences, lengths, args=args, batch_size=args.batch_size) model_trace = poutine.trace( poutine.replay(poutine.enum(model, first_available_dim), guide_trace)).get_trace(sequences, lengths, args=args, batch_size=args.batch_size) logging.info(model_trace.format_shapes()) # Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting. # All of our models have two plates: "data" and "tones". Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO elbo = Elbo(max_plate_nesting=1 if model is model_0 else 2) optim = Adam({'lr': args.learning_rate}) svi = SVI(model, guide, optim, elbo) # We'll train on small minibatches. logging.info('Step\tLoss') for step in range(args.num_steps): loss = svi.step(sequences, lengths, args=args, batch_size=args.batch_size) logging.info('{: >5d}\t{}'.format(step, loss / num_observations)) # We evaluate on the entire training dataset, # excluding the prior term so our results are comparable across models. train_loss = elbo.loss(model, guide, sequences, lengths, args, include_prior=False) logging.info('training loss = {}'.format(train_loss / num_observations)) # Finally we evaluate on the test dataset. logging.info('-' * 40) logging.info('Evaluating on {} test sequences'.format( len(data['test']['sequences']))) sequences = data['test']['sequences'][..., present_notes] lengths = data['test']['sequence_lengths'] if args.truncate: lengths.clamp_(max=args.truncate) num_observations = float(lengths.sum()) # note that since we removed unseen notes above (to make the problem a bit easier and for # numerical stability) this test loss may not be directly comparable to numbers # reported on this dataset elsewhere. test_loss = elbo.loss(model, guide, sequences, lengths, args=args, include_prior=False) logging.info('test loss = {}'.format(test_loss / num_observations)) # We expect models with higher capacity to perform better, # but eventually overfit to the training set. capacity = sum( value.reshape(-1).size(0) for value in pyro.get_param_store().values()) logging.info('{} capacity = {} parameters'.format(model.__name__, capacity))
tr = poutine.trace(guide).get_trace(torch.tensor(2.)) exp_msg = r"Error while computing score_parts at site 'test_site':\s*" \ r"The value argument must be within the support" with pytest.raises(ValueError, match=exp_msg): tr.compute_score_parts() def _model(a=torch.tensor(1.), b=torch.tensor(1.)): latent = pyro.sample("latent", dist.Beta(a, b)) return pyro.sample("test_site", dist.Bernoulli(latent), obs=torch.tensor(1)) @pytest.mark.parametrize('wrapper', [ lambda fn: poutine.block(fn), lambda fn: poutine.condition(fn, {'latent': 0.9}), lambda fn: poutine.enum(fn, -1), lambda fn: poutine.replay(fn, poutine.trace(fn).get_trace()), ]) def test_pickling(wrapper): wrapped = wrapper(_model) # default protocol cannot serialize torch.Size objects (see https://github.com/pytorch/pytorch/issues/20823) deserialized = pickle.loads( pickle.dumps(wrapped, protocol=pickle.HIGHEST_PROTOCOL)) obs = torch.tensor(0.5) pyro.set_rng_seed(0) actual_trace = poutine.trace(deserialized).get_trace(obs) pyro.set_rng_seed(0) expected_trace = poutine.trace(wrapped).get_trace(obs)
def _sample_posterior(model, first_available_dim, temperature, *args, **kwargs): # For internal use by infer_discrete. # Create an enumerated trace. with poutine.block(), EnumerateMessenger(first_available_dim): enum_trace = poutine.trace(model).get_trace(*args, **kwargs) enum_trace = prune_subsample_sites(enum_trace) enum_trace.compute_log_prob() enum_trace.pack_tensors() plate_to_symbol = enum_trace.plate_to_symbol # Collect a set of query sample sites to which the backward algorithm will propagate. log_probs = OrderedDict() sum_dims = set() queries = [] for node in enum_trace.nodes.values(): if node["type"] == "sample": ordinal = frozenset(plate_to_symbol[f.name] for f in node["cond_indep_stack"] if f.vectorized) log_prob = node["packed"]["log_prob"] log_probs.setdefault(ordinal, []).append(log_prob) sum_dims.update(log_prob._pyro_dims) for frame in node["cond_indep_stack"]: if frame.vectorized: sum_dims.remove(plate_to_symbol[frame.name]) # Note we mark all sample sites with require_backward to gather # enumerated sites and adjust cond_indep_stack of all sample sites. if not node["is_observed"]: queries.append(log_prob) require_backward(log_prob) # Run forward-backward algorithm, collecting the ordinal of each connected component. ring = _make_ring(temperature) log_probs = contract_tensor_tree(log_probs, sum_dims, ring=ring) # run forward algorithm query_to_ordinal = {} pending = object() # a constant value for pending queries for query in queries: query._pyro_backward_result = pending for ordinal, terms in log_probs.items(): for term in terms: if hasattr(term, "_pyro_backward"): term._pyro_backward() # run backward algorithm # Note: this is quadratic in number of ordinals for query in queries: if query not in query_to_ordinal and query._pyro_backward_result is not pending: query_to_ordinal[query] = ordinal # Construct a collapsed trace by gathering and adjusting cond_indep_stack. collapsed_trace = poutine.Trace() for node in enum_trace.nodes.values(): if node["type"] == "sample" and not node["is_observed"]: # TODO move this into a Leaf implementation somehow new_node = { "type": "sample", "name": node["name"], "is_observed": False, "infer": node["infer"].copy(), "cond_indep_stack": node["cond_indep_stack"], "value": node["value"], } log_prob = node["packed"]["log_prob"] if hasattr(log_prob, "_pyro_backward_result"): # Adjust the cond_indep_stack. ordinal = query_to_ordinal[log_prob] new_node["cond_indep_stack"] = tuple( f for f in node["cond_indep_stack"] if not f.vectorized or plate_to_symbol[f.name] in ordinal) # Gather if node depended on an enumerated value. sample = log_prob._pyro_backward_result if sample is not None: new_value = packed.pack(node["value"], node["infer"]["_dim_to_symbol"]) for index, dim in zip(jit_iter(sample), sample._pyro_sample_dims): if dim in new_value._pyro_dims: index._pyro_dims = sample._pyro_dims[1:] new_value = packed.gather(new_value, index, dim) new_node["value"] = packed.unpack(new_value, enum_trace.symbol_to_dim) collapsed_trace.add_node(node["name"], **new_node) # Replay the model against the collapsed trace. with SamplePosteriorMessenger(trace=collapsed_trace): return model(*args, **kwargs)
def main(args): # setup hyperparameters for the model hypers = { 'expected_sparsity': max(1.0, args.num_dimensions / 10), 'alpha1': 3.0, 'beta1': 1.0, 'alpha2': 3.0, 'beta2': 1.0, 'alpha3': 1.0, 'c': 1.0 } P = args.num_dimensions S = args.active_dimensions Q = args.quadratic_dimensions # generate artificial dataset X, Y, expected_thetas, expected_quad_dims = get_data(N=args.num_data, P=P, S=S, Q=Q, sigma_obs=args.sigma) loss_fn = Trace_ELBO().differentiable_loss # We initialize the AutoDelta guide (for MAP estimation) with args.num_trials many # initial parameters sampled from the vicinity of the median of the prior distribution # and then continue optimizing with the best performing initialization. init_losses = [] for restart in range(args.num_restarts): pyro.clear_param_store() pyro.set_rng_seed(restart) guide = AutoDelta(model, init_loc_fn=init_loc_fn) with torch.no_grad(): init_losses.append(loss_fn(model, guide, X, Y, hypers).item()) pyro.set_rng_seed(np.argmin(init_losses)) pyro.clear_param_store() guide = AutoDelta(model, init_loc_fn=init_loc_fn) # Instead of using pyro.infer.SVI and pyro.optim we instead construct our own PyTorch # optimizer and take charge of gradient-based optimization ourselves. with poutine.block(), poutine.trace(param_only=True) as param_capture: guide(X, Y, hypers) params = list( [pyro.param(name).unconstrained() for name in param_capture.trace]) adam = Adam(params, lr=args.lr) report_frequency = 50 print("Beginning MAP optimization...") # the optimization loop for step in range(args.num_steps): loss = loss_fn(model, guide, X, Y, hypers) / args.num_data loss.backward() adam.step() adam.zero_grad() # we manually reduce the learning rate according to this schedule if step in [100, 300, 700, 900]: adam.param_groups[0]['lr'] *= 0.2 if step % report_frequency == 0 or step == args.num_steps - 1: print("[step %04d] loss: %.5f" % (step, loss)) print("Expected singleton thetas:\n", expected_thetas.data.numpy()) # we do the final computation using double precision median = guide.median() # == mode for MAP inference active_dims, active_quad_dims = \ compute_posterior_stats(X.double(), Y.double(), median['msq'].double(), median['lambda'].double(), median['eta1'].double(), median['xisq'].double(), torch.tensor(hypers['c']).double(), median['sigma'].double()) expected_active_dims = np.arange(S).tolist() tp_singletons = len(set(active_dims) & set(expected_active_dims)) fp_singletons = len(set(active_dims) - set(expected_active_dims)) fn_singletons = len(set(expected_active_dims) - set(active_dims)) singleton_stats = (tp_singletons, fp_singletons, fn_singletons) tp_quads = len(set(active_quad_dims) & set(expected_quad_dims)) fp_quads = len(set(active_quad_dims) - set(expected_quad_dims)) fn_quads = len(set(expected_quad_dims) - set(active_quad_dims)) quad_stats = (tp_quads, fp_quads, fn_quads) # We report how well we did, i.e. did we recover the sparse set of coefficients # that we expected for our artificial dataset? print("[SUMMARY STATS]") print("Singletons (true positive, false positive, false negative): " + "(%d, %d, %d)" % singleton_stats) print("Quadratic (true positive, false positive, false negative): " + "(%d, %d, %d)" % quad_stats)
ax.add_artist(ell) # Save figure fig.savefig(figname) if __name__ == "__main__": pyro.enable_validation(True) pyro.set_rng_seed(42) # Create our model with a fixed number of components K = 2 data = get_samples() global_guide = AutoDelta(poutine.block(model, expose=['weights', 'locs', 'scales'])) global_guide = config_enumerate(global_guide, 'parallel') _, svi = initialize(data) true_colors = [0] * 100 + [1] * 100 plot(data, colors=true_colors, figname='pyro_init.png') for i in range(151): svi.step(data) if i % 50 == 0: locs = pyro.param('locs') scales = pyro.param('scales') weights = pyro.param('weights') assignment_probs = pyro.param('assignment_probs')