Ejemplo n.º 1
0
Archivo: hmm.py Proyecto: xidulu/pyro
def model_1(sequences, lengths, args, batch_size=None, include_prior=True):
    # Sometimes it is safe to ignore jit warnings. Here we use the
    # pyro.util.ignore_jit_warnings context manager to silence warnings about
    # conversion to integer, since we know all three numbers will be the same
    # across all invocations to the model.
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences, )
        assert lengths.max() <= max_length
    with poutine.mask(mask=include_prior):
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1))
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim,
                                        data_dim]).to_event(2))
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    # We subsample batch_size items out of num_sequences items. Note that since
    # we're using dim=-1 for the notes plate, we need to batch over a different
    # dimension, here dim=-2.
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        x = 0
        # If we are not using the jit, then we can vary the program structure
        # each call by running for a dynamically determined number of time
        # steps, lengths.max(). However if we are using the jit, then we try to
        # keep a single program structure for all minibatches; the fixed
        # structure ends up being faster since each program structure would
        # need to trigger a new jit compile stage.
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                x = pyro.sample("x_{}".format(t),
                                dist.Categorical(probs_x[x]),
                                infer={"enumerate": "parallel"})
                with tones_plate:
                    pyro.sample("y_{}".format(t),
                                dist.Bernoulli(probs_y[x.squeeze(-1)]),
                                obs=sequences[batch, t])
Ejemplo n.º 2
0
 def model(data):
     initialize = pyro.sample("initialize", dist.Dirichlet(torch.ones(dim)))
     with pyro.plate("states", dim):
         transition = pyro.sample("transition", dist.Dirichlet(torch.ones(dim, dim)))
         emission_loc = pyro.sample(
             "emission_loc", dist.Normal(torch.zeros(dim), torch.ones(dim))
         )
         emission_scale = pyro.sample(
             "emission_scale", dist.LogNormal(torch.zeros(dim), torch.ones(dim))
         )
     x = None
     with ignore_jit_warnings([("Iterating over a tensor", RuntimeWarning)]):
         for t, y in pyro.markov(enumerate(data)):
             x = pyro.sample(
                 "x_{}".format(t),
                 dist.Categorical(initialize if x is None else transition[x]),
                 infer={"enumerate": "parallel"},
             )
             pyro.sample(
                 "y_{}".format(t),
                 dist.Normal(emission_loc[x], emission_scale[x]),
                 obs=y,
             )
Ejemplo n.º 3
0
def model_4(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences, )
        assert lengths.max() <= max_length
    hidden_dim = int(args.hidden_dim**0.5)  # split between w and x
    with handlers.mask(mask=include_prior):
        probs_w = pyro.sample(
            "probs_w",
            dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1))
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).expand_by(
                [hidden_dim]).to_event(2))
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim,
                                        data_dim]).to_event(3))
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        # Note the broadcasting tricks here: we declare a hidden torch.arange and
        # ensure that w and x are always tensors so we can unsqueeze them below,
        # thus ensuring that the x sample sites have correct distribution shape.
        w = x = torch.tensor(0, dtype=torch.long)
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with handlers.mask(mask=(t < lengths).unsqueeze(-1)):
                w = pyro.sample("w_{}".format(t),
                                dist.Categorical(probs_w[w]),
                                infer={"enumerate": "parallel"})
                x = pyro.sample("x_{}".format(t),
                                dist.Categorical(Vindex(probs_x)[w, x]),
                                infer={"enumerate": "parallel"})
                with tones_plate as tones:
                    pyro.sample("y_{}".format(t),
                                dist.Bernoulli(probs_y[w, x, tones]),
                                obs=sequences[batch, t])
Ejemplo n.º 4
0
Archivo: hmm.py Proyecto: pyro-ppl/pyro
def model_5(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences, )
        assert lengths.max() <= max_length

    # Initialize a global module instance if needed.
    global tones_generator
    if tones_generator is None:
        tones_generator = TonesGenerator(args, data_dim)
    pyro.module("tones_generator", tones_generator)

    with poutine.mask(mask=include_prior):
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1),
        )
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        x = 0
        y = torch.zeros(data_dim)
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                x = pyro.sample(
                    "x_{}".format(t),
                    dist.Categorical(probs_x[x]),
                    infer={"enumerate": "parallel"},
                )
                # Note that since each tone depends on all tones at a previous time step
                # the tones at different time steps now need to live in separate plates.
                with pyro.plate("tones_{}".format(t), data_dim, dim=-1):
                    y = pyro.sample(
                        "y_{}".format(t),
                        dist.Bernoulli(logits=tones_generator(x, y)),
                        obs=sequences[batch, t],
                    )
Ejemplo n.º 5
0
def ubersum(equation, *operands, **kwargs):
    """
    Generalized batched sum-product algorithm via tensor message passing.

    This generalizes :func:`~pyro.ops.einsum.contract` in two ways:

    1.  Multiple outputs are allowed, and intermediate results can be shared.
    2.  Inputs and outputs can be batched along symbols given in ``batch_dims``;
        reductions along ``batch_dims`` are product reductions.

    The best way to understand this function is to try the examples below,
    which show how :func:`ubersum` calls can be implemented as multiple calls
    to :func:`~pyro.ops.einsum.contract` (which is generally more expensive).

    To illustrate multiple outputs, note that the following are equivalent::

        z1, z2, z3 = ubersum('ab,bc->a,b,c', x, y)  # multiple outputs

        backend = 'pyro.ops.einsum.torch_log'
        z1 = contract('ab,bc->a', x, y, backend=backend)
        z2 = contract('ab,bc->b', x, y, backend=backend)
        z3 = contract('ab,bc->c', x, y, backend=backend)

    To illustrate batched inputs, note that the following are equivalent::

        assert len(x) == 3 and len(y) == 3
        z = ubersum('ab,ai,bi->b', w, x, y, batch_dims='i')

        z = contract('ab,a,a,a,b,b,b->b', w, *x, *y, backend=backend)

    When a sum dimension `a` always appears with a batch dimension `i`,
    then `a` corresponds to a distinct symbol for each slice of `a`. Thus
    the following are equivalent::

        assert len(x) == 3 and len(y) == 3
        z = ubersum('ai,ai->', x, y, batch_dims='i')

        z = contract('a,b,c,a,b,c->', *x, *y, backend=backend)

    When such a sum dimension appears in the output, it must be
    accompanied by all of its batch dimensions, e.g. the following are
    equivalent::

        assert len(x) == 3 and len(y) == 3
        z = ubersum('abi,abi->bi', x, y, batch_dims='i')

        z0 = contract('ab,ac,ad,ab,ac,ad->b', *x, *y, backend=backend)
        z1 = contract('ab,ac,ad,ab,ac,ad->c', *x, *y, backend=backend)
        z2 = contract('ab,ac,ad,ab,ac,ad->d', *x, *y, backend=backend)
        z = torch.stack([z0, z1, z2])

    Note that each batch slice through the output is multilinear in all batch
    slices through all inptus, thus e.g. batch matrix multiply would be
    implemented *without* ``batch_dims``, so the following are all equivalent::

        xy = ubersum('abc,acd->abd', x, y, batch_dims='')
        xy = torch.stack([xa.mm(ya) for xa, ya in zip(x, y)])
        xy = torch.bmm(x, y)

    Among all valid equations, some computations are polynomial in the sizes of
    the input tensors and other computations are exponential in the sizes of
    the input tensors. This function raises :py:class:`NotImplementedError`
    whenever the computation is exponential.

    :param str equation: An einsum equation, optionally with multiple outputs.
    :param torch.Tensor operands: A collection of tensors.
    :param str batch_dims: An optional string of batch dims.
    :param dict cache: An optional :func:`~opt_einsum.shared_intermediates`
        cache.
    :param bool modulo_total: Optionally allow ubersum to arbitrarily scale
        each result batch, which can significantly reduce computation. This is
        safe to set whenever each result batch denotes a nonnormalized
        probability distribution whose total is not of interest.
    :return: a tuple of tensors of requested shape, one entry per output.
    :rtype: tuple
    :raises ValueError: if tensor sizes mismatch or an output requests a
        batched dim without that dim's batch dims.
    :raises NotImplementedError: if contraction would have cost exponential in
        the size of any input tensor.
    """
    # Extract kwargs.
    cache = kwargs.pop('cache', None)
    batch_dims = kwargs.pop('batch_dims', '')
    backend = kwargs.pop('backend', 'pyro.ops.einsum.torch_log')
    modulo_total = kwargs.pop('modulo_total', False)
    try:
        Ring = BACKEND_TO_RING[backend]
    except KeyError:
        raise NotImplementedError('\n'.join(
            ['Only the following pyro backends are currently implemented:'] +
            list(BACKEND_TO_RING)))

    # Parse generalized einsum equation.
    if '.' in equation:
        raise NotImplementedError(
            'ubsersum does not yet support ellipsis notation')
    inputs, outputs = equation.split('->')
    inputs = inputs.split(',')
    outputs = outputs.split(',')
    assert len(inputs) == len(operands)
    assert all(isinstance(x, torch.Tensor) for x in operands)
    if not modulo_total and any(outputs):
        raise NotImplementedError(
            'Try setting modulo_total=True and ensuring that your use case '
            'allows an arbitrary scale factor on each result batch.')
    if len(operands) != len(set(operands)):
        operands = [x[...] for x in operands]  # ensure tensors are unique

    # Check sizes.
    with ignore_jit_warnings():
        dim_to_size = {}
        for dims, term in zip(inputs, operands):
            for dim, size in zip(dims, map(int, term.shape)):
                old = dim_to_size.setdefault(dim, size)
                if old != size:
                    raise ValueError(
                        u"Dimension size mismatch at dim '{}': {} vs {}".
                        format(dim, size, old))

    # Construct a tensor tree shared by all outputs.
    tensor_tree = OrderedDict()
    batch_dims = frozenset(batch_dims)
    for dims, term in zip(inputs, operands):
        assert len(dims) == term.dim()
        term._pyro_dims = dims
        ordinal = batch_dims.intersection(dims)
        tensor_tree.setdefault(ordinal, []).append(term)

    # Compute outputs, sharing intermediate computations.
    results = []
    with shared_intermediates(cache) as cache:
        ring = Ring(cache, dim_to_size=dim_to_size)
        for output in outputs:
            sum_dims = set(output).union(*inputs) - set(batch_dims)
            term = contract_to_tensor(
                tensor_tree,
                sum_dims,
                target_ordinal=batch_dims.intersection(output),
                target_dims=sum_dims.intersection(output),
                ring=ring)
            if term._pyro_dims != output:
                term = term.permute(*map(term._pyro_dims.index, output))
                term._pyro_dims = output
            results.append(term)
    return tuple(results)
Ejemplo n.º 6
0
 def _key(self):
     with ignore_jit_warnings(["Converting a tensor to a Python number"]):
         size = self.size.item() if isinstance(self.size, torch.Tensor) else self.size
         return self.name, self.dim, size, self.counter
Ejemplo n.º 7
0
def torus_dbn(phis=None,
              psis=None,
              lengths=None,
              num_sequences=None,
              num_states=55,
              prior_conc=0.1,
              prior_loc=0.0,
              prior_length_shape=100.,
              prior_length_rate=100.,
              prior_kappa_min=10.,
              prior_kappa_max=1000.):
    # From https://pyro.ai/examples/hmm.html
    with ignore_jit_warnings():
        if lengths is not None:
            assert num_sequences is None
            num_sequences = int(lengths.shape[0])
        else:
            assert num_sequences is not None
    transition_probs = pyro.sample(
        'transition_probs',
        dist.Dirichlet(
            torch.ones(num_states, num_states, dtype=torch.float) *
            num_states).to_event(1))
    length_shape = pyro.sample('length_shape',
                               dist.HalfCauchy(prior_length_shape))
    length_rate = pyro.sample('length_rate',
                              dist.HalfCauchy(prior_length_rate))
    phi_locs = pyro.sample(
        'phi_locs',
        dist.VonMises(
            torch.ones(num_states, dtype=torch.float) * prior_loc,
            torch.ones(num_states, dtype=torch.float) *
            prior_conc).to_event(1))
    phi_kappas = pyro.sample(
        'phi_kappas',
        dist.Uniform(
            torch.ones(num_states, dtype=torch.float) * prior_kappa_min,
            torch.ones(num_states, dtype=torch.float) *
            prior_kappa_max).to_event(1))
    psi_locs = pyro.sample(
        'psi_locs',
        dist.VonMises(
            torch.ones(num_states, dtype=torch.float) * prior_loc,
            torch.ones(num_states, dtype=torch.float) *
            prior_conc).to_event(1))
    psi_kappas = pyro.sample(
        'psi_kappas',
        dist.Uniform(
            torch.ones(num_states, dtype=torch.float) * prior_kappa_min,
            torch.ones(num_states, dtype=torch.float) *
            prior_kappa_max).to_event(1))
    element_plate = pyro.plate('elements', 1, dim=-1)
    with pyro.plate('sequences', num_sequences, dim=-2) as batch:
        if lengths is not None:
            lengths = lengths[batch]
            obs_length = lengths.float().unsqueeze(-1)
        else:
            obs_length = None
        state = 0
        sam_lengths = pyro.sample('length',
                                  dist.TransformedDistribution(
                                      dist.GammaPoisson(
                                          length_shape, length_rate),
                                      AffineTransform(0., 1.)),
                                  obs=obs_length)
        if lengths is None:
            lengths = sam_lengths.squeeze(-1).long()
        for t in pyro.markov(range(lengths.max())):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                state = pyro.sample(f'state_{t}',
                                    dist.Categorical(transition_probs[state]),
                                    infer={'enumerate': 'parallel'})
                if phis is not None:
                    obs_phi = Vindex(phis)[batch, t].unsqueeze(-1)
                else:
                    obs_phi = None
                if psis is not None:
                    obs_psi = Vindex(psis)[batch, t].unsqueeze(-1)
                else:
                    obs_psi = None
                with element_plate:
                    pyro.sample(f'phi_{t}',
                                dist.VonMises(phi_locs[state],
                                              phi_kappas[state]),
                                obs=obs_phi)
                    pyro.sample(f'psi_{t}',
                                dist.VonMises(psi_locs[state],
                                              psi_kappas[state]),
                                obs=obs_psi)