Esempio n. 1
0
def config_enumerate(guide=None, default="sequential"):
    """
    Configures each enumerable site a guide to enumerate with given method,
    ``site["infer"]["enumerate"] = default``. This can be used as either a
    function::

        guide = config_enumerate(guide)

    or as a decorator::

        @config_enumerate
        def guide1(*args, **kwargs):
            ...

        @config_enumerate(default="parallel")
        def guide2(*args, **kwargs):
            ...

    This does not overwrite existing annotations ``infer={"enumerate": ...}``.

    :param callable guide: a pyro model that will be used as a guide in
        :class:`~pyro.infer.svi.SVI`.
    :param str default: one of "sequential", "parallel", or None.
    :return: an annotated guide
    :rtype: callable
    """
    if default not in ["sequential", "parallel", None]:
        raise ValueError(
            "Invalid default value. Expected 'sequential', 'parallel', or None, but got {}"
            .format(repr(default)))
    # Support usage as a decorator:
    if guide is None:
        return lambda guide: config_enumerate(guide, default=default)

    return poutine.infer_config(guide, config_fn=_config_enumerate(default))
Esempio n. 2
0
def config_enumerate(guide=None, default="sequential"):
    """
    Configures each enumerable site a guide to enumerate with given method,
    ``site["infer"]["enumerate"] = default``. This can be used as either a
    function::

        guide = config_enumerate(guide)

    or as a decorator::

        @config_enumerate
        def guide1(*args, **kwargs):
            ...

        @config_enumerate(default="parallel")
        def guide2(*args, **kwargs):
            ...

    This does not overwrite existing annotations ``infer={"enumerate": ...}``.

    :param callable guide: a pyro model that will be used as a guide in
        :class:`~pyro.infer.svi.SVI`.
    :param str default: one of "sequential", "parallel", or None.
    :return: an annotated guide
    :rtype: callable
    """
    if default not in ["sequential", "parallel", None]:
        raise ValueError("Invalid default value. Expected 'sequential', 'parallel', or None, but got {}".format(
            repr(default)))
    # Support usage as a decorator:
    if guide is None:
        return lambda guide: config_enumerate(guide, default=default)

    return poutine.infer_config(guide, config_fn=_config_enumerate(default))
Esempio n. 3
0
def config_enumerate(guide=None, default="parallel", expand=False, num_samples=None):
    """
    Configures enumeration for all relevant sites in a guide. This is mainly
    used in conjunction with :class:`~pyro.infer.traceenum_elbo.TraceEnum_ELBO`.

    When configuring for exhaustive enumeration of discrete variables, this
    configures all sample sites whose distribution satisfies
    ``.has_enumerate_support == True``.
    When configuring for local parallel Monte Carlo sampling via
    ``default="parallel", num_samples=n``, this configures all sample sites.
    This does not overwrite existing annotations ``infer={"enumerate": ...}``.

    This can be used as either a function::

        guide = config_enumerate(guide)

    or as a decorator::

        @config_enumerate
        def guide1(*args, **kwargs):
            ...

        @config_enumerate(default="sequential", expand=True)
        def guide2(*args, **kwargs):
            ...

    :param callable guide: a pyro model that will be used as a guide in
        :class:`~pyro.infer.svi.SVI`.
    :param str default: Which enumerate strategy to use, one of
        "sequential", "parallel", or None. Defaults to "parallel".
    :param bool expand: Whether to expand enumerated sample values. See
        :meth:`~pyro.distributions.Distribution.enumerate_support` for details.
        This only applies to exhaustive enumeration, where ``num_samples=None``.
        If ``num_samples`` is not ``None``, then this samples will always be
        expanded.
    :param num_samples: if not ``None``, use local Monte Carlo sampling rather
        than exhaustive enumeration. This makes sense for both continuous and
        discrete distributions.
    :type num_samples: int or None
    :return: an annotated guide
    :rtype: callable
    """
    if default not in ["sequential", "parallel", None]:
        raise ValueError("Invalid default value. Expected 'sequential', 'parallel', or None, but got {}".format(
            repr(default)))
    if expand not in [True, False]:
        raise ValueError("Invalid expand value. Expected True or False, but got {}".format(repr(expand)))
    if num_samples is not None:
        if not (isinstance(num_samples, numbers.Number) and num_samples > 0):
            raise ValueError("Invalid num_samples, expected None or positive integer, but got {}".format(
                repr(num_samples)))
        if default == "sequential":
            raise ValueError('Local sampling does not support "sequential" sampling; '
                             'use "parallel" sampling instead.')

    # Support usage as a decorator:
    if guide is None:
        return lambda guide: config_enumerate(guide, default=default, expand=expand, num_samples=num_samples)

    return poutine.infer_config(guide, config_fn=_config_enumerate(default, expand, num_samples))
Esempio n. 4
0
    def test_infer_config_sample(self):
        cfg_model = poutine.infer_config(self.model, config_fn=self.config_fn)

        tr = poutine.trace(cfg_model).get_trace()

        assert tr.nodes["a"]["infer"] == {"enumerate": "parallel", "blah": True}
        assert tr.nodes["b"]["infer"] == {"blah": True}
        assert tr.nodes["p"]["infer"] == {}
Esempio n. 5
0
    def test_infer_config_sample(self):
        cfg_model = poutine.infer_config(self.model, config_fn=self.config_fn)

        tr = poutine.trace(cfg_model).get_trace()

        assert tr.nodes["a"]["infer"] == {"enumerate": "parallel", "blah": True}
        assert tr.nodes["b"]["infer"] == {"blah": True}
        assert tr.nodes["p"]["infer"] == {}
Esempio n. 6
0
File: hmm.py Progetto: pyro-ppl/pyro
def main(args):
    if args.cuda:
        torch.set_default_tensor_type("torch.cuda.FloatTensor")

    logging.info("Loading data")
    data = poly.load_data(poly.JSB_CHORALES)

    logging.info("-" * 40)
    model = models[args.model]
    logging.info("Training {} on {} sequences".format(
        model.__name__, len(data["train"]["sequences"])))
    sequences = data["train"]["sequences"]
    lengths = data["train"]["sequence_lengths"]

    # find all the notes that are present at least once in the training set
    present_notes = (sequences == 1).sum(0).sum(0) > 0
    # remove notes that are never played (we remove 37/88 notes)
    sequences = sequences[..., present_notes]

    if args.truncate:
        lengths = lengths.clamp(max=args.truncate)
        sequences = sequences[:, :args.truncate]
    num_observations = float(lengths.sum())
    pyro.set_rng_seed(args.seed)
    pyro.clear_param_store()

    # We'll train using MAP Baum-Welch, i.e. MAP estimation while marginalizing
    # out the hidden state x. This is accomplished via an automatic guide that
    # learns point estimates of all of our conditional probability tables,
    # named probs_*.
    guide = AutoDelta(
        poutine.block(model,
                      expose_fn=lambda msg: msg["name"].startswith("probs_")))

    # To help debug our tensor shapes, let's print the shape of each site's
    # distribution, value, and log_prob tensor. Note this information is
    # automatically printed on most errors inside SVI.
    if args.print_shapes:
        first_available_dim = -2 if model is model_0 else -3
        guide_trace = poutine.trace(guide).get_trace(
            sequences, lengths, args=args, batch_size=args.batch_size)
        model_trace = poutine.trace(
            poutine.replay(poutine.enum(model, first_available_dim),
                           guide_trace)).get_trace(sequences,
                                                   lengths,
                                                   args=args,
                                                   batch_size=args.batch_size)
        logging.info(model_trace.format_shapes())

    # Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting.
    # All of our models have two plates: "data" and "tones".
    optim = Adam({"lr": args.learning_rate})
    if args.tmc:
        if args.jit:
            raise NotImplementedError(
                "jit support not yet added for TraceTMC_ELBO")
        elbo = TraceTMC_ELBO(max_plate_nesting=1 if model is model_0 else 2)
        tmc_model = poutine.infer_config(
            model,
            lambda msg: {
                "num_samples": args.tmc_num_samples,
                "expand": False
            } if msg["infer"].get("enumerate", None) == "parallel" else {},
        )  # noqa: E501
        svi = SVI(tmc_model, guide, optim, elbo)
    else:
        Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO
        elbo = Elbo(
            max_plate_nesting=1 if model is model_0 else 2,
            strict_enumeration_warning=(model is not model_7),
            jit_options={"time_compilation": args.time_compilation},
        )
        svi = SVI(model, guide, optim, elbo)

    # We'll train on small minibatches.
    logging.info("Step\tLoss")
    for step in range(args.num_steps):
        loss = svi.step(sequences,
                        lengths,
                        args=args,
                        batch_size=args.batch_size)
        logging.info("{: >5d}\t{}".format(step, loss / num_observations))

    if args.jit and args.time_compilation:
        logging.debug("time to compile: {} s.".format(
            elbo._differentiable_loss.compile_time))

    # We evaluate on the entire training dataset,
    # excluding the prior term so our results are comparable across models.
    train_loss = elbo.loss(model,
                           guide,
                           sequences,
                           lengths,
                           args,
                           include_prior=False)
    logging.info("training loss = {}".format(train_loss / num_observations))

    # Finally we evaluate on the test dataset.
    logging.info("-" * 40)
    logging.info("Evaluating on {} test sequences".format(
        len(data["test"]["sequences"])))
    sequences = data["test"]["sequences"][..., present_notes]
    lengths = data["test"]["sequence_lengths"]
    if args.truncate:
        lengths = lengths.clamp(max=args.truncate)
    num_observations = float(lengths.sum())

    # note that since we removed unseen notes above (to make the problem a bit easier and for
    # numerical stability) this test loss may not be directly comparable to numbers
    # reported on this dataset elsewhere.
    test_loss = elbo.loss(model,
                          guide,
                          sequences,
                          lengths,
                          args=args,
                          include_prior=False)
    logging.info("test loss = {}".format(test_loss / num_observations))

    # We expect models with higher capacity to perform better,
    # but eventually overfit to the training set.
    capacity = sum(
        value.reshape(-1).size(0) for value in pyro.get_param_store().values())
    logging.info("{} capacity = {} parameters".format(model.__name__,
                                                      capacity))
Esempio n. 7
0
def main(args):
    pyro.set_rng_seed(0)
    pyro.clear_param_store()
    pyro.enable_validation(__debug__)

    # load data
    if args.dataset == "dipper":
        capture_history_file = os.path.dirname(
            os.path.abspath(__file__)) + '/dipper_capture_history.csv'
    elif args.dataset == "vole":
        capture_history_file = os.path.dirname(
            os.path.abspath(__file__)) + '/meadow_voles_capture_history.csv'
    else:
        raise ValueError("Available datasets are \'dipper\' and \'vole\'.")

    capture_history = torch.tensor(
        np.genfromtxt(capture_history_file, delimiter=',')).float()[:, 1:]
    N, T = capture_history.shape
    print(
        "Loaded {} capture history for {} individuals collected over {} time periods."
        .format(args.dataset, N, T))

    if args.dataset == "dipper" and args.model in ["4", "5"]:
        sex_file = os.path.dirname(
            os.path.abspath(__file__)) + '/dipper_sex.csv'
        sex = torch.tensor(np.genfromtxt(sex_file, delimiter=',')).float()[:,
                                                                           1]
        print("Loaded dipper sex data.")
    elif args.dataset == "vole" and args.model in ["4", "5"]:
        raise ValueError(
            "Cannot run model_{} on meadow voles data, since we lack sex " +
            "information for these animals.".format(args.model))
    else:
        sex = None

    model = models[args.model]

    # we use poutine.block to only expose the continuous latent variables
    # in the models to AutoDiagonalNormal (all of which begin with 'phi'
    # or 'rho')
    def expose_fn(msg):
        return msg["name"][0:3] in ['phi', 'rho']

    # we use a mean field diagonal normal variational distributions (i.e. guide)
    # for the continuous latent variables.
    guide = AutoDiagonalNormal(poutine.block(model, expose_fn=expose_fn))

    # since we enumerate the discrete random variables,
    # we need to use TraceEnum_ELBO or TraceTMC_ELBO.
    optim = Adam({'lr': args.learning_rate})
    if args.tmc:
        elbo = TraceTMC_ELBO(max_plate_nesting=1)
        tmc_model = poutine.infer_config(model, lambda msg: {
            "num_samples": args.tmc_num_samples,
            "expand": False
        } if msg["infer"].get("enumerate", None) == "parallel" else {}
                                         )  # noqa: E501
        svi = SVI(tmc_model, guide, optim, elbo)
    else:
        elbo = TraceEnum_ELBO(max_plate_nesting=1,
                              num_particles=20,
                              vectorize_particles=True)
        svi = SVI(model, guide, optim, elbo)

    losses = []

    print(
        "Beginning training of model_{} with Stochastic Variational Inference."
        .format(args.model))

    for step in range(args.num_steps):
        loss = svi.step(capture_history, sex)
        losses.append(loss)
        if step % 20 == 0 and step > 0 or step == args.num_steps - 1:
            print("[iteration %03d] loss: %.3f" %
                  (step, np.mean(losses[-20:])))

    # evaluate final trained model
    elbo_eval = TraceEnum_ELBO(max_plate_nesting=1,
                               num_particles=2000,
                               vectorize_particles=True)
    svi_eval = SVI(model, guide, optim, elbo_eval)
    print("Final loss: %.4f" % svi_eval.evaluate_loss(capture_history, sex))