def predict(self, forecast=0): """ Predict latent variables and optionally forecast forward. This may be run only after :meth:`fit_mcmc` and draws the same ``num_samples`` as passed to :meth:`fit_mcmc`. :param int forecast: The number of time steps to forecast forward. :returns: A dictionary mapping sample site name (or compartment name) to a tensor whose first dimension corresponds to sample batching. :rtype: dict """ if self.num_quant_bins > 1: _require_double_precision() if not self.samples: raise RuntimeError("Missing samples, try running .fit_mcmc() first") samples = self.samples num_samples = len(next(iter(samples.values()))) particle_plate = pyro.plate("particles", num_samples, dim=-1 - self.max_plate_nesting) # Sample discrete auxiliary variables conditioned on the continuous # variables sampled by _quantized_model. This samples only time steps # [0:duration]. Here infer_discrete runs a forward-filter # backward-sample algorithm. logger.info("Predicting latent variables for {} time steps..." .format(self.duration)) model = self._sequential_model model = poutine.condition(model, samples) model = particle_plate(model) if not self.relaxed: model = infer_discrete(model, first_available_dim=-2 - self.max_plate_nesting) trace = poutine.trace(model).get_trace() samples = OrderedDict((name, site["value"].expand(site["fn"].shape())) for name, site in trace.nodes.items() if site["type"] == "sample" if not site_is_subsample(site) if not site_is_factor(site)) assert all(v.size(0) == num_samples for v in samples.values()), \ {k: tuple(v.shape) for k, v in samples.items()} # Optionally forecast with the forward _generative_model. This samples # time steps [duration:duration+forecast]. if forecast: logger.info("Forecasting {} steps ahead...".format(forecast)) model = self._generative_model model = poutine.condition(model, samples) model = particle_plate(model) trace = poutine.trace(model).get_trace(forecast) samples = OrderedDict((name, site["value"]) for name, site in trace.nodes.items() if site["type"] == "sample" if not site_is_subsample(site) if not site_is_factor(site)) self._concat_series(samples, trace, forecast) assert all(v.size(0) == num_samples for v in samples.values()), \ {k: tuple(v.shape) for k, v in samples.items()} return samples
def _adjust_to_data(self, trace, data_trace): subsampled_idxs = dict() for name, site in trace.iter_stochastic_nodes(): # Adjust subsample sites if site_is_subsample(site): site["fn"] = data_trace.nodes[name]["fn"] site["value"] = data_trace.nodes[name]["value"] # Adjust sites under conditionally independent stacks orig_cis_stack = site["cond_indep_stack"] site["cond_indep_stack"] = data_trace.nodes[name][ "cond_indep_stack"] assert len(orig_cis_stack) == len(site["cond_indep_stack"]) site["fn"] = data_trace.nodes[name]["fn"] for ocis, cis in zip(orig_cis_stack, site["cond_indep_stack"]): # Select random sub-indices to replay values under conditionally independent stacks. # Otherwise, we assume there is an dependence of indexes between training data # and prediction data. assert ocis.name == cis.name assert not site_is_subsample(site) batch_dim = cis.dim - site["fn"].event_dim subsampled_idxs[cis.name] = subsampled_idxs.get( cis.name, torch.randint(0, ocis.size, (cis.size, ), device=site["value"].device)) site["value"] = site["value"].index_select( batch_dim, subsampled_idxs[cis.name])
def save_visualization(trace, graph_output): """ DEPRECATED Use :func:`pyro.infer.inspect.render_model()` instead. Take a trace generated by poutine.trace with `graph_type='dense'` and render the graph with the output saved to file. - non-reparameterized stochastic nodes are salmon - reparameterized stochastic nodes are half salmon, half grey - observation nodes are green :param pyro.poutine.Trace trace: a trace to be visualized :param graph_output: the graph will be saved to graph_output.pdf :type graph_output: str Example:: trace = pyro.poutine.trace(model, graph_type="dense").get_trace() save_visualization(trace, 'output') """ warnings.warn( "`save_visualization` function is deprecated and will be removed in " "a future version.") import graphviz g = graphviz.Digraph() for label, node in trace.nodes.items(): if site_is_subsample(node): continue shape = "ellipse" if label in trace.stochastic_nodes and label not in trace.reparameterized_nodes: fillcolor = "salmon" elif label in trace.reparameterized_nodes: fillcolor = "lightgrey;.5:salmon" elif label in trace.observation_nodes: fillcolor = "darkolivegreen3" else: # only visualize RVs continue g.node(label, label=label, shape=shape, style="filled", fillcolor=fillcolor) for label1, label2 in trace.edges: if site_is_subsample(trace.nodes[label1]): continue if site_is_subsample(trace.nodes[label2]): continue g.edge(label1, label2) g.render(graph_output, view=False, cleanup=True)
def _non_compartmental(self): """ A dict mapping name -> (distribution, is_regional) for all non-compartmental sites in :meth:`transition`. For simple models this is often empty; for time-heterogeneous models this may contain time-local latent variables. """ # Trace a simple invocation of .transition(). with torch.no_grad(), poutine.block(): params = self.global_model() prev = self.initialize(params) for name in self.approximate: prev[name + "_approx"] = prev[name] curr = prev.copy() with poutine.trace() as tr: self.transition(params, curr, 0) flows = self.compute_flows(prev, curr, 0) # Extract latent variables that are not compartmental flows. result = OrderedDict() for name, site in tr.trace.iter_stochastic_nodes(): if name in flows or site_is_subsample(site): continue assert name.endswith("_0"), name name = name[:-2] assert name in self.series, name # TODO This supports only the region_plate. For full plate support, # this could be replaced by a self.plate() method as in EasyGuide. is_regional = any(f.name == "region" for f in site["cond_indep_stack"]) result[name] = site["fn"], is_regional return result
def posterior_replay(model, posterior_samples, *args, **kwargs): r""" Given a model and samples from the posterior (potentially with conjugate sites collapsed), return a `dict` of samples from the posterior with conjugate sites uncollapsed. Note that this can also be used to generate samples from the posterior predictive distribution. :param model: Python callable. :param dict posterior_samples: posterior samples keyed by site name. :param args: arguments to `model`. :param kwargs: keyword arguments to `model`. :return: `dict` of samples from the posterior. """ posterior_samples = posterior_samples.copy() num_samples = kwargs.pop("num_samples", None) assert posterior_samples or num_samples, "`num_samples` must be provided if `posterior_samples` is empty." if num_samples is None: num_samples = list(posterior_samples.values())[0].shape[0] return_samples = defaultdict(list) for i in range(num_samples): conditioned_nodes = {k: v[i] for k, v in posterior_samples.items()} collapsed_trace = poutine.trace(poutine.condition(collapse_conjugate(model), conditioned_nodes))\ .get_trace(*args, **kwargs) trace = poutine.trace(uncollapse_conjugate(model, collapsed_trace)).get_trace( *args, **kwargs) for name, site in trace.iter_stochastic_nodes(): if not site_is_subsample(site): return_samples[name].append(site["value"]) return {k: torch.stack(v) for k, v in return_samples.items()}
def save_visualization(trace, graph_output): """ :param pyro.poutine.Trace trace: a trace to be visualized :param graph_output: the graph will be saved to graph_output.pdf :type graph_output: str Take a trace generated by poutine.trace with `graph_type='dense'` and render the graph with the output saved to file. - non-reparameterized stochastic nodes are salmon - reparameterized stochastic nodes are half salmon, half grey - observation nodes are green Example: trace = pyro.poutine.trace(model, graph_type="dense").get_trace() save_visualization(trace, 'output') """ g = graphviz.Digraph() for label, node in trace.nodes.items(): if site_is_subsample(node): continue shape = 'ellipse' if label in trace.stochastic_nodes and label not in trace.reparameterized_nodes: fillcolor = 'salmon' elif label in trace.reparameterized_nodes: fillcolor = 'lightgrey;.5:salmon' elif label in trace.observation_nodes: fillcolor = 'darkolivegreen3' else: # only visualize RVs continue g.node(label, label=label, shape=shape, style='filled', fillcolor=fillcolor) for label1, label2 in trace.edges: if site_is_subsample(trace.nodes[label1]): continue if site_is_subsample(trace.nodes[label2]): continue g.edge(label1, label2) g.render(graph_output, view=False, cleanup=True)
def full_mass(self): """ A list of a single tuple of the names of global random variables. """ with torch.no_grad(), poutine.block(), poutine.trace() as tr: self.global_model() return [tuple(name for name, site in tr.trace.iter_stochastic_nodes() if not site_is_subsample(site))]
def get_iarange_stacks(trace): """ This builds a dict mapping site name to a set of iarange stacks. Each iarange stack is a list of :class:`CondIndepStackFrame`s corresponding to an :class:`iarange`. This information is used by :class:`Trace_ELBO` and :class:`TraceGraph_ELBO`. """ return {name: [f for f in node["cond_indep_stack"] if f.vectorized] for name, node in trace.nodes.items() if node["type"] == "sample" and not site_is_subsample(node)}
def get_plate_stacks(trace): """ This builds a dict mapping site name to a set of plate stacks. Each plate stack is a list of :class:`CondIndepStackFrame`s corresponding to an :class:`plate`. This information is used by :class:`Trace_ELBO` and :class:`TraceGraph_ELBO`. """ return {name: [f for f in node["cond_indep_stack"] if f.vectorized] for name, node in trace.nodes.items() if node["type"] == "sample" and not site_is_subsample(node)}
def _sample_posterior(model, first_available_dim, temperature, *args, **kwargs): if temperature == 0: sum_op, prod_op = funsor.ops.max, funsor.ops.add approx = funsor.approximations.argmax_approximate elif temperature == 1: sum_op, prod_op = funsor.ops.logaddexp, funsor.ops.add approx = funsor.montecarlo.MonteCarlo() else: raise ValueError("temperature must be 0 (map) or 1 (sample) for now") with block(), enum(first_available_dim=first_available_dim): # XXX replay against an empty Trace to ensure densities are not double-counted model_tr = trace(replay(model, trace=Trace())).get_trace(*args, **kwargs) terms = terms_from_trace(model_tr) # terms["log_factors"] = [log p(x) for each observed or latent sample site x] # terms["log_measures"] = [log p(z) or other Dice factor # for each latent sample site z] with funsor.interpretations.lazy: log_prob = funsor.sum_product.sum_product( sum_op, prod_op, terms["log_factors"] + terms["log_measures"], eliminate=terms["measure_vars"] | terms["plate_vars"], plates=terms["plate_vars"], ) log_prob = funsor.optimizer.apply_optimizer(log_prob) with approx: approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob) # construct a result trace to replay against the model sample_tr = model_tr.copy() sample_subs = {} for name, node in sample_tr.nodes.items(): if node["type"] != "sample" or site_is_subsample(node): continue if node["is_observed"]: # "observed" values may be collapsed samples that depend on enumerated # values, so we have to slice them down # TODO this should really be handled entirely under the hood by adjoint node["funsor"] = {"value": node["funsor"]["value"](**sample_subs)} else: node["funsor"]["log_measure"] = approx_factors[node["funsor"] ["log_measure"]] node["funsor"]["value"] = _get_support_value( node["funsor"]["log_measure"], name) sample_subs[name] = node["funsor"]["value"] with replay(trace=sample_tr): return model(*args, **kwargs)
def _process_message(self, msg): if self._block: return if site_is_subsample(msg): return super()._process_message(msg) # Block sample statements. if msg["type"] == "sample": if isinstance(msg["fn"], Funsor) or isinstance( msg["value"], (str, Funsor)): msg["stop"] = True
def series(self): """ A frozenset of names of sample sites that are sampled each time step. """ # Trace a simple invocation of .transition(). with torch.no_grad(), poutine.block(): params = self.global_model() prev = self.initialize(params) for name in self.approximate: prev[name + "_approx"] = prev[name] curr = prev.copy() with poutine.trace() as tr: self.transition(params, curr, 0) return frozenset(re.match("(.*)_0", name).group(1) for name, site in tr.trace.nodes.items() if site["type"] == "sample" if not site_is_subsample(site))
def is_sample_site(msg): if msg["type"] != "sample": return False if site_is_subsample(msg): return False # Ignore masked observations. if msg["is_observed"] and msg["mask"] is False: return False # Exclude deterministic sites. fn = msg["fn"] while hasattr(fn, "base_dist"): fn = fn.base_dist if type(fn).__name__ == "Delta": return False return True
def evaluate_log_posterior_density(model, posterior_samples, baseball_dataset): """ Evaluate the log probability density of observing the unseen data (season hits) given a model and posterior distribution over the parameters. """ _, test, player_names = train_test_split(baseball_dataset) at_bats_season, hits_season = test[:, 0], test[:, 1] with ignore_experimental_warning(): trace = predictive(model, posterior_samples, at_bats_season, hits_season, parallel=True, return_trace=True) # Use LogSumExp trick to evaluate $log(1/num_samples \sum_i p(new_data | \theta^{i})) $, # where $\theta^{i}$ are parameter samples from the model's posterior. trace.compute_log_prob() log_joint = 0. for name, site in trace.nodes.items(): if site["type"] == "sample" and not site_is_subsample(site): # We use `sum_rightmost(x, -1)` to take the sum of all rightmost dimensions of `x` # except the first dimension (which corresponding to the number of posterior samples) site_log_prob_sum = sum_rightmost(site['log_prob'], -1) log_joint += site_log_prob_sum posterior_pred_density = torch.logsumexp(log_joint, dim=0) - math.log(log_joint.shape[0]) logging.info("\nLog posterior predictive density") logging.info("--------------------------------") logging.info("{:.4f}\n".format(posterior_pred_density))
def _pyro_post_sample(self, msg): if site_is_subsample(msg): return name = msg["name"] if name not in self.trace: return model_value = msg["value"] guide_value = self.trace.nodes[name]["value"] if model_value.shape == guide_value.shape: msg["value"] = guide_value return # Search for a single dim with mismatched size. assert model_value.dim() == guide_value.dim() for dim in range(model_value.dim()): if model_value.size(dim) != guide_value.size(dim): break assert model_value.size(dim) > guide_value.size(dim) assert model_value.shape[dim + 1:] == guide_value.shape[dim + 1:] split = guide_value.size(dim) index = (slice(None), ) * dim + (slice(split, None), ) msg["value"] = torch.cat([guide_value, model_value[index]], dim=dim)
def get_model_relations( model: Callable, model_args: Optional[tuple] = None, model_kwargs: Optional[dict] = None, ): """ Infer relations of RVs and plates from given model and optionally data. See https://github.com/pyro-ppl/pyro/issues/949 for more details. This returns a dictionary with keys: - "sample_sample" map each downstream sample site to a list of the upstream sample sites on which it depend; - "sample_dist" maps each sample site to the name of the distribution at that site; - "plate_sample" maps each plate name to a list of the sample sites within that plate; and - "observe" is a list of observed sample sites. For example for the model:: def model(data): m = pyro.sample('m', dist.Normal(0, 1)) sd = pyro.sample('sd', dist.LogNormal(m, 1)) with pyro.plate('N', len(data)): pyro.sample('obs', dist.Normal(m, sd), obs=data) the relation is:: {'sample_sample': {'m': [], 'sd': ['m'], 'obs': ['m', 'sd']}, 'sample_dist': {'m': 'Normal', 'sd': 'LogNormal', 'obs': 'Normal'}, 'plate_sample': {'N': ['obs']}, 'observed': ['obs']} :param callable model: A model to inspect. :param model_args: Optional tuple of model args. :param model_kwargs: Optional dict of model kwargs. :rtype: dict """ if model_args is None: model_args = () if model_kwargs is None: model_kwargs = {} with torch.random.fork_rng(), torch.no_grad(), pyro.validation_enabled( False): with TrackProvenance(): trace = poutine.trace(model).get_trace(*model_args, **model_kwargs) sample_sample = {} sample_param = {} sample_dist = {} param_constraint = {} plate_sample = defaultdict(list) observed = [] def _get_type_from_frozenname(frozen_name): return trace.nodes[frozen_name]["type"] for name, site in trace.nodes.items(): if site["type"] == "param": param_constraint[name] = str(site["kwargs"]["constraint"]) if site["type"] != "sample" or site_is_subsample(site): continue sample_sample[name] = [ upstream for upstream in get_provenance(site["fn"].log_prob(site["value"])) if upstream != name and _get_type_from_frozenname(upstream) == "sample" ] sample_param[name] = [ upstream for upstream in get_provenance(site["fn"].log_prob(site["value"])) if upstream != name and _get_type_from_frozenname(upstream) == "param" ] sample_dist[name] = _get_dist_name(site["fn"]) for frame in site["cond_indep_stack"]: plate_sample[frame.name].append(name) if site["is_observed"]: observed.append(name) def _resolve_plate_samples(plate_samples): for p, pv in plate_samples.items(): pv = set(pv) for q, qv in plate_samples.items(): qv = set(qv) if len(pv & qv) > 0 and len(pv - qv) > 0 and len(qv - pv) > 0: plate_samples_ = plate_samples.copy() plate_samples_[q] = pv & qv plate_samples_[q + "__CLONE"] = qv - pv return _resolve_plate_samples(plate_samples_) return plate_samples plate_sample = _resolve_plate_samples(plate_sample) # convert set to list to keep order of variables plate_sample = { k: [name for name in trace.nodes if name in v] for k, v in plate_sample.items() } return { "sample_sample": sample_sample, "sample_param": sample_param, "sample_dist": sample_dist, "param_constraint": param_constraint, "plate_sample": dict(plate_sample), "observed": observed, }
def prototype_hide_fn(msg): # Record only stochastic sites in the prototype_trace. return msg["type"] != "sample" or msg["is_observed"] or site_is_subsample(msg)
def _process_message(self, msg): if self._block: return if site_is_subsample(msg): return super()._process_message(msg)
def _pyro_post_sample(self, msg): if self._block: return if site_is_subsample(msg): return super()._pyro_post_sample(msg)
def _setup_prototype(self, *args, **kwargs) -> None: super()._setup_prototype(*args, **kwargs) self.locs = PyroModule() self.scales = PyroModule() self.white_vecs = PyroModule() self.prec_sqrts = PyroModule() self._factors = OrderedDict() self._plates = OrderedDict() self._event_numel = OrderedDict() self._unconstrained_event_shapes = OrderedDict() # Trace model dependencies. model = self._original_model[0] self._original_model = None self.dependencies = poutine.block(get_dependencies)(model, args, kwargs)[ "prior_dependencies" ] # Eliminate observations with no upstream latents. for d, upstreams in list(self.dependencies.items()): if all(self.prototype_trace.nodes[u]["is_observed"] for u in upstreams): del self.dependencies[d] del self.prototype_trace.nodes[d] # Collect factors and plates. for d, site in self.prototype_trace.nodes.items(): # Prune non-essential parts of the trace to save memory. pruned_site, site = site, site.copy() pruned_site.clear() # Collect factors and plates. if site["type"] != "sample" or site_is_subsample(site): continue assert all(f.vectorized for f in site["cond_indep_stack"]) self._factors[d] = self._compress_site(site) plates = frozenset(site["cond_indep_stack"]) if site["fn"].batch_shape != _plates_to_shape(plates): raise ValueError( f"Shape mismatch at site '{d}'. " "Are you missing a pyro.plate() or .to_event()?" ) if site["is_observed"]: # Break irrelevant observation plates. plates &= frozenset().union( *(self._plates[u] for u in self.dependencies[d] if u != d) ) self._plates[d] = plates # Create location-scale parameters, one per latent variable. if site["is_observed"]: # This may slightly overestimate, e.g. for Multinomial. self._event_numel[d] = site["fn"].event_shape.numel() # Account for broken irrelevant observation plates. for f in set(site["cond_indep_stack"]) - plates: self._event_numel[d] *= f.size continue with helpful_support_errors(site): init_loc = biject_to(site["fn"].support).inv(site["value"]).detach() batch_shape = site["fn"].batch_shape event_shape = init_loc.shape[len(batch_shape) :] self._unconstrained_event_shapes[d] = event_shape self._event_numel[d] = event_shape.numel() event_dim = len(event_shape) deep_setattr(self.locs, d, PyroParam(init_loc, event_dim=event_dim)) deep_setattr( self.scales, d, PyroParam( torch.full_like(init_loc, self._init_scale), constraint=self.scale_constraint, event_dim=event_dim, ), ) # Create parameters for dependencies, one per factor. for d, site in self._factors.items(): u_size = 0 for u in self.dependencies[d]: if not self._factors[u]["is_observed"]: broken_shape = _plates_to_shape(self._plates[u] - self._plates[d]) u_size += broken_shape.numel() * self._event_numel[u] d_size = self._event_numel[d] if site["is_observed"]: d_size = min(d_size, u_size) # just an optimization batch_shape = _plates_to_shape(self._plates[d]) # Create parameters of each Gaussian factor. white_vec = init_loc.new_zeros(batch_shape + (d_size,)) # We initialize with noise to avoid singular gradient. prec_sqrt = torch.rand( batch_shape + (u_size, d_size), dtype=init_loc.dtype, device=init_loc.device, ) prec_sqrt.sub_(0.5).mul_(self._init_scale) if not site["is_observed"]: # Initialize the [d,d] block to the identity matrix. prec_sqrt.diagonal(dim1=-2, dim2=-1).fill_(1) deep_setattr(self.white_vecs, d, PyroParam(white_vec, event_dim=1)) deep_setattr(self.prec_sqrts, d, PyroParam(prec_sqrt, event_dim=2))
def fit_svi(self, *, num_samples=100, num_steps=2000, num_particles=32, learning_rate=0.1, learning_rate_decay=0.01, betas=(0.8, 0.99), haar=True, init_scale=0.01, guide_rank=0, jit=False, log_every=200, **options): """ Runs stochastic variational inference to generate posterior samples. This runs :class:`~pyro.infer.svi.SVI`, setting the ``.samples`` attribute on completion. This approximate inference method is useful for quickly iterating on probabilistic models. :param int num_samples: Number of posterior samples to draw from the trained guide. Defaults to 100. :param int num_steps: Number of :class:`~pyro.infer.svi.SVI` steps. :param int num_particles: Number of :class:`~pyro.infer.svi.SVI` particles per step. :param int learning_rate: Learning rate for the :class:`~pyro.optim.clipped_adam.ClippedAdam` optimizer. :param int learning_rate_decay: Learning rate for the :class:`~pyro.optim.clipped_adam.ClippedAdam` optimizer. Note this is decay over the entire schedule, not per-step decay. :param tuple betas: Momentum parameters for the :class:`~pyro.optim.clipped_adam.ClippedAdam` optimizer. :param bool haar: Whether to use a Haar wavelet reparameterizer. :param int guide_rank: Rank of the auto normal guide. If zero (default) use an :class:`~pyro.infer.autoguide.AutoNormal` guide. If a positive integer or None, use an :class:`~pyro.infer.autoguide.AutoLowRankMultivariateNormal` guide. If the string "full", use an :class:`~pyro.infer.autoguide.AutoMultivariateNormal` guide. These latter two require more ``num_steps`` to fit. :param float init_scale: Initial scale of the :class:`~pyro.infer.autoguide.AutoLowRankMultivariateNormal` guide. :param bool jit: Whether to use a jit compiled ELBO. :param int log_every: How often to log svi losses. :param int heuristic_num_particles: Passed to :meth:`heuristic` as ``num_particles``. Defaults to 1024. :returns: Time series of SVI losses (useful to diagnose convergence). :rtype: list """ # Save configuration for .predict(). self.relaxed = True self.num_quant_bins = 1 # Setup Haar wavelet transform. if haar: time_dim = -2 if self.is_regional else -1 dims = {"auxiliary": time_dim} supports = {"auxiliary": constraints.interval(-0.5, self.population + 0.5)} for name, (fn, is_regional) in self._non_compartmental.items(): dims[name] = time_dim - fn.event_dim supports[name] = fn.support haar = _HaarSplitReparam(0, self.duration, dims, supports) # Heuristically initialize to feasible latents. heuristic_options = {k.replace("heuristic_", ""): options.pop(k) for k in list(options) if k.startswith("heuristic_")} assert not options, "unrecognized options: {}".format(", ".join(options)) init_strategy = self._heuristic(haar, **heuristic_options) # Configure variational inference. logger.info("Running inference...") model = self._relaxed_model if haar: model = haar.reparam(model) if guide_rank == 0: guide = AutoNormal(model, init_loc_fn=init_strategy, init_scale=init_scale) elif guide_rank == "full": guide = AutoMultivariateNormal(model, init_loc_fn=init_strategy, init_scale=init_scale) elif guide_rank is None or isinstance(guide_rank, int): guide = AutoLowRankMultivariateNormal(model, init_loc_fn=init_strategy, init_scale=init_scale, rank=guide_rank) else: raise ValueError("Invalid guide_rank: {}".format(guide_rank)) Elbo = JitTrace_ELBO if jit else Trace_ELBO elbo = Elbo(max_plate_nesting=self.max_plate_nesting, num_particles=num_particles, vectorize_particles=True, ignore_jit_warnings=True) optim = ClippedAdam({"lr": learning_rate, "betas": betas, "lrd": learning_rate_decay ** (1 / num_steps)}) svi = SVI(model, guide, optim, elbo) # Run inference. start_time = default_timer() losses = [] for step in range(1 + num_steps): loss = svi.step() / self.duration if step % log_every == 0: logger.info("step {} loss = {:0.4g}".format(step, loss)) losses.append(loss) elapsed = default_timer() - start_time logger.info("SVI took {:0.1f} seconds, {:0.1f} step/sec" .format(elapsed, (1 + num_steps) / elapsed)) # Draw posterior samples. with torch.no_grad(): particle_plate = pyro.plate("particles", num_samples, dim=-1 - self.max_plate_nesting) guide_trace = poutine.trace(particle_plate(guide)).get_trace() model_trace = poutine.trace( poutine.replay(particle_plate(model), guide_trace)).get_trace() self.samples = {name: site["value"] for name, site in model_trace.nodes.items() if site["type"] == "sample" if not site["is_observed"] if not site_is_subsample(site)} if haar: haar.aux_to_user(self.samples) assert all(v.size(0) == num_samples for v in self.samples.values()), \ {k: tuple(v.shape) for k, v in self.samples.items()} return losses