def model_1(sequences, lengths, args, batch_size=None, include_prior=True): # Sometimes it is safe to ignore jit warnings. Here we use the # pyro.util.ignore_jit_warnings context manager to silence warnings about # conversion to integer, since we know all three numbers will be the same # across all invocations to the model. with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences, ) assert lengths.max() <= max_length with poutine.mask(mask=include_prior): probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1)) probs_y = pyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2)) tones_plate = pyro.plate("tones", data_dim, dim=-1) # We subsample batch_size items out of num_sequences items. Note that since # we're using dim=-1 for the notes plate, we need to batch over a different # dimension, here dim=-2. with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] x = 0 # If we are not using the jit, then we can vary the program structure # each call by running for a dynamically determined number of time # steps, lengths.max(). However if we are using the jit, then we try to # keep a single program structure for all minibatches; the fixed # structure ends up being faster since each program structure would # need to trigger a new jit compile stage. for t in pyro.markov(range(max_length if args.jit else lengths.max())): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): x = pyro.sample("x_{}".format(t), dist.Categorical(probs_x[x]), infer={"enumerate": "parallel"}) with tones_plate: pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y[x.squeeze(-1)]), obs=sequences[batch, t])
def model(data): initialize = pyro.sample("initialize", dist.Dirichlet(torch.ones(dim))) with pyro.plate("states", dim): transition = pyro.sample("transition", dist.Dirichlet(torch.ones(dim, dim))) emission_loc = pyro.sample( "emission_loc", dist.Normal(torch.zeros(dim), torch.ones(dim)) ) emission_scale = pyro.sample( "emission_scale", dist.LogNormal(torch.zeros(dim), torch.ones(dim)) ) x = None with ignore_jit_warnings([("Iterating over a tensor", RuntimeWarning)]): for t, y in pyro.markov(enumerate(data)): x = pyro.sample( "x_{}".format(t), dist.Categorical(initialize if x is None else transition[x]), infer={"enumerate": "parallel"}, ) pyro.sample( "y_{}".format(t), dist.Normal(emission_loc[x], emission_scale[x]), obs=y, )
def model_4(sequences, lengths, args, batch_size=None, include_prior=True): with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences, ) assert lengths.max() <= max_length hidden_dim = int(args.hidden_dim**0.5) # split between w and x with handlers.mask(mask=include_prior): probs_w = pyro.sample( "probs_w", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1)) probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).expand_by( [hidden_dim]).to_event(2)) probs_y = pyro.sample( "probs_y", dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim, data_dim]).to_event(3)) tones_plate = pyro.plate("tones", data_dim, dim=-1) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] # Note the broadcasting tricks here: we declare a hidden torch.arange and # ensure that w and x are always tensors so we can unsqueeze them below, # thus ensuring that the x sample sites have correct distribution shape. w = x = torch.tensor(0, dtype=torch.long) for t in pyro.markov(range(max_length if args.jit else lengths.max())): with handlers.mask(mask=(t < lengths).unsqueeze(-1)): w = pyro.sample("w_{}".format(t), dist.Categorical(probs_w[w]), infer={"enumerate": "parallel"}) x = pyro.sample("x_{}".format(t), dist.Categorical(Vindex(probs_x)[w, x]), infer={"enumerate": "parallel"}) with tones_plate as tones: pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y[w, x, tones]), obs=sequences[batch, t])
def model_5(sequences, lengths, args, batch_size=None, include_prior=True): with ignore_jit_warnings(): num_sequences, max_length, data_dim = map(int, sequences.shape) assert lengths.shape == (num_sequences, ) assert lengths.max() <= max_length # Initialize a global module instance if needed. global tones_generator if tones_generator is None: tones_generator = TonesGenerator(args, data_dim) pyro.module("tones_generator", tones_generator) with poutine.mask(mask=include_prior): probs_x = pyro.sample( "probs_x", dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1), ) with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch: lengths = lengths[batch] x = 0 y = torch.zeros(data_dim) for t in pyro.markov(range(max_length if args.jit else lengths.max())): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): x = pyro.sample( "x_{}".format(t), dist.Categorical(probs_x[x]), infer={"enumerate": "parallel"}, ) # Note that since each tone depends on all tones at a previous time step # the tones at different time steps now need to live in separate plates. with pyro.plate("tones_{}".format(t), data_dim, dim=-1): y = pyro.sample( "y_{}".format(t), dist.Bernoulli(logits=tones_generator(x, y)), obs=sequences[batch, t], )
def ubersum(equation, *operands, **kwargs): """ Generalized batched sum-product algorithm via tensor message passing. This generalizes :func:`~pyro.ops.einsum.contract` in two ways: 1. Multiple outputs are allowed, and intermediate results can be shared. 2. Inputs and outputs can be batched along symbols given in ``batch_dims``; reductions along ``batch_dims`` are product reductions. The best way to understand this function is to try the examples below, which show how :func:`ubersum` calls can be implemented as multiple calls to :func:`~pyro.ops.einsum.contract` (which is generally more expensive). To illustrate multiple outputs, note that the following are equivalent:: z1, z2, z3 = ubersum('ab,bc->a,b,c', x, y) # multiple outputs backend = 'pyro.ops.einsum.torch_log' z1 = contract('ab,bc->a', x, y, backend=backend) z2 = contract('ab,bc->b', x, y, backend=backend) z3 = contract('ab,bc->c', x, y, backend=backend) To illustrate batched inputs, note that the following are equivalent:: assert len(x) == 3 and len(y) == 3 z = ubersum('ab,ai,bi->b', w, x, y, batch_dims='i') z = contract('ab,a,a,a,b,b,b->b', w, *x, *y, backend=backend) When a sum dimension `a` always appears with a batch dimension `i`, then `a` corresponds to a distinct symbol for each slice of `a`. Thus the following are equivalent:: assert len(x) == 3 and len(y) == 3 z = ubersum('ai,ai->', x, y, batch_dims='i') z = contract('a,b,c,a,b,c->', *x, *y, backend=backend) When such a sum dimension appears in the output, it must be accompanied by all of its batch dimensions, e.g. the following are equivalent:: assert len(x) == 3 and len(y) == 3 z = ubersum('abi,abi->bi', x, y, batch_dims='i') z0 = contract('ab,ac,ad,ab,ac,ad->b', *x, *y, backend=backend) z1 = contract('ab,ac,ad,ab,ac,ad->c', *x, *y, backend=backend) z2 = contract('ab,ac,ad,ab,ac,ad->d', *x, *y, backend=backend) z = torch.stack([z0, z1, z2]) Note that each batch slice through the output is multilinear in all batch slices through all inptus, thus e.g. batch matrix multiply would be implemented *without* ``batch_dims``, so the following are all equivalent:: xy = ubersum('abc,acd->abd', x, y, batch_dims='') xy = torch.stack([xa.mm(ya) for xa, ya in zip(x, y)]) xy = torch.bmm(x, y) Among all valid equations, some computations are polynomial in the sizes of the input tensors and other computations are exponential in the sizes of the input tensors. This function raises :py:class:`NotImplementedError` whenever the computation is exponential. :param str equation: An einsum equation, optionally with multiple outputs. :param torch.Tensor operands: A collection of tensors. :param str batch_dims: An optional string of batch dims. :param dict cache: An optional :func:`~opt_einsum.shared_intermediates` cache. :param bool modulo_total: Optionally allow ubersum to arbitrarily scale each result batch, which can significantly reduce computation. This is safe to set whenever each result batch denotes a nonnormalized probability distribution whose total is not of interest. :return: a tuple of tensors of requested shape, one entry per output. :rtype: tuple :raises ValueError: if tensor sizes mismatch or an output requests a batched dim without that dim's batch dims. :raises NotImplementedError: if contraction would have cost exponential in the size of any input tensor. """ # Extract kwargs. cache = kwargs.pop('cache', None) batch_dims = kwargs.pop('batch_dims', '') backend = kwargs.pop('backend', 'pyro.ops.einsum.torch_log') modulo_total = kwargs.pop('modulo_total', False) try: Ring = BACKEND_TO_RING[backend] except KeyError: raise NotImplementedError('\n'.join( ['Only the following pyro backends are currently implemented:'] + list(BACKEND_TO_RING))) # Parse generalized einsum equation. if '.' in equation: raise NotImplementedError( 'ubsersum does not yet support ellipsis notation') inputs, outputs = equation.split('->') inputs = inputs.split(',') outputs = outputs.split(',') assert len(inputs) == len(operands) assert all(isinstance(x, torch.Tensor) for x in operands) if not modulo_total and any(outputs): raise NotImplementedError( 'Try setting modulo_total=True and ensuring that your use case ' 'allows an arbitrary scale factor on each result batch.') if len(operands) != len(set(operands)): operands = [x[...] for x in operands] # ensure tensors are unique # Check sizes. with ignore_jit_warnings(): dim_to_size = {} for dims, term in zip(inputs, operands): for dim, size in zip(dims, map(int, term.shape)): old = dim_to_size.setdefault(dim, size) if old != size: raise ValueError( u"Dimension size mismatch at dim '{}': {} vs {}". format(dim, size, old)) # Construct a tensor tree shared by all outputs. tensor_tree = OrderedDict() batch_dims = frozenset(batch_dims) for dims, term in zip(inputs, operands): assert len(dims) == term.dim() term._pyro_dims = dims ordinal = batch_dims.intersection(dims) tensor_tree.setdefault(ordinal, []).append(term) # Compute outputs, sharing intermediate computations. results = [] with shared_intermediates(cache) as cache: ring = Ring(cache, dim_to_size=dim_to_size) for output in outputs: sum_dims = set(output).union(*inputs) - set(batch_dims) term = contract_to_tensor( tensor_tree, sum_dims, target_ordinal=batch_dims.intersection(output), target_dims=sum_dims.intersection(output), ring=ring) if term._pyro_dims != output: term = term.permute(*map(term._pyro_dims.index, output)) term._pyro_dims = output results.append(term) return tuple(results)
def _key(self): with ignore_jit_warnings(["Converting a tensor to a Python number"]): size = self.size.item() if isinstance(self.size, torch.Tensor) else self.size return self.name, self.dim, size, self.counter
def torus_dbn(phis=None, psis=None, lengths=None, num_sequences=None, num_states=55, prior_conc=0.1, prior_loc=0.0, prior_length_shape=100., prior_length_rate=100., prior_kappa_min=10., prior_kappa_max=1000.): # From https://pyro.ai/examples/hmm.html with ignore_jit_warnings(): if lengths is not None: assert num_sequences is None num_sequences = int(lengths.shape[0]) else: assert num_sequences is not None transition_probs = pyro.sample( 'transition_probs', dist.Dirichlet( torch.ones(num_states, num_states, dtype=torch.float) * num_states).to_event(1)) length_shape = pyro.sample('length_shape', dist.HalfCauchy(prior_length_shape)) length_rate = pyro.sample('length_rate', dist.HalfCauchy(prior_length_rate)) phi_locs = pyro.sample( 'phi_locs', dist.VonMises( torch.ones(num_states, dtype=torch.float) * prior_loc, torch.ones(num_states, dtype=torch.float) * prior_conc).to_event(1)) phi_kappas = pyro.sample( 'phi_kappas', dist.Uniform( torch.ones(num_states, dtype=torch.float) * prior_kappa_min, torch.ones(num_states, dtype=torch.float) * prior_kappa_max).to_event(1)) psi_locs = pyro.sample( 'psi_locs', dist.VonMises( torch.ones(num_states, dtype=torch.float) * prior_loc, torch.ones(num_states, dtype=torch.float) * prior_conc).to_event(1)) psi_kappas = pyro.sample( 'psi_kappas', dist.Uniform( torch.ones(num_states, dtype=torch.float) * prior_kappa_min, torch.ones(num_states, dtype=torch.float) * prior_kappa_max).to_event(1)) element_plate = pyro.plate('elements', 1, dim=-1) with pyro.plate('sequences', num_sequences, dim=-2) as batch: if lengths is not None: lengths = lengths[batch] obs_length = lengths.float().unsqueeze(-1) else: obs_length = None state = 0 sam_lengths = pyro.sample('length', dist.TransformedDistribution( dist.GammaPoisson( length_shape, length_rate), AffineTransform(0., 1.)), obs=obs_length) if lengths is None: lengths = sam_lengths.squeeze(-1).long() for t in pyro.markov(range(lengths.max())): with poutine.mask(mask=(t < lengths).unsqueeze(-1)): state = pyro.sample(f'state_{t}', dist.Categorical(transition_probs[state]), infer={'enumerate': 'parallel'}) if phis is not None: obs_phi = Vindex(phis)[batch, t].unsqueeze(-1) else: obs_phi = None if psis is not None: obs_psi = Vindex(psis)[batch, t].unsqueeze(-1) else: obs_psi = None with element_plate: pyro.sample(f'phi_{t}', dist.VonMises(phi_locs[state], phi_kappas[state]), obs=obs_phi) pyro.sample(f'psi_{t}', dist.VonMises(psi_locs[state], psi_kappas[state]), obs=obs_psi)