예제 #1
    def predict(self, forecast=0):
        Predict latent variables and optionally forecast forward.

        This may be run only after :meth:`fit_mcmc` and draws the same
        ``num_samples`` as passed to :meth:`fit_mcmc`.

        :param int forecast: The number of time steps to forecast forward.
        :returns: A dictionary mapping sample site name (or compartment name)
            to a tensor whose first dimension corresponds to sample batching.
        :rtype: dict
        if self.num_quant_bins > 1:
        if not self.samples:
            raise RuntimeError("Missing samples, try running .fit_mcmc() first")

        samples = self.samples
        num_samples = len(next(iter(samples.values())))
        particle_plate = pyro.plate("particles", num_samples,
                                    dim=-1 - self.max_plate_nesting)

        # Sample discrete auxiliary variables conditioned on the continuous
        # variables sampled by _quantized_model. This samples only time steps
        # [0:duration]. Here infer_discrete runs a forward-filter
        # backward-sample algorithm.
        logger.info("Predicting latent variables for {} time steps..."
        model = self._sequential_model
        model = poutine.condition(model, samples)
        model = particle_plate(model)
        if not self.relaxed:
            model = infer_discrete(model, first_available_dim=-2 - self.max_plate_nesting)
        trace = poutine.trace(model).get_trace()
        samples = OrderedDict((name, site["value"].expand(site["fn"].shape()))
                              for name, site in trace.nodes.items()
                              if site["type"] == "sample"
                              if not site_is_subsample(site)
                              if not site_is_factor(site))
        assert all(v.size(0) == num_samples for v in samples.values()), \
            {k: tuple(v.shape) for k, v in samples.items()}

        # Optionally forecast with the forward _generative_model. This samples
        # time steps [duration:duration+forecast].
        if forecast:
            logger.info("Forecasting {} steps ahead...".format(forecast))
            model = self._generative_model
            model = poutine.condition(model, samples)
            model = particle_plate(model)
            trace = poutine.trace(model).get_trace(forecast)
            samples = OrderedDict((name, site["value"])
                                  for name, site in trace.nodes.items()
                                  if site["type"] == "sample"
                                  if not site_is_subsample(site)
                                  if not site_is_factor(site))

        self._concat_series(samples, trace, forecast)
        assert all(v.size(0) == num_samples for v in samples.values()), \
            {k: tuple(v.shape) for k, v in samples.items()}
        return samples
예제 #2
 def _adjust_to_data(self, trace, data_trace):
     subsampled_idxs = dict()
     for name, site in trace.iter_stochastic_nodes():
         # Adjust subsample sites
         if site_is_subsample(site):
             site["fn"] = data_trace.nodes[name]["fn"]
             site["value"] = data_trace.nodes[name]["value"]
         # Adjust sites under conditionally independent stacks
         orig_cis_stack = site["cond_indep_stack"]
         site["cond_indep_stack"] = data_trace.nodes[name][
         assert len(orig_cis_stack) == len(site["cond_indep_stack"])
         site["fn"] = data_trace.nodes[name]["fn"]
         for ocis, cis in zip(orig_cis_stack, site["cond_indep_stack"]):
             # Select random sub-indices to replay values under conditionally independent stacks.
             # Otherwise, we assume there is an dependence of indexes between training data
             # and prediction data.
             assert ocis.name == cis.name
             assert not site_is_subsample(site)
             batch_dim = cis.dim - site["fn"].event_dim
             subsampled_idxs[cis.name] = subsampled_idxs.get(
                               ocis.size, (cis.size, ),
             site["value"] = site["value"].index_select(
                 batch_dim, subsampled_idxs[cis.name])
예제 #3
파일: util.py 프로젝트: pyro-ppl/pyro
def save_visualization(trace, graph_output):
    DEPRECATED Use :func:`pyro.infer.inspect.render_model()` instead.

    Take a trace generated by poutine.trace with `graph_type='dense'`
    and render the graph with the output saved to file.

    - non-reparameterized stochastic nodes are salmon
    - reparameterized stochastic nodes are half salmon, half grey
    - observation nodes are green

    :param pyro.poutine.Trace trace: a trace to be visualized
    :param graph_output: the graph will be saved to graph_output.pdf
    :type graph_output: str


        trace = pyro.poutine.trace(model, graph_type="dense").get_trace()
        save_visualization(trace, 'output')
        "`save_visualization` function is deprecated and will be removed in "
        "a future version.")

    import graphviz

    g = graphviz.Digraph()

    for label, node in trace.nodes.items():
        if site_is_subsample(node):
        shape = "ellipse"
        if label in trace.stochastic_nodes and label not in trace.reparameterized_nodes:
            fillcolor = "salmon"
        elif label in trace.reparameterized_nodes:
            fillcolor = "lightgrey;.5:salmon"
        elif label in trace.observation_nodes:
            fillcolor = "darkolivegreen3"
            # only visualize RVs

    for label1, label2 in trace.edges:
        if site_is_subsample(trace.nodes[label1]):
        if site_is_subsample(trace.nodes[label2]):
        g.edge(label1, label2)

    g.render(graph_output, view=False, cleanup=True)
예제 #4
    def _non_compartmental(self):
        A dict mapping name -> (distribution, is_regional) for all
        non-compartmental sites in :meth:`transition`. For simple models this
        is often empty; for time-heterogeneous models this may contain
        time-local latent variables.
        # Trace a simple invocation of .transition().
        with torch.no_grad(), poutine.block():
            params = self.global_model()
            prev = self.initialize(params)
            for name in self.approximate:
                prev[name + "_approx"] = prev[name]
            curr = prev.copy()
            with poutine.trace() as tr:
                self.transition(params, curr, 0)
            flows = self.compute_flows(prev, curr, 0)

        # Extract latent variables that are not compartmental flows.
        result = OrderedDict()
        for name, site in tr.trace.iter_stochastic_nodes():
            if name in flows or site_is_subsample(site):
            assert name.endswith("_0"), name
            name = name[:-2]
            assert name in self.series, name
            # TODO This supports only the region_plate. For full plate support,
            # this could be replaced by a self.plate() method as in EasyGuide.
            is_regional = any(f.name == "region" for f in site["cond_indep_stack"])
            result[name] = site["fn"], is_regional
        return result
예제 #5
파일: infer.py 프로젝트: yufengwa/pyro
def posterior_replay(model, posterior_samples, *args, **kwargs):
    Given a model and samples from the posterior (potentially with conjugate sites
    collapsed), return a `dict` of samples from the posterior with conjugate sites
    uncollapsed. Note that this can also be used to generate samples from the
    posterior predictive distribution.

    :param model: Python callable.
    :param dict posterior_samples: posterior samples keyed by site name.
    :param args: arguments to `model`.
    :param kwargs: keyword arguments to `model`.
    :return: `dict` of samples from the posterior.
    posterior_samples = posterior_samples.copy()
    num_samples = kwargs.pop("num_samples", None)
    assert posterior_samples or num_samples, "`num_samples` must be provided if `posterior_samples` is empty."
    if num_samples is None:
        num_samples = list(posterior_samples.values())[0].shape[0]

    return_samples = defaultdict(list)
    for i in range(num_samples):
        conditioned_nodes = {k: v[i] for k, v in posterior_samples.items()}
        collapsed_trace = poutine.trace(poutine.condition(collapse_conjugate(model), conditioned_nodes))\
            .get_trace(*args, **kwargs)
        trace = poutine.trace(uncollapse_conjugate(model,
                                                       *args, **kwargs)
        for name, site in trace.iter_stochastic_nodes():
            if not site_is_subsample(site):

    return {k: torch.stack(v) for k, v in return_samples.items()}
예제 #6
파일: util.py 프로젝트: zyxue/pyro
def save_visualization(trace, graph_output):
    :param pyro.poutine.Trace trace: a trace to be visualized
    :param graph_output: the graph will be saved to graph_output.pdf
    :type graph_output: str

    Take a trace generated by poutine.trace with `graph_type='dense'` and render
    the graph with the output saved to file.

    - non-reparameterized stochastic nodes are salmon
    - reparameterized stochastic nodes are half salmon, half grey
    - observation nodes are green


    trace = pyro.poutine.trace(model, graph_type="dense").get_trace()
    save_visualization(trace, 'output')
    g = graphviz.Digraph()

    for label, node in trace.nodes.items():
        if site_is_subsample(node):
        shape = 'ellipse'
        if label in trace.stochastic_nodes and label not in trace.reparameterized_nodes:
            fillcolor = 'salmon'
        elif label in trace.reparameterized_nodes:
            fillcolor = 'lightgrey;.5:salmon'
        elif label in trace.observation_nodes:
            fillcolor = 'darkolivegreen3'
            # only visualize RVs

    for label1, label2 in trace.edges:
        if site_is_subsample(trace.nodes[label1]):
        if site_is_subsample(trace.nodes[label2]):
        g.edge(label1, label2)

    g.render(graph_output, view=False, cleanup=True)
예제 #7
 def full_mass(self):
     A list of a single tuple of the names of global random variables.
     with torch.no_grad(), poutine.block(), poutine.trace() as tr:
     return [tuple(name for name, site in tr.trace.iter_stochastic_nodes()
                   if not site_is_subsample(site))]
예제 #8
파일: util.py 프로젝트: lewisKit/pyro
def get_iarange_stacks(trace):
    This builds a dict mapping site name to a set of iarange stacks.  Each
    iarange stack is a list of :class:`CondIndepStackFrame`s corresponding to
    an :class:`iarange`.  This information is used by :class:`Trace_ELBO` and
    return {name: [f for f in node["cond_indep_stack"] if f.vectorized]
            for name, node in trace.nodes.items()
            if node["type"] == "sample" and not site_is_subsample(node)}
예제 #9
def get_plate_stacks(trace):
    This builds a dict mapping site name to a set of plate stacks.  Each
    plate stack is a list of :class:`CondIndepStackFrame`s corresponding to
    an :class:`plate`.  This information is used by :class:`Trace_ELBO` and
    return {name: [f for f in node["cond_indep_stack"] if f.vectorized]
            for name, node in trace.nodes.items()
            if node["type"] == "sample" and not site_is_subsample(node)}
예제 #10
파일: util.py 프로젝트: Magica-Chen/pyro
def save_visualization(trace, graph_output):
    :param pyro.poutine.Trace trace: a trace to be visualized
    :param graph_output: the graph will be saved to graph_output.pdf
    :type graph_output: str

    Take a trace generated by poutine.trace with `graph_type='dense'` and render
    the graph with the output saved to file.

    - non-reparameterized stochastic nodes are salmon
    - reparameterized stochastic nodes are half salmon, half grey
    - observation nodes are green


    trace = pyro.poutine.trace(model, graph_type="dense").get_trace()
    save_visualization(trace, 'output')
    g = graphviz.Digraph()

    for label, node in trace.nodes.items():
        if site_is_subsample(node):
        shape = 'ellipse'
        if label in trace.stochastic_nodes and label not in trace.reparameterized_nodes:
            fillcolor = 'salmon'
        elif label in trace.reparameterized_nodes:
            fillcolor = 'lightgrey;.5:salmon'
        elif label in trace.observation_nodes:
            fillcolor = 'darkolivegreen3'
            # only visualize RVs
        g.node(label, label=label, shape=shape, style='filled', fillcolor=fillcolor)

    for label1, label2 in trace.edges:
        if site_is_subsample(trace.nodes[label1]):
        if site_is_subsample(trace.nodes[label2]):
        g.edge(label1, label2)

    g.render(graph_output, view=False, cleanup=True)
예제 #11
def _sample_posterior(model, first_available_dim, temperature, *args,

    if temperature == 0:
        sum_op, prod_op = funsor.ops.max, funsor.ops.add
        approx = funsor.approximations.argmax_approximate
    elif temperature == 1:
        sum_op, prod_op = funsor.ops.logaddexp, funsor.ops.add
        approx = funsor.montecarlo.MonteCarlo()
        raise ValueError("temperature must be 0 (map) or 1 (sample) for now")

    with block(), enum(first_available_dim=first_available_dim):
        # XXX replay against an empty Trace to ensure densities are not double-counted
        model_tr = trace(replay(model,
                                trace=Trace())).get_trace(*args, **kwargs)

    terms = terms_from_trace(model_tr)
    # terms["log_factors"] = [log p(x) for each observed or latent sample site x]
    # terms["log_measures"] = [log p(z) or other Dice factor
    #                          for each latent sample site z]

    with funsor.interpretations.lazy:
        log_prob = funsor.sum_product.sum_product(
            terms["log_factors"] + terms["log_measures"],
            eliminate=terms["measure_vars"] | terms["plate_vars"],
        log_prob = funsor.optimizer.apply_optimizer(log_prob)

    with approx:
        approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob)

    # construct a result trace to replay against the model
    sample_tr = model_tr.copy()
    sample_subs = {}
    for name, node in sample_tr.nodes.items():
        if node["type"] != "sample" or site_is_subsample(node):
        if node["is_observed"]:
            # "observed" values may be collapsed samples that depend on enumerated
            # values, so we have to slice them down
            # TODO this should really be handled entirely under the hood by adjoint
            node["funsor"] = {"value": node["funsor"]["value"](**sample_subs)}
            node["funsor"]["log_measure"] = approx_factors[node["funsor"]
            node["funsor"]["value"] = _get_support_value(
                node["funsor"]["log_measure"], name)
            sample_subs[name] = node["funsor"]["value"]

    with replay(trace=sample_tr):
        return model(*args, **kwargs)
예제 #12
    def _process_message(self, msg):
        if self._block:
        if site_is_subsample(msg):

        # Block sample statements.
        if msg["type"] == "sample":
            if isinstance(msg["fn"], Funsor) or isinstance(
                    msg["value"], (str, Funsor)):
                msg["stop"] = True
예제 #13
 def series(self):
     A frozenset of names of sample sites that are sampled each time step.
     # Trace a simple invocation of .transition().
     with torch.no_grad(), poutine.block():
         params = self.global_model()
         prev = self.initialize(params)
         for name in self.approximate:
             prev[name + "_approx"] = prev[name]
         curr = prev.copy()
         with poutine.trace() as tr:
             self.transition(params, curr, 0)
     return frozenset(re.match("(.*)_0", name).group(1)
                      for name, site in tr.trace.nodes.items()
                      if site["type"] == "sample"
                      if not site_is_subsample(site))
예제 #14
def is_sample_site(msg):
    if msg["type"] != "sample":
        return False
    if site_is_subsample(msg):
        return False

    # Ignore masked observations.
    if msg["is_observed"] and msg["mask"] is False:
        return False

    # Exclude deterministic sites.
    fn = msg["fn"]
    while hasattr(fn, "base_dist"):
        fn = fn.base_dist
    if type(fn).__name__ == "Delta":
        return False

    return True
예제 #15
def evaluate_log_posterior_density(model, posterior_samples, baseball_dataset):
    Evaluate the log probability density of observing the unseen data (season hits)
    given a model and posterior distribution over the parameters.
    _, test, player_names = train_test_split(baseball_dataset)
    at_bats_season, hits_season = test[:, 0], test[:, 1]
    with ignore_experimental_warning():
        trace = predictive(model, posterior_samples, at_bats_season, hits_season,
                           parallel=True, return_trace=True)
    # Use LogSumExp trick to evaluate $log(1/num_samples \sum_i p(new_data | \theta^{i})) $,
    # where $\theta^{i}$ are parameter samples from the model's posterior.
    log_joint = 0.
    for name, site in trace.nodes.items():
        if site["type"] == "sample" and not site_is_subsample(site):
            # We use `sum_rightmost(x, -1)` to take the sum of all rightmost dimensions of `x`
            # except the first dimension (which corresponding to the number of posterior samples)
            site_log_prob_sum = sum_rightmost(site['log_prob'], -1)
            log_joint += site_log_prob_sum
    posterior_pred_density = torch.logsumexp(log_joint, dim=0) - math.log(log_joint.shape[0])
    logging.info("\nLog posterior predictive density")
예제 #16
    def _pyro_post_sample(self, msg):
        if site_is_subsample(msg):

        name = msg["name"]
        if name not in self.trace:

        model_value = msg["value"]
        guide_value = self.trace.nodes[name]["value"]
        if model_value.shape == guide_value.shape:
            msg["value"] = guide_value

        # Search for a single dim with mismatched size.
        assert model_value.dim() == guide_value.dim()
        for dim in range(model_value.dim()):
            if model_value.size(dim) != guide_value.size(dim):
        assert model_value.size(dim) > guide_value.size(dim)
        assert model_value.shape[dim + 1:] == guide_value.shape[dim + 1:]
        split = guide_value.size(dim)
        index = (slice(None), ) * dim + (slice(split, None), )
        msg["value"] = torch.cat([guide_value, model_value[index]], dim=dim)
예제 #17
def get_model_relations(
    model: Callable,
    model_args: Optional[tuple] = None,
    model_kwargs: Optional[dict] = None,
    Infer relations of RVs and plates from given model and optionally data.
    See https://github.com/pyro-ppl/pyro/issues/949 for more details.

    This returns a dictionary with keys:

    -  "sample_sample" map each downstream sample site to a list of the upstream
       sample sites on which it depend;
    -  "sample_dist" maps each sample site to the name of the distribution at
       that site;
    -  "plate_sample" maps each plate name to a list of the sample sites within
       that plate; and
    -  "observe" is a list of observed sample sites.

    For example for the model::

        def model(data):
            m = pyro.sample('m', dist.Normal(0, 1))
            sd = pyro.sample('sd', dist.LogNormal(m, 1))
            with pyro.plate('N', len(data)):
                pyro.sample('obs', dist.Normal(m, sd), obs=data)

    the relation is::

        {'sample_sample': {'m': [], 'sd': ['m'], 'obs': ['m', 'sd']},
         'sample_dist': {'m': 'Normal', 'sd': 'LogNormal', 'obs': 'Normal'},
         'plate_sample': {'N': ['obs']},
         'observed': ['obs']}

    :param callable model: A model to inspect.
    :param model_args: Optional tuple of model args.
    :param model_kwargs: Optional dict of model kwargs.
    :rtype: dict
    if model_args is None:
        model_args = ()
    if model_kwargs is None:
        model_kwargs = {}

    with torch.random.fork_rng(), torch.no_grad(), pyro.validation_enabled(
        with TrackProvenance():
            trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)

    sample_sample = {}
    sample_param = {}
    sample_dist = {}
    param_constraint = {}
    plate_sample = defaultdict(list)
    observed = []

    def _get_type_from_frozenname(frozen_name):
        return trace.nodes[frozen_name]["type"]

    for name, site in trace.nodes.items():
        if site["type"] == "param":
            param_constraint[name] = str(site["kwargs"]["constraint"])

        if site["type"] != "sample" or site_is_subsample(site):

        sample_sample[name] = [
            for upstream in get_provenance(site["fn"].log_prob(site["value"]))
            if upstream != name
            and _get_type_from_frozenname(upstream) == "sample"

        sample_param[name] = [
            for upstream in get_provenance(site["fn"].log_prob(site["value"]))
            if upstream != name
            and _get_type_from_frozenname(upstream) == "param"

        sample_dist[name] = _get_dist_name(site["fn"])
        for frame in site["cond_indep_stack"]:
        if site["is_observed"]:

    def _resolve_plate_samples(plate_samples):
        for p, pv in plate_samples.items():
            pv = set(pv)
            for q, qv in plate_samples.items():
                qv = set(qv)
                if len(pv & qv) > 0 and len(pv - qv) > 0 and len(qv - pv) > 0:
                    plate_samples_ = plate_samples.copy()
                    plate_samples_[q] = pv & qv
                    plate_samples_[q + "__CLONE"] = qv - pv
                    return _resolve_plate_samples(plate_samples_)
        return plate_samples

    plate_sample = _resolve_plate_samples(plate_sample)
    # convert set to list to keep order of variables
    plate_sample = {
        k: [name for name in trace.nodes if name in v]
        for k, v in plate_sample.items()

    return {
        "sample_sample": sample_sample,
        "sample_param": sample_param,
        "sample_dist": sample_dist,
        "param_constraint": param_constraint,
        "plate_sample": dict(plate_sample),
        "observed": observed,
예제 #18
파일: guides.py 프로젝트: yufengwa/pyro
def prototype_hide_fn(msg):
    # Record only stochastic sites in the prototype_trace.
    return msg["type"] != "sample" or msg["is_observed"] or site_is_subsample(msg)
예제 #19
 def _process_message(self, msg):
     if self._block:
     if site_is_subsample(msg):
예제 #20
 def _pyro_post_sample(self, msg):
     if self._block:
     if site_is_subsample(msg):
예제 #21
파일: gaussian.py 프로젝트: pyro-ppl/pyro
    def _setup_prototype(self, *args, **kwargs) -> None:
        super()._setup_prototype(*args, **kwargs)

        self.locs = PyroModule()
        self.scales = PyroModule()
        self.white_vecs = PyroModule()
        self.prec_sqrts = PyroModule()
        self._factors = OrderedDict()
        self._plates = OrderedDict()
        self._event_numel = OrderedDict()
        self._unconstrained_event_shapes = OrderedDict()

        # Trace model dependencies.
        model = self._original_model[0]
        self._original_model = None
        self.dependencies = poutine.block(get_dependencies)(model, args, kwargs)[

        # Eliminate observations with no upstream latents.
        for d, upstreams in list(self.dependencies.items()):
            if all(self.prototype_trace.nodes[u]["is_observed"] for u in upstreams):
                del self.dependencies[d]
                del self.prototype_trace.nodes[d]

        # Collect factors and plates.
        for d, site in self.prototype_trace.nodes.items():
            # Prune non-essential parts of the trace to save memory.
            pruned_site, site = site, site.copy()

            # Collect factors and plates.
            if site["type"] != "sample" or site_is_subsample(site):
            assert all(f.vectorized for f in site["cond_indep_stack"])
            self._factors[d] = self._compress_site(site)
            plates = frozenset(site["cond_indep_stack"])
            if site["fn"].batch_shape != _plates_to_shape(plates):
                raise ValueError(
                    f"Shape mismatch at site '{d}'. "
                    "Are you missing a pyro.plate() or .to_event()?"
            if site["is_observed"]:
                # Break irrelevant observation plates.
                plates &= frozenset().union(
                    *(self._plates[u] for u in self.dependencies[d] if u != d)
            self._plates[d] = plates

            # Create location-scale parameters, one per latent variable.
            if site["is_observed"]:
                # This may slightly overestimate, e.g. for Multinomial.
                self._event_numel[d] = site["fn"].event_shape.numel()
                # Account for broken irrelevant observation plates.
                for f in set(site["cond_indep_stack"]) - plates:
                    self._event_numel[d] *= f.size
            with helpful_support_errors(site):
                init_loc = biject_to(site["fn"].support).inv(site["value"]).detach()
            batch_shape = site["fn"].batch_shape
            event_shape = init_loc.shape[len(batch_shape) :]
            self._unconstrained_event_shapes[d] = event_shape
            self._event_numel[d] = event_shape.numel()
            event_dim = len(event_shape)
            deep_setattr(self.locs, d, PyroParam(init_loc, event_dim=event_dim))
                    torch.full_like(init_loc, self._init_scale),

        # Create parameters for dependencies, one per factor.
        for d, site in self._factors.items():
            u_size = 0
            for u in self.dependencies[d]:
                if not self._factors[u]["is_observed"]:
                    broken_shape = _plates_to_shape(self._plates[u] - self._plates[d])
                    u_size += broken_shape.numel() * self._event_numel[u]
            d_size = self._event_numel[d]
            if site["is_observed"]:
                d_size = min(d_size, u_size)  # just an optimization
            batch_shape = _plates_to_shape(self._plates[d])

            # Create parameters of each Gaussian factor.
            white_vec = init_loc.new_zeros(batch_shape + (d_size,))
            # We initialize with noise to avoid singular gradient.
            prec_sqrt = torch.rand(
                batch_shape + (u_size, d_size),
            if not site["is_observed"]:
                # Initialize the [d,d] block to the identity matrix.
                prec_sqrt.diagonal(dim1=-2, dim2=-1).fill_(1)
            deep_setattr(self.white_vecs, d, PyroParam(white_vec, event_dim=1))
            deep_setattr(self.prec_sqrts, d, PyroParam(prec_sqrt, event_dim=2))
예제 #22
    def fit_svi(self, *,
                betas=(0.8, 0.99),
        Runs stochastic variational inference to generate posterior samples.

        This runs :class:`~pyro.infer.svi.SVI`, setting the ``.samples``
        attribute on completion.

        This approximate inference method is useful for quickly iterating on
        probabilistic models.

        :param int num_samples: Number of posterior samples to draw from the
            trained guide. Defaults to 100.
        :param int num_steps: Number of :class:`~pyro.infer.svi.SVI` steps.
        :param int num_particles: Number of :class:`~pyro.infer.svi.SVI`
            particles per step.
        :param int learning_rate: Learning rate for the
            :class:`~pyro.optim.clipped_adam.ClippedAdam` optimizer.
        :param int learning_rate_decay: Learning rate for the
            :class:`~pyro.optim.clipped_adam.ClippedAdam` optimizer. Note this
            is decay over the entire schedule, not per-step decay.
        :param tuple betas: Momentum parameters for the
            :class:`~pyro.optim.clipped_adam.ClippedAdam` optimizer.
        :param bool haar: Whether to use a Haar wavelet reparameterizer.
        :param int guide_rank: Rank of the auto normal guide. If zero (default)
            use an :class:`~pyro.infer.autoguide.AutoNormal` guide. If a
            positive integer or None, use an
            :class:`~pyro.infer.autoguide.AutoLowRankMultivariateNormal` guide.
            If the string "full", use an
            :class:`~pyro.infer.autoguide.AutoMultivariateNormal` guide. These
            latter two require more ``num_steps`` to fit.
        :param float init_scale: Initial scale of the
            :class:`~pyro.infer.autoguide.AutoLowRankMultivariateNormal` guide.
        :param bool jit: Whether to use a jit compiled ELBO.
        :param int log_every: How often to log svi losses.
        :param int heuristic_num_particles: Passed to :meth:`heuristic` as
            ``num_particles``. Defaults to 1024.
        :returns: Time series of SVI losses (useful to diagnose convergence).
        :rtype: list
        # Save configuration for .predict().
        self.relaxed = True
        self.num_quant_bins = 1

        # Setup Haar wavelet transform.
        if haar:
            time_dim = -2 if self.is_regional else -1
            dims = {"auxiliary": time_dim}
            supports = {"auxiliary": constraints.interval(-0.5, self.population + 0.5)}
            for name, (fn, is_regional) in self._non_compartmental.items():
                dims[name] = time_dim - fn.event_dim
                supports[name] = fn.support
            haar = _HaarSplitReparam(0, self.duration, dims, supports)

        # Heuristically initialize to feasible latents.
        heuristic_options = {k.replace("heuristic_", ""): options.pop(k)
                             for k in list(options)
                             if k.startswith("heuristic_")}
        assert not options, "unrecognized options: {}".format(", ".join(options))
        init_strategy = self._heuristic(haar, **heuristic_options)

        # Configure variational inference.
        logger.info("Running inference...")
        model = self._relaxed_model
        if haar:
            model = haar.reparam(model)
        if guide_rank == 0:
            guide = AutoNormal(model, init_loc_fn=init_strategy, init_scale=init_scale)
        elif guide_rank == "full":
            guide = AutoMultivariateNormal(model, init_loc_fn=init_strategy,
        elif guide_rank is None or isinstance(guide_rank, int):
            guide = AutoLowRankMultivariateNormal(model, init_loc_fn=init_strategy,
                                                  init_scale=init_scale, rank=guide_rank)
            raise ValueError("Invalid guide_rank: {}".format(guide_rank))
        Elbo = JitTrace_ELBO if jit else Trace_ELBO
        elbo = Elbo(max_plate_nesting=self.max_plate_nesting,
                    num_particles=num_particles, vectorize_particles=True,
        optim = ClippedAdam({"lr": learning_rate, "betas": betas,
                             "lrd": learning_rate_decay ** (1 / num_steps)})
        svi = SVI(model, guide, optim, elbo)

        # Run inference.
        start_time = default_timer()
        losses = []
        for step in range(1 + num_steps):
            loss = svi.step() / self.duration
            if step % log_every == 0:
                logger.info("step {} loss = {:0.4g}".format(step, loss))
        elapsed = default_timer() - start_time
        logger.info("SVI took {:0.1f} seconds, {:0.1f} step/sec"
                    .format(elapsed, (1 + num_steps) / elapsed))

        # Draw posterior samples.
        with torch.no_grad():
            particle_plate = pyro.plate("particles", num_samples,
                                        dim=-1 - self.max_plate_nesting)
            guide_trace = poutine.trace(particle_plate(guide)).get_trace()
            model_trace = poutine.trace(
                poutine.replay(particle_plate(model), guide_trace)).get_trace()
            self.samples = {name: site["value"] for name, site in model_trace.nodes.items()
                            if site["type"] == "sample"
                            if not site["is_observed"]
                            if not site_is_subsample(site)}
            if haar:
        assert all(v.size(0) == num_samples for v in self.samples.values()), \
            {k: tuple(v.shape) for k, v in self.samples.items()}

        return losses