Esempio n. 1
0
def test_replay_enumerate_poutine(depth, first_available_dim):
    num_particles = 2
    y_dist = Categorical(torch.tensor([0.5, 0.25, 0.25]))

    def guide():
        pyro.sample("y", y_dist, infer={"enumerate": "parallel"})

    guide = poutine.enum(guide, first_available_dim=depth + first_available_dim)
    guide = poutine.trace(guide)
    guide_trace = guide.get_trace()

    def model():
        pyro.sample("x", Bernoulli(0.5))
        for i in range(depth):
            pyro.sample("a_{}".format(i), Bernoulli(0.5), infer={"enumerate": "parallel"})
        pyro.sample("y", y_dist, infer={"enumerate": "parallel"})
        for i in range(depth):
            pyro.sample("b_{}".format(i), Bernoulli(0.5), infer={"enumerate": "parallel"})

    model = poutine.enum(model, first_available_dim=first_available_dim)
    model = poutine.replay(model, trace=guide_trace)
    model = poutine.trace(model)

    for i in range(num_particles):
        tr = model.get_trace()
        assert tr.nodes["y"]["value"] is guide_trace.nodes["y"]["value"]
        tr.compute_log_prob()
        log_prob = sum(site["log_prob"] for name, site in tr.iter_stochastic_nodes())
        actual_shape = log_prob.shape
        expected_shape = (2,) * depth + (3,) + (2,) * depth + (1,) * first_available_dim
        assert actual_shape == expected_shape, 'error on iteration {}'.format(i)
Esempio n. 2
0
def test_replay_enumerate_poutine(depth, first_available_dim):
    num_particles = 2
    y_dist = Categorical(torch.tensor([0.5, 0.25, 0.25]))

    def guide():
        pyro.sample("y", y_dist, infer={"enumerate": "parallel"})

    guide = poutine.enum(guide, first_available_dim=first_available_dim - depth)
    guide = poutine.trace(guide)
    guide_trace = guide.get_trace()

    def model():
        pyro.sample("x", Bernoulli(0.5))
        for i in range(depth):
            pyro.sample("a_{}".format(i), Bernoulli(0.5), infer={"enumerate": "parallel"})
        pyro.sample("y", y_dist, infer={"enumerate": "parallel"})
        for i in range(depth):
            pyro.sample("b_{}".format(i), Bernoulli(0.5), infer={"enumerate": "parallel"})

    model = poutine.enum(model, first_available_dim=first_available_dim)
    model = poutine.replay(model, trace=guide_trace)
    model = poutine.trace(model)

    for i in range(num_particles):
        tr = model.get_trace()
        assert tr.nodes["y"]["value"] is guide_trace.nodes["y"]["value"]
        tr.compute_log_prob()
        log_prob = sum(site["log_prob"] for name, site in tr.iter_stochastic_nodes())
        actual_shape = log_prob.shape
        expected_shape = (2,) * depth + (3,) + (2,) * depth + (1,) * (-1 - first_available_dim)
        assert actual_shape == expected_shape, 'error on iteration {}'.format(i)
Esempio n. 3
0
 def _initialize_model_properties(self):
     if self.max_plate_nesting is None:
         self._guess_max_plate_nesting()
     # Wrap model in `poutine.enum` to enumerate over discrete latent sites.
     # No-op if model does not have any discrete latents.
     self.model = poutine.enum(config_enumerate(self.model),
                               first_available_dim=-1 -
                               self.max_plate_nesting)
     if self._automatic_transform_enabled:
         self.transforms = {}
     trace = poutine.trace(self.model).get_trace(*self._args,
                                                 **self._kwargs)
     for name, node in trace.iter_stochastic_nodes():
         if isinstance(node["fn"], _Subsample):
             continue
         if node["fn"].has_enumerate_support:
             self._has_enumerable_sites = True
             continue
         site_value = node["value"]
         if node["fn"].support is not constraints.real and self._automatic_transform_enabled:
             self.transforms[name] = biject_to(node["fn"].support).inv
             site_value = self.transforms[name](node["value"])
         self._r_shapes[name] = site_value.shape
         self._r_numels[name] = site_value.numel()
     self._trace_prob_evaluator = TraceEinsumEvaluator(
         trace, self._has_enumerable_sites, self.max_plate_nesting)
     mass_matrix_size = sum(self._r_numels.values())
     if self.full_mass:
         initial_mass_matrix = eye_like(site_value, mass_matrix_size)
     else:
         initial_mass_matrix = site_value.new_ones(mass_matrix_size)
     self._adapter.inverse_mass_matrix = initial_mass_matrix
Esempio n. 4
0
def test_enumerate_poutine(depth, first_available_dim):
    num_particles = 2

    def model():
        pyro.sample("x", Bernoulli(0.5))
        for i in range(depth):
            pyro.sample("a_{}".format(i),
                        Bernoulli(0.5),
                        infer={"enumerate": "parallel"})

    model = poutine.enum(model, first_available_dim=first_available_dim)
    model = poutine.trace(model)

    for i in range(num_particles):
        tr = model.get_trace()
        tr.compute_log_prob()
        log_prob = sum(site["log_prob"]
                       for name, site in tr.iter_stochastic_nodes())
        actual_shape = log_prob.shape
        expected_shape = (2, ) * depth
        if depth:
            expected_shape = expected_shape + (1, ) * (-1 -
                                                       first_available_dim)
        assert actual_shape == expected_shape, 'error on iteration {}'.format(
            i)
Esempio n. 5
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a trace generator
        """
        # enable parallel enumeration
        guide = poutine.enum(guide,
                             first_available_dim=self.max_iarange_nesting)

        for i in range(self.num_particles):
            for guide_trace in iter_discrete_traces("flat", guide, *args,
                                                    **kwargs):
                model_trace = poutine.trace(poutine.replay(model,
                                                           trace=guide_trace),
                                            graph_type="flat").get_trace(
                                                *args, **kwargs)

                if is_validation_enabled():
                    check_model_guide_match(model_trace, guide_trace,
                                            self.max_iarange_nesting)
                guide_trace = prune_subsample_sites(guide_trace)
                model_trace = prune_subsample_sites(model_trace)
                if is_validation_enabled():
                    check_traceenum_requirements(model_trace, guide_trace)

                model_trace.compute_log_prob()
                guide_trace.compute_score_parts()
                if is_validation_enabled():
                    for site in model_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                    any_enumerated = False
                    for site in guide_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                            if site["infer"].get("enumerate"):
                                any_enumerated = True
                    if self.strict_enumeration_warning and not any_enumerated:
                        warnings.warn(
                            'TraceEnum_ELBO found no sample sites configured for enumeration. '
                            'If you want to enumerate sites, you need to @config_enumerate or set '
                            'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? '
                            'If you do not want to enumerate, consider using Trace_ELBO instead.'
                        )

                yield model_trace, guide_trace
Esempio n. 6
0
def test_enumerate_poutine(depth, first_available_dim):
    num_particles = 2

    def model():
        pyro.sample("x", Bernoulli(0.5))
        for i in range(depth):
            pyro.sample("a_{}".format(i), Bernoulli(0.5), infer={"enumerate": "parallel"})

    model = poutine.enum(model, first_available_dim=first_available_dim)
    model = poutine.trace(model)

    for i in range(num_particles):
        tr = model.get_trace()
        tr.compute_log_prob()
        log_prob = sum(site["log_prob"] for name, site in tr.iter_stochastic_nodes())
        actual_shape = log_prob.shape
        expected_shape = (2,) * depth
        if depth:
            expected_shape = expected_shape + (1,) * first_available_dim
        assert actual_shape == expected_shape, 'error on iteration {}'.format(i)
Esempio n. 7
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a trace generator
        """
        # enable parallel enumeration
        guide = poutine.enum(guide, first_available_dim=self.max_iarange_nesting)

        for i in range(self.num_particles):
            for guide_trace in iter_discrete_traces("flat", guide, *args, **kwargs):
                model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
                                            graph_type="flat").get_trace(*args, **kwargs)

                if is_validation_enabled():
                    check_model_guide_match(model_trace, guide_trace, self.max_iarange_nesting)
                guide_trace = prune_subsample_sites(guide_trace)
                model_trace = prune_subsample_sites(model_trace)
                if is_validation_enabled():
                    check_traceenum_requirements(model_trace, guide_trace)

                model_trace.compute_log_prob()
                guide_trace.compute_score_parts()
                if is_validation_enabled():
                    for site in model_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                    any_enumerated = False
                    for site in guide_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                            if site["infer"].get("enumerate"):
                                any_enumerated = True
                    if self.strict_enumeration_warning and not any_enumerated:
                        warnings.warn('TraceEnum_ELBO found no sample sites configured for enumeration. '
                                      'If you want to enumerate sites, you need to @config_enumerate or set '
                                      'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? '
                                      'If you do not want to enumerate, consider using Trace_ELBO instead.')

                yield model_trace, guide_trace
Esempio n. 8
0
def main(args):
    if args.cuda:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')

    logging.info('Loading data')
    data = poly.load_data(poly.JSB_CHORALES)

    logging.info('-' * 40)
    model = models[args.model]
    logging.info('Training {} on {} sequences'.format(
        model.__name__, len(data['train']['sequences'])))
    sequences = data['train']['sequences']
    lengths = data['train']['sequence_lengths']

    # find all the notes that are present at least once in the training set
    present_notes = ((sequences == 1).sum(0).sum(0) > 0)
    # remove notes that are never played (we remove 37/88 notes)
    sequences = sequences[..., present_notes]

    if args.truncate:
        lengths.clamp_(max=args.truncate)
        sequences = sequences[:, :args.truncate]
    num_observations = float(lengths.sum())
    pyro.set_rng_seed(0)
    pyro.clear_param_store()
    pyro.enable_validation(True)

    # We'll train using MAP Baum-Welch, i.e. MAP estimation while marginalizing
    # out the hidden state x. This is accomplished via an automatic guide that
    # learns point estimates of all of our conditional probability tables,
    # named probs_*.
    guide = AutoDelta(
        poutine.block(model,
                      expose_fn=lambda msg: msg["name"].startswith("probs_")))

    # To help debug our tensor shapes, let's print the shape of each site's
    # distribution, value, and log_prob tensor. Note this information is
    # automatically printed on most errors inside SVI.
    if args.print_shapes:
        first_available_dim = -2 if model is model_0 else -3
        guide_trace = poutine.trace(guide).get_trace(
            sequences, lengths, args=args, batch_size=args.batch_size)
        model_trace = poutine.trace(
            poutine.replay(poutine.enum(model, first_available_dim),
                           guide_trace)).get_trace(sequences,
                                                   lengths,
                                                   args=args,
                                                   batch_size=args.batch_size)
        logging.info(model_trace.format_shapes())

    # Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting.
    # All of our models have two plates: "data" and "tones".
    Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO
    elbo = Elbo(max_plate_nesting=1 if model is model_0 else 2)
    optim = Adam({'lr': args.learning_rate})
    svi = SVI(model, guide, optim, elbo)

    # We'll train on small minibatches.
    logging.info('Step\tLoss')
    for step in range(args.num_steps):
        loss = svi.step(sequences,
                        lengths,
                        args=args,
                        batch_size=args.batch_size)
        logging.info('{: >5d}\t{}'.format(step, loss / num_observations))

    # We evaluate on the entire training dataset,
    # excluding the prior term so our results are comparable across models.
    train_loss = elbo.loss(model,
                           guide,
                           sequences,
                           lengths,
                           args,
                           include_prior=False)
    logging.info('training loss = {}'.format(train_loss / num_observations))

    # Finally we evaluate on the test dataset.
    logging.info('-' * 40)
    logging.info('Evaluating on {} test sequences'.format(
        len(data['test']['sequences'])))
    sequences = data['test']['sequences'][..., present_notes]
    lengths = data['test']['sequence_lengths']
    if args.truncate:
        lengths.clamp_(max=args.truncate)
    num_observations = float(lengths.sum())

    # note that since we removed unseen notes above (to make the problem a bit easier and for
    # numerical stability) this test loss may not be directly comparable to numbers
    # reported on this dataset elsewhere.
    test_loss = elbo.loss(model,
                          guide,
                          sequences,
                          lengths,
                          args=args,
                          include_prior=False)
    logging.info('test loss = {}'.format(test_loss / num_observations))

    # We expect models with higher capacity to perform better,
    # but eventually overfit to the training set.
    capacity = sum(
        value.reshape(-1).size(0) for value in pyro.get_param_store().values())
    logging.info('{} capacity = {} parameters'.format(model.__name__,
                                                      capacity))
Esempio n. 9
0
              r"The value argument must be within the support"
    with pytest.raises(ValueError, match=exp_msg):
        tr.compute_score_parts()


def _model(a=torch.tensor(1.), b=torch.tensor(1.)):
    latent = pyro.sample("latent", dist.Beta(a, b))
    return pyro.sample("test_site",
                       dist.Bernoulli(latent),
                       obs=torch.tensor(1))


@pytest.mark.parametrize('wrapper', [
    lambda fn: poutine.block(fn),
    lambda fn: poutine.condition(fn, {'latent': 0.9}),
    lambda fn: poutine.enum(fn, -1),
    lambda fn: poutine.replay(fn,
                              poutine.trace(fn).get_trace()),
])
def test_pickling(wrapper):
    wrapped = wrapper(_model)
    # default protocol cannot serialize torch.Size objects (see https://github.com/pytorch/pytorch/issues/20823)
    deserialized = pickle.loads(
        pickle.dumps(wrapped, protocol=pickle.HIGHEST_PROTOCOL))
    obs = torch.tensor(0.5)
    pyro.set_rng_seed(0)
    actual_trace = poutine.trace(deserialized).get_trace(obs)
    pyro.set_rng_seed(0)
    expected_trace = poutine.trace(wrapped).get_trace(obs)
    assert tuple(actual_trace) == tuple(expected_trace.nodes)
    assert_close([
Esempio n. 10
0
File: hmm.py Progetto: pyro-ppl/pyro
def main(args):
    if args.cuda:
        torch.set_default_tensor_type("torch.cuda.FloatTensor")

    logging.info("Loading data")
    data = poly.load_data(poly.JSB_CHORALES)

    logging.info("-" * 40)
    model = models[args.model]
    logging.info("Training {} on {} sequences".format(
        model.__name__, len(data["train"]["sequences"])))
    sequences = data["train"]["sequences"]
    lengths = data["train"]["sequence_lengths"]

    # find all the notes that are present at least once in the training set
    present_notes = (sequences == 1).sum(0).sum(0) > 0
    # remove notes that are never played (we remove 37/88 notes)
    sequences = sequences[..., present_notes]

    if args.truncate:
        lengths = lengths.clamp(max=args.truncate)
        sequences = sequences[:, :args.truncate]
    num_observations = float(lengths.sum())
    pyro.set_rng_seed(args.seed)
    pyro.clear_param_store()

    # We'll train using MAP Baum-Welch, i.e. MAP estimation while marginalizing
    # out the hidden state x. This is accomplished via an automatic guide that
    # learns point estimates of all of our conditional probability tables,
    # named probs_*.
    guide = AutoDelta(
        poutine.block(model,
                      expose_fn=lambda msg: msg["name"].startswith("probs_")))

    # To help debug our tensor shapes, let's print the shape of each site's
    # distribution, value, and log_prob tensor. Note this information is
    # automatically printed on most errors inside SVI.
    if args.print_shapes:
        first_available_dim = -2 if model is model_0 else -3
        guide_trace = poutine.trace(guide).get_trace(
            sequences, lengths, args=args, batch_size=args.batch_size)
        model_trace = poutine.trace(
            poutine.replay(poutine.enum(model, first_available_dim),
                           guide_trace)).get_trace(sequences,
                                                   lengths,
                                                   args=args,
                                                   batch_size=args.batch_size)
        logging.info(model_trace.format_shapes())

    # Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting.
    # All of our models have two plates: "data" and "tones".
    optim = Adam({"lr": args.learning_rate})
    if args.tmc:
        if args.jit:
            raise NotImplementedError(
                "jit support not yet added for TraceTMC_ELBO")
        elbo = TraceTMC_ELBO(max_plate_nesting=1 if model is model_0 else 2)
        tmc_model = poutine.infer_config(
            model,
            lambda msg: {
                "num_samples": args.tmc_num_samples,
                "expand": False
            } if msg["infer"].get("enumerate", None) == "parallel" else {},
        )  # noqa: E501
        svi = SVI(tmc_model, guide, optim, elbo)
    else:
        Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO
        elbo = Elbo(
            max_plate_nesting=1 if model is model_0 else 2,
            strict_enumeration_warning=(model is not model_7),
            jit_options={"time_compilation": args.time_compilation},
        )
        svi = SVI(model, guide, optim, elbo)

    # We'll train on small minibatches.
    logging.info("Step\tLoss")
    for step in range(args.num_steps):
        loss = svi.step(sequences,
                        lengths,
                        args=args,
                        batch_size=args.batch_size)
        logging.info("{: >5d}\t{}".format(step, loss / num_observations))

    if args.jit and args.time_compilation:
        logging.debug("time to compile: {} s.".format(
            elbo._differentiable_loss.compile_time))

    # We evaluate on the entire training dataset,
    # excluding the prior term so our results are comparable across models.
    train_loss = elbo.loss(model,
                           guide,
                           sequences,
                           lengths,
                           args,
                           include_prior=False)
    logging.info("training loss = {}".format(train_loss / num_observations))

    # Finally we evaluate on the test dataset.
    logging.info("-" * 40)
    logging.info("Evaluating on {} test sequences".format(
        len(data["test"]["sequences"])))
    sequences = data["test"]["sequences"][..., present_notes]
    lengths = data["test"]["sequence_lengths"]
    if args.truncate:
        lengths = lengths.clamp(max=args.truncate)
    num_observations = float(lengths.sum())

    # note that since we removed unseen notes above (to make the problem a bit easier and for
    # numerical stability) this test loss may not be directly comparable to numbers
    # reported on this dataset elsewhere.
    test_loss = elbo.loss(model,
                          guide,
                          sequences,
                          lengths,
                          args=args,
                          include_prior=False)
    logging.info("test loss = {}".format(test_loss / num_observations))

    # We expect models with higher capacity to perform better,
    # but eventually overfit to the training set.
    capacity = sum(
        value.reshape(-1).size(0) for value in pyro.get_param_store().values())
    logging.info("{} capacity = {} parameters".format(model.__name__,
                                                      capacity))
Esempio n. 11
0
def get_log_prob_fn(
    model,
    model_args=(),
    model_kwargs={},
    implementation="pyro",
    automatic_transform_enabled=False,
    transforms=None,
    max_plate_nesting=None,
    jit_compile=False,
    jit_options=None,
    skip_jit_warnings=False,
    **kwargs,
) -> (Callable, Dict[str, Any]):
    """
    Given a Python callable with Pyro primitives, generates the following model-specific
    functions:
    - a log prob function whose input are parameters and whose output
      is the log prob of the model
    - transforms to transform latent sites of `model` to unconstrained space

    Args:
        model: a Pyro model which contains Pyro primitives.
        model_args: optional args taken by `model`.
        model_kwargs: optional kwargs taken by `model`.
        implementation: Switches between implementations
        automatic_transform_enabled: Whether or not should try to infer transforms
            to unconstrained space
        transforms: Optional dictionary that specifies a transform
            for a sample site with constrained support to unconstrained space. The
            transform should be invertible, and implement `log_abs_det_jacobian`.
            If not specified and the model has sites with constrained support,
            automatic transformations will be applied, as specified in
            `torch.distributions.constraint_registry`.
        max_plate_nesting: Optional bound on max number of nested
            `pyro.plate` contexts. This is required if model contains
            discrete sample sites that can be enumerated over in parallel. Will
            try to infer automatically if not provided
        jit_compile: Optional parameter denoting whether to use
            the PyTorch JIT to trace the log density computation, and use this
            optimized executable trace in the integrator.
        jit_options: A dictionary contains optional arguments for
            `torch.jit.trace` function.
        ignore_jit_warnings: Flag to ignore warnings from the JIT
            tracer when `jit_compile=True`. Default is False.

    Returns:
        `log_prob_fn` and `transforms`
    """
    if transforms is None:
        transforms = {}

    if max_plate_nesting is None:
        max_plate_nesting = _guess_max_plate_nesting(model, model_args,
                                                     model_kwargs)

    model = poutine.enum(config_enumerate(model),
                         first_available_dim=-1 - max_plate_nesting)
    model_trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)

    has_enumerable_sites = False
    for name, node in model_trace.iter_stochastic_nodes():
        fn = node["fn"]

        if isinstance(fn, _Subsample):
            if fn.subsample_size is not None and fn.subsample_size < fn.size:
                raise NotImplementedError(
                    "Model with subsample sites are not supported.")
            continue

        if fn.has_enumerate_support:
            has_enumerable_sites = True
            continue

        if automatic_transform_enabled:
            transforms[name] = biject_to(fn.support).inv
        else:
            transforms[name] = dist.transforms.identity_transform
        # Reinterpret batch dimensions of transform to get log abs det jac summed over
        # parameter dimensions.
        if not isinstance(transforms[name], IndependentTransform):
            transforms[name] = IndependentTransform(
                transforms[name], reinterpreted_batch_ndims=1)

    if implementation == "pyro":
        trace_prob_evaluator = TraceEinsumEvaluator(model_trace,
                                                    has_enumerable_sites,
                                                    max_plate_nesting)

        lp_maker = _LPMaker(model, model_args, model_kwargs,
                            trace_prob_evaluator, transforms)

        lp_fn = lp_maker.get_lp_fn(jit_compile, skip_jit_warnings, jit_options)

    elif implementation == "experimental":
        assert automatic_transform_enabled is False

        if jit_compile:
            warnings.warn("Will not JIT compile, unsupported for now.")

        def lp_fn(input_dict):
            excluded_nodes = set(["_INPUT", "_RETURN"])

            for key, value in input_dict.items():
                model_trace.nodes[key]["value"] = value

            replayed_model = pyro.poutine.replay(model, model_trace)

            log_p = 0
            for trace_enum in iter_discrete_traces("flat", fn=replayed_model):
                trace_enum.compute_log_prob()

                for node_name, node in trace_enum.nodes.items():
                    if node_name in excluded_nodes:
                        continue

                    if node["log_prob"].ndim == 1:
                        log_p += trace_enum.nodes[node_name]["log_prob"]
                    else:
                        log_p += trace_enum.nodes[node_name]["log_prob"].sum(
                            dim=1)

            return log_p

    else:
        raise NotImplementedError

    return lp_fn, transforms
Esempio n. 12
0
def initialize_model(model, model_args=(), model_kwargs={}, transforms=None, max_plate_nesting=None,
                     jit_compile=False, jit_options=None, skip_jit_warnings=False, num_chains=1,
                     init_strategy=init_to_uniform, initial_params=None):
    """
    Given a Python callable with Pyro primitives, generates the following model-specific
    properties needed for inference using HMC/NUTS kernels:

    - initial parameters to be sampled using a HMC kernel,
    - a potential function whose input is a dict of parameters in unconstrained space,
    - transforms to transform latent sites of `model` to unconstrained space,
    - a prototype trace to be used in MCMC to consume traces from sampled parameters.

    :param model: a Pyro model which contains Pyro primitives.
    :param tuple model_args: optional args taken by `model`.
    :param dict model_kwargs: optional kwargs taken by `model`.
    :param dict transforms: Optional dictionary that specifies a transform
        for a sample site with constrained support to unconstrained space. The
        transform should be invertible, and implement `log_abs_det_jacobian`.
        If not specified and the model has sites with constrained support,
        automatic transformations will be applied, as specified in
        :mod:`torch.distributions.constraint_registry`.
    :param int max_plate_nesting: Optional bound on max number of nested
        :func:`pyro.plate` contexts. This is required if model contains
        discrete sample sites that can be enumerated over in parallel.
    :param bool jit_compile: Optional parameter denoting whether to use
        the PyTorch JIT to trace the log density computation, and use this
        optimized executable trace in the integrator.
    :param dict jit_options: A dictionary contains optional arguments for
        :func:`torch.jit.trace` function.
    :param bool ignore_jit_warnings: Flag to ignore warnings from the JIT
        tracer when ``jit_compile=True``. Default is False.
    :param int num_chains: Number of parallel chains. If `num_chains > 1`,
        the returned `initial_params` will be a list with `num_chains` elements.
    :param callable init_strategy: A per-site initialization function.
        See :ref:`autoguide-initialization` section for available functions.
    :param dict initial_params: dict containing initial tensors in unconstrained
        space to initiate the markov chain.
    :returns: a tuple of (`initial_params`, `potential_fn`, `transforms`, `prototype_trace`)
    """
    # XXX `transforms` domains are sites' supports
    # FIXME: find a good pattern to deal with `transforms` arg
    if transforms is None:
        automatic_transform_enabled = True
        transforms = {}
    else:
        automatic_transform_enabled = False
    if max_plate_nesting is None:
        max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs)
    # Wrap model in `poutine.enum` to enumerate over discrete latent sites.
    # No-op if model does not have any discrete latents.
    model = poutine.enum(config_enumerate(model),
                         first_available_dim=-1 - max_plate_nesting)
    prototype_model = poutine.trace(InitMessenger(init_strategy)(model))
    model_trace = prototype_model.get_trace(*model_args, **model_kwargs)
    has_enumerable_sites = False
    prototype_samples = {}
    for name, node in model_trace.iter_stochastic_nodes():
        fn = node["fn"]
        if isinstance(fn, _Subsample):
            if fn.subsample_size is not None and fn.subsample_size < fn.size:
                raise NotImplementedError("HMC/NUTS does not support model with subsample sites.")
            continue
        if node["fn"].has_enumerate_support:
            has_enumerable_sites = True
            continue
        # we need to detach here because this sample can be a leaf variable,
        # so we can't change its requires_grad flag to calculate its grad in
        # velocity_verlet
        prototype_samples[name] = node["value"].detach()
        if automatic_transform_enabled:
            transforms[name] = biject_to(node["fn"].support).inv

    trace_prob_evaluator = TraceEinsumEvaluator(model_trace,
                                                has_enumerable_sites,
                                                max_plate_nesting)

    pe_maker = _PEMaker(model, model_args, model_kwargs, trace_prob_evaluator, transforms)

    if initial_params is None:
        prototype_params = {k: transforms[k](v) for k, v in prototype_samples.items()}
        # Note that we deliberately do not exercise jit compilation here so as to
        # enable potential_fn to be picklable (a torch._C.Function cannot be pickled).
        # We pass model_trace merely for computational savings.
        initial_params = _find_valid_initial_params(model, model_args, model_kwargs, transforms,
                                                    pe_maker.get_potential_fn(), prototype_params,
                                                    num_chains=num_chains, init_strategy=init_strategy,
                                                    trace=model_trace)
    potential_fn = pe_maker.get_potential_fn(jit_compile, skip_jit_warnings, jit_options)
    return initial_params, potential_fn, transforms, model_trace