def __init__(self, base_dist, *, gate=None, gate_logits=None, validate_args=None): if (gate is None) == (gate_logits is None): raise ValueError( "Either `gate` or `gate_logits` must be specified, but not both." ) if gate is not None: batch_shape = broadcast_shape(gate.shape, base_dist.batch_shape) self.gate = gate.expand(batch_shape) else: batch_shape = broadcast_shape(gate_logits.shape, base_dist.batch_shape) self.gate_logits = gate_logits.expand(batch_shape) if base_dist.event_shape: raise ValueError("ZeroInflatedDistribution expected empty " "base_dist.event_shape but got {}".format( base_dist.event_shape)) self.base_dist = base_dist.expand(batch_shape) event_shape = torch.Size() super().__init__(batch_shape, event_shape, validate_args)
def find_domain(op, *domains): r""" Finds the :class:`Domain` resulting when applying ``op`` to ``domains``. :param callable op: An operation. :param Domain \*domains: One or more input domains. """ assert callable(op), op assert all(isinstance(arg, Domain) for arg in domains) if len(domains) == 1: dtype = domains[0].dtype shape = domains[0].shape if op is ops.log or op is ops.exp: dtype = 'real' elif isinstance(op, ops.ReshapeOp): shape = op.shape elif isinstance(op, ops.AssociativeOp): shape = () return Domain(shape, dtype) lhs, rhs = domains if isinstance(op, ops.GetitemOp): dtype = lhs.dtype shape = lhs.shape[:op.offset] + lhs.shape[1 + op.offset:] return Domain(shape, dtype) elif op == ops.matmul: assert lhs.shape and rhs.shape if len(rhs.shape) == 1: assert lhs.shape[-1] == rhs.shape[-1] shape = lhs.shape[:-1] elif len(lhs.shape) == 1: assert lhs.shape[-1] == rhs.shape[-2] shape = rhs.shape[:-2] + rhs.shape[-1:] else: assert lhs.shape[-1] == rhs.shape[-2] shape = broadcast_shape(lhs.shape[:-1], rhs.shape[:-2] + (1, )) + rhs.shape[-1:] return Domain(shape, 'real') if lhs.dtype == 'real' or rhs.dtype == 'real': dtype = 'real' elif op in (ops.add, ops.mul, ops.pow, ops.max, ops.min): dtype = op(lhs.dtype - 1, rhs.dtype - 1) + 1 elif op in (ops.and_, ops.or_, ops.xor): dtype = 2 elif lhs.dtype == rhs.dtype: dtype = lhs.dtype else: raise NotImplementedError('TODO') if lhs.shape == rhs.shape: shape = lhs.shape else: shape = broadcast_shape(lhs.shape, rhs.shape) return Domain(shape, dtype)
def test_broadcast(mask_shape, component0_shape, component1_shape, value_shape): mask = torch.empty(torch.Size(mask_shape)).bernoulli_(0.5).bool() component0 = dist.Normal(torch.zeros(component0_shape), 1.0) component1 = dist.Exponential(torch.ones(component1_shape)) value = torch.ones(value_shape) d = dist.MaskedMixture(mask, component0, component1) d_shape = broadcast_shape(mask_shape, component0_shape, component1_shape) assert d.batch_shape == d_shape log_prob_shape = broadcast_shape(d_shape, value_shape) assert d.log_prob(value).shape == log_prob_shape
def test_stable_hmm_shape(init_shape, trans_mat_shape, trans_dist_shape, obs_mat_shape, obs_dist_shape, hidden_dim, obs_dim): stability = dist.Uniform(0, 2).sample() init_dist = random_stable(stability, init_shape + (hidden_dim, )).to_event(1) trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim)) trans_dist = random_stable(stability, trans_dist_shape + (hidden_dim, )).to_event(1) obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim)) obs_dist = random_stable(stability, obs_dist_shape + (obs_dim, )).to_event(1) d = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=4) shape = broadcast_shape(init_shape + (4, ), trans_mat_shape, trans_dist_shape, obs_mat_shape, obs_dist_shape) expected_batch_shape, time_shape = shape[:-1], shape[-1:] expected_event_shape = time_shape + (obs_dim, ) assert d.batch_shape == expected_batch_shape assert d.event_shape == expected_event_shape assert d.support.event_dim == d.event_dim x = d.rsample() assert x.shape == d.shape() x = d.rsample((6, )) assert x.shape == (6, ) + d.shape() x = d.expand((6, 5)).rsample() assert x.shape == (6, 5) + d.event_shape
def _eager_contract_tensors(reduced_vars, terms, backend): iter_symbols = map(opt_einsum.get_symbol, itertools.count()) symbols = defaultdict(functools.partial(next, iter_symbols)) inputs = OrderedDict() einsum_inputs = [] operands = [] for term in terms: inputs.update(term.inputs) einsum_inputs.append("".join(symbols[k] for k in term.inputs) + "".join(symbols[i - len(term.shape)] for i, size in enumerate(term.shape) if size != 1)) # Squeeze absent event dims to be compatible with einsum. data = term.data batch_shape = data.shape[:data.dim() - len(term.shape)] event_shape = tuple(size for size in term.shape if size != 1) data = data.reshape(batch_shape + event_shape) operands.append(data) for k in reduced_vars: del inputs[k] batch_shape = tuple(v.size for v in inputs.values()) event_shape = broadcast_shape(*(term.shape for term in terms)) einsum_output = ( "".join(symbols[k] for k in inputs) + "".join(symbols[dim] for dim in range(-len(event_shape), 0) if dim in symbols)) equation = ",".join(einsum_inputs) + "->" + einsum_output data = opt_einsum.contract(equation, *operands, backend=backend) data = data.reshape(batch_shape + event_shape) return Tensor(data, inputs)
def get_scipy_batch_logpdf(self, idx): if not self.scipy_arg_fn: return dist_params = self.get_dist_params(idx, wrap_tensor=False) dist_params_wrapped = self.get_dist_params(idx) dist_params = self._convert_logits_to_ps(dist_params) test_data = self.get_test_data(idx, wrap_tensor=False) test_data_wrapped = self.get_test_data(idx) shape = broadcast_shape( self.pyro_dist(**dist_params_wrapped).shape(), test_data_wrapped.size()) log_prob = [] for i in range(len(test_data)): batch_params = {} for k in dist_params: param = np.broadcast_to(dist_params[k], shape) batch_params[k] = param[i] args, kwargs = self.scipy_arg_fn(**batch_params) if self.is_discrete: log_prob.append( self.scipy_dist.logpmf(test_data[i], *args, **kwargs)) else: log_prob.append( self.scipy_dist.logpdf(test_data[i], *args, **kwargs)) return log_prob
def test_gaussian_hmm_log_prob(init_shape, trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape, hidden_dim, obs_dim): init_dist = random_mvn(init_shape, hidden_dim) trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim)) trans_dist = random_mvn(trans_mvn_shape, hidden_dim) obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim)) obs_dist = random_mvn(obs_mvn_shape, obs_dim) actual_dist = GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) expected_dist = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) assert actual_dist.batch_shape == expected_dist.batch_shape assert actual_dist.event_shape == expected_dist.event_shape shape = broadcast_shape(init_shape + (1, ), trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape) data = obs_dist.expand(shape).sample() assert data.shape == actual_dist.shape() actual_log_prob = actual_dist.log_prob(data) expected_log_prob = expected_dist.log_prob(data) assert_close(actual_log_prob, expected_log_prob, atol=1e-5, rtol=1e-5) check_expand(actual_dist, data)
def test_gamma_gaussian_hmm_shape(scale_shape, init_shape, trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape, hidden_dim, obs_dim): init_dist = random_mvn(init_shape, hidden_dim) trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim)) trans_dist = random_mvn(trans_mvn_shape, hidden_dim) obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim)) obs_dist = random_mvn(obs_mvn_shape, obs_dim) scale_dist = random_gamma(scale_shape) d = dist.GammaGaussianHMM(scale_dist, init_dist, trans_mat, trans_dist, obs_mat, obs_dist) shape = broadcast_shape(scale_shape + (1, ), init_shape + (1, ), trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape) expected_batch_shape, time_shape = shape[:-1], shape[-1:] expected_event_shape = time_shape + (obs_dim, ) assert d.batch_shape == expected_batch_shape assert d.event_shape == expected_event_shape assert d.support.event_dim == d.event_dim data = obs_dist.expand(shape).sample() assert data.shape == d.shape() actual = d.log_prob(data) assert actual.shape == expected_batch_shape check_expand(d, data) mixing, final = d.filter(data) assert isinstance(mixing, dist.Gamma) assert mixing.batch_shape == d.batch_shape assert mixing.event_shape == () assert isinstance(final, dist.MultivariateNormal) assert final.batch_shape == d.batch_shape assert final.event_shape == (hidden_dim, )
def test_discrete_hmm_shape(ok, init_shape, trans_shape, obs_shape, event_shape, state_dim): init_logits = torch.randn(init_shape + (state_dim, )) trans_logits = torch.randn(trans_shape + (state_dim, state_dim)) obs_logits = torch.randn(obs_shape + (state_dim, ) + event_shape) obs_dist = dist.Bernoulli(logits=obs_logits).to_event(len(event_shape)) data = obs_dist.sample()[(slice(None), ) * len(obs_shape) + (0, )] if not ok: with pytest.raises(ValueError): d = dist.DiscreteHMM(init_logits, trans_logits, obs_dist) d.log_prob(data) return d = dist.DiscreteHMM(init_logits, trans_logits, obs_dist) assert d.support.event_dim == d.event_dim actual = d.log_prob(data) expected_shape = broadcast_shape(init_shape, trans_shape[:-1], obs_shape[:-1]) assert actual.shape == expected_shape check_expand(d, data) final = d.filter(data) assert isinstance(final, dist.Categorical) assert final.batch_shape == d.batch_shape assert final.event_shape == () assert final.support.upper_bound == state_dim - 1
def __init__(self, initial_dist, transition_dist, observation_dist, validate_args=None): assert isinstance(initial_dist, torch.distributions.MultivariateNormal) assert isinstance(transition_dist, torch.distributions.MultivariateNormal) assert isinstance(observation_dist, torch.distributions.MultivariateNormal) hidden_dim = initial_dist.event_shape[0] assert transition_dist.event_shape[0] == hidden_dim + hidden_dim obs_dim = observation_dist.event_shape[0] - hidden_dim shape = broadcast_shape(initial_dist.batch_shape + (1, ), transition_dist.batch_shape, observation_dist.batch_shape) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + (obs_dim, ) super(GaussianMRF, self).__init__(batch_shape, event_shape, validate_args=validate_args) self.hidden_dim = hidden_dim self.obs_dim = obs_dim self._init = mvn_to_gaussian(initial_dist) self._trans = mvn_to_gaussian(transition_dist) self._obs = mvn_to_gaussian(observation_dist)
def __init__(self, base_dist, mask): if broadcast_shape(mask.shape, base_dist.batch_shape) != base_dist.batch_shape: raise ValueError("Expected mask.shape to be broadcastable to base_dist.batch_shape, " "actual {} vs {}".format(mask.shape, base_dist.batch_shape)) self.base_dist = base_dist self._mask = mask super(MaskedDistribution, self).__init__(base_dist.batch_shape, base_dist.event_shape)
def test_gaussian_hmm_shape(diag, init_shape, trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape, hidden_dim, obs_dim): init_dist = random_mvn(init_shape, hidden_dim) trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim)) trans_dist = random_mvn(trans_mvn_shape, hidden_dim) obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim)) obs_dist = random_mvn(obs_mvn_shape, obs_dim) if diag: scale = obs_dist.scale_tril.diagonal(dim1=-2, dim2=-1) obs_dist = dist.Normal(obs_dist.loc, scale).to_event(1) d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) shape = broadcast_shape(init_shape + (1,), trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape) expected_batch_shape, time_shape = shape[:-1], shape[-1:] expected_event_shape = time_shape + (obs_dim,) assert d.batch_shape == expected_batch_shape assert d.event_shape == expected_event_shape data = obs_dist.expand(shape).sample() assert data.shape == d.shape() actual = d.log_prob(data) assert actual.shape == expected_batch_shape check_expand(d, data) final = d.filter(data) assert isinstance(final, dist.MultivariateNormal) assert final.batch_shape == d.batch_shape assert final.event_shape == (hidden_dim,)
def sample(self, guide_name, fn, infer=None): """ Wrapper around ``pyro.sample()`` to create a single auxiliary sample site and then unpack to multiple sample sites for model replay. :param str guide_name: The name of the auxiliary guide site. :param callable fn: A distribution with shape ``self.event_shape``. :param dict infer: Optional inference configuration dict. :returns: A pair ``(guide_z, model_zs)`` where ``guide_z`` is the single concatenated blob and ``model_zs`` is a dict mapping site name to constrained model sample. :rtype: tuple """ # Sample a packed tensor. if fn.event_shape != self.event_shape: raise ValueError( "Invalid fn.event_shape for group: expected {}, actual {}". format(tuple(self.event_shape), tuple(fn.event_shape))) if infer is None: infer = {} infer["is_auxiliary"] = True guide_z = pyro.sample(guide_name, fn, infer=infer) common_batch_shape = guide_z.shape[:-1] model_zs = {} pos = 0 for site in self.prototype_sites: name = site["name"] fn = site["fn"] # Extract slice from packed sample. size = self._site_sizes[name] batch_shape = broadcast_shape(common_batch_shape, self._site_batch_shapes[name]) unconstrained_z = guide_z[..., pos:pos + size] unconstrained_z = unconstrained_z.reshape(batch_shape + fn.event_shape) pos += size # Transform to constrained space. transform = biject_to(fn.support) z = transform(unconstrained_z) log_density = transform.inv.log_abs_det_jacobian( z, unconstrained_z) log_density = sum_rightmost( log_density, log_density.dim() - z.dim() + fn.event_dim) delta_dist = dist.Delta(z, log_density=log_density, event_dim=fn.event_dim) # Replay model sample statement. with ExitStack() as stack: for frame in site["cond_indep_stack"]: plate = self.guide.plate(frame.name) if plate not in runtime._PYRO_STACK: stack.enter_context(plate) model_zs[name] = pyro.sample(name, delta_dist) return guide_z, model_zs
def __init__(self, mu, sigma, *args, **kwargs): torch_dist = torch.distributions.Normal(mu, sigma) x_shape = torch.Size( broadcast_shape(mu.size(), sigma.size(), strict=True)) event_dim = 1 super(Normal, self).__init__(torch_dist, x_shape, event_dim, *args, **kwargs)
def __init__( self, total_count, logits, multiplicative_noise_scale, *, num_quad_points=8, validate_args=None, ): if num_quad_points < 1: raise ValueError("num_quad_points must be positive.") total_count, logits, multiplicative_noise_scale = broadcast_all( total_count, logits, multiplicative_noise_scale) self.quad_points, self.log_weights = get_quad_rule( num_quad_points, logits) quad_logits = ( logits.unsqueeze(-1) + multiplicative_noise_scale.unsqueeze(-1) * self.quad_points) self.nb_dist = NegativeBinomial(total_count=total_count.unsqueeze(-1), logits=quad_logits) self.multiplicative_noise_scale = multiplicative_noise_scale self.total_count = total_count self.logits = logits self.num_quad_points = num_quad_points batch_shape = broadcast_shape(multiplicative_noise_scale.shape, self.nb_dist.batch_shape[:-1]) event_shape = torch.Size() super().__init__(batch_shape, event_shape, validate_args)
def __init__(self, leaf_times, rate_grid, *, validate_args=None): batch_shape = broadcast_shape(leaf_times.shape[:-1], rate_grid.shape[:-1]) event_shape = (leaf_times.size(-1) - 1, ) self.leaf_times = leaf_times self.rate_grid = rate_grid super().__init__(batch_shape, event_shape, validate_args=validate_args)
def sample(self, sample_shape=torch.Size([])): """ :param ~torch.Size sample_shape: Sample shape, last dimension must be ``num_steps`` and must be broadcastable to ``(batch_size, num_steps)``. batch_size must be int not tuple. """ # shape: batch_size x num_steps x categorical_size shape = broadcast_shape( torch.Size(list(self.batch_shape) + [1, 1]), torch.Size(list(sample_shape) + [1]), torch.Size((1, 1, self.event_shape[-1])), ) # state: batch_size x state_dim state = OneHotCategorical(logits=self.initial_logits).sample() # sample: batch_size x num_steps x categorical_size sample = torch.zeros(shape) for i in range(shape[-2]): # batch_size x 1 x state_dim @ # batch_size x state_dim x categorical_size obs_logits = torch.matmul(state.unsqueeze(-2), self.observation_logits).squeeze(-2) sample[:, i, :] = OneHotCategorical(logits=obs_logits).sample() # batch_size x 1 x state_dim @ # batch_size x state_dim x state_dim trans_logits = torch.matmul(state.unsqueeze(-2), self.transition_logits).squeeze(-2) state = OneHotCategorical(logits=trans_logits).sample() return sample
def __init__(self, initial_logits, transition_logits, observation_dist, validate_args=None): if initial_logits.dim() < 1: raise ValueError( "expected initial_logits to have at least one dim, " "actual shape = {}".format(initial_logits.shape)) if transition_logits.dim() < 2: raise ValueError( "expected transition_logits to have at least two dims, " "actual shape = {}".format(transition_logits.shape)) if len(observation_dist.batch_shape) < 1: raise ValueError( "expected observation_dist to have at least one batch dim, " "actual .batch_shape = {}".format( observation_dist.batch_shape)) shape = broadcast_shape(initial_logits.shape[:-1] + (1, ), transition_logits.shape[:-2], observation_dist.batch_shape[:-1]) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + observation_dist.event_shape self.initial_logits = initial_logits - initial_logits.logsumexp( -1, True) self.transition_logits = transition_logits - transition_logits.logsumexp( -1, True) self.observation_dist = observation_dist super(DiscreteHMM, self).__init__(batch_shape, event_shape, validate_args=validate_args)
def _make_phylogeny(leaf_times, coal_times): assert leaf_times.size(-1) == 1 + coal_times.size(-1) # Expand shapes to match. N = leaf_times.size(-1) batch_shape = broadcast_shape(leaf_times.shape[:-1], coal_times.shape[:-1]) if leaf_times.shape[:-1] != batch_shape: leaf_times = leaf_times.expand(batch_shape + (N, )) if coal_times.shape[:-1] != batch_shape: coal_times = coal_times.expand(batch_shape + (N - 1, )) # Combine N sampling events (leaf_times) plus N-1 coalescent events # (coal_times) into a pair (times, signs) of arrays of length 2N-1, where # leaf sample sign is +1 and coalescent sign is -1. times = torch.cat([coal_times, leaf_times], dim=-1) signs = torch.linspace(1.5 - N, N - 0.5, 2 * N - 1).sign() # e.g. [-1, -1, +1, +1, +1] # Sort the events reverse-ordered in time, i.e. latest to earliest. times, index = times.sort(dim=-1, descending=True) signs = signs[index] inv_index = index.new_empty(index.shape) inv_index.scatter_(-1, index, torch.arange(2 * N - 1).expand_as(index)) # Compute the number n of lineages preceding each event, then the binomial # coefficients that will multiply the base coalescence rate. lineages = signs.cumsum(-1) binomial = lineages * (lineages - 1) / 2 # Compute the binomial coefficient following each coalescent event. coal_index = inv_index[..., :N - 1] coal_binomial = binomial.gather(-1, coal_index - 1) return _Phylogeny(times, signs, lineages, binomial, coal_binomial)
def test_studentt_hmm_shape( init_shape, trans_mat_shape, trans_dist_shape, obs_mat_shape, obs_dist_shape, hidden_dim, obs_dim, ): init_dist = random_studentt(init_shape + (hidden_dim, )).to_event(1) trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim)) trans_dist = random_studentt(trans_dist_shape + (hidden_dim, )).to_event(1) obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim)) obs_dist = random_studentt(obs_dist_shape + (obs_dim, )).to_event(1) d = dist.LinearHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) shape = broadcast_shape( init_shape + (1, ), trans_mat_shape, trans_dist_shape, obs_mat_shape, obs_dist_shape, ) expected_batch_shape, time_shape = shape[:-1], shape[-1:] expected_event_shape = time_shape + (obs_dim, ) assert d.batch_shape == expected_batch_shape assert d.event_shape == expected_event_shape assert d.support.event_dim == d.event_dim x = d.rsample() assert x.shape == d.shape() x = d.rsample((6, )) assert x.shape == (6, ) + d.shape() x = d.expand((6, 5)).rsample() assert x.shape == (6, 5) + d.event_shape
def matrix_and_mvn_to_gaussian(matrix, mvn): """ Convert a noisy affine function to a Gaussian. The noisy affine function is defined as:: y = x @ matrix + mvn.sample() :param ~torch.Tensor matrix: A matrix with rightmost shape ``(x_dim, y_dim)``. :param ~torch.distributions.MultivariateNormal mvn: A multivariate normal distribution. :return: A Gaussian with broadcasted batch shape and ``.dim() == x_dim + y_dim``. :rtype: ~pyro.ops.gaussian.Gaussian """ assert (isinstance(mvn, torch.distributions.MultivariateNormal) or (isinstance(mvn, torch.distributions.Independent) and isinstance(mvn.base_dist, torch.distributions.Normal))) assert isinstance(matrix, torch.Tensor) x_dim, y_dim = matrix.shape[-2:] assert mvn.event_shape == (y_dim,) batch_shape = broadcast_shape(matrix.shape[:-2], mvn.batch_shape) matrix = matrix.expand(batch_shape + (x_dim, y_dim)) mvn = mvn.expand(batch_shape) # Handle diagonal normal distributions as an efficient special case. if isinstance(mvn, torch.distributions.Independent): return AffineNormal(matrix, mvn.base_dist.loc, mvn.base_dist.scale) y_gaussian = mvn_to_gaussian(mvn) result = _matrix_and_gaussian_to_gaussian(matrix, y_gaussian) assert result.batch_shape == batch_shape assert result.dim() == x_dim + y_dim return result
def __init__(self, mask, component0, component1, validate_args=None): if not torch.is_tensor(mask) or mask.dtype != torch.bool: raise ValueError( 'Expected mask to be a BoolTensor but got {}'.format( type(mask))) if component0.event_shape != component1.event_shape: raise ValueError( 'components event_shape disagree: {} vs {}'.format( component0.event_shape, component1.event_shape)) batch_shape = broadcast_shape(mask.shape, component0.batch_shape, component1.batch_shape) if mask.shape != batch_shape: mask = mask.expand(batch_shape) if component0.batch_shape != batch_shape: component0 = component0.expand(batch_shape) if component1.batch_shape != batch_shape: component1 = component1.expand(batch_shape) self.mask = mask self.component0 = component0 self.component1 = component1 super().__init__(batch_shape, component0.event_shape, validate_args) # We need to disable _validate_sample on each component since samples are only valid on the # component from which they are drawn. Instead we perform validation using a MaskedConstraint. self.component0._validate_args = False self.component1._validate_args = False
def log_prob(self, value): if self._mask is False: shape = broadcast_shape(self.base_dist.batch_shape, value.shape[:value.dim() - self.event_dim]) return torch.zeros((), device=value.device).expand(shape) if self._mask is True: return self.base_dist.log_prob(value) return scale_and_mask(self.base_dist.log_prob(value), mask=self._mask)
def batch_shape(self): return broadcast_shape( self.log_normalizer.shape, self.info_vec.shape[:-1], self.precision.shape[:-2], self.alpha.shape, self.beta.shape, )
def infer_shapes( loc, covariance_matrix=None, precision_matrix=None, scale_tril=None ): batch_shape, event_shape = loc[:-1], loc[-1:] for matrix in [covariance_matrix, precision_matrix, scale_tril]: if matrix is not None: batch_shape = broadcast_shape(batch_shape, matrix[:-2]) return batch_shape, event_shape
def score_parts(self, value): shape = broadcast_shape(self.batch_shape, value.shape[:value.dim() - self.event_dim]) log_prob, score_function, entropy_term = self.base_dist.score_parts(value) log_prob = sum_rightmost(log_prob, self.reinterpreted_batch_ndims).expand(shape) if not isinstance(score_function, numbers.Number): score_function = sum_rightmost(score_function, self.reinterpreted_batch_ndims).expand(shape) if not isinstance(entropy_term, numbers.Number): entropy_term = sum_rightmost(entropy_term, self.reinterpreted_batch_ndims).expand(shape) return ScoreParts(log_prob, score_function, entropy_term)
def eager_multinomial(total_count, probs, value): # Multinomial.log_prob() supports inhomogeneous total_count only by # avoiding passing total_count to the constructor. inputs, (total_count, probs, value) = align_tensors(total_count, probs, value) shape = broadcast_shape(total_count.shape + (1,), probs.shape, value.shape) probs = Tensor(probs.expand(shape), inputs) value = Tensor(value.expand(shape), inputs) total_count = Number(total_count.max().item()) # Used by distributions validation code. return Multinomial.eager_log_prob(total_count=total_count, probs=probs, value=value)
def test_support_shape(dist): for idx in range(dist.get_num_test_data()): dist_params = dist.get_dist_params(idx) d = dist.pyro_dist(**dist_params) assert d.support.event_dim == d.event_dim x = dist.get_test_data(idx) ok = d.support.check(x) assert ok.shape == broadcast_shape(d.batch_shape, x.shape[:x.dim() - d.event_dim]) assert ok.all()
def __init__(self, leaf_times, rate=1., *, validate_args=None): rate = torch.as_tensor(rate, dtype=leaf_times.dtype, device=leaf_times.device) batch_shape = broadcast_shape(rate.shape, leaf_times.shape[:-1]) event_shape = (leaf_times.size(-1) - 1, ) self.leaf_times = leaf_times self.rate = rate super().__init__(batch_shape, event_shape, validate_args=validate_args)
def _gather(tensor, dim, index): """ Like :func:`torch.gather` but broadcasts. """ if dim != -1: raise NotImplementedError shape = broadcast_shape(tensor.shape[:-1], index.shape[:-1]) + (-1, ) tensor = tensor.expand(shape) index = index.expand(shape) return tensor.gather(dim, index)
def forward(self, *input_args): # we have a single object if len(input_args) == 1: # regardless of type, # we don't care about single objects # we just index into the object input_args = input_args[0] # don't concat things that are just single objects if torch.is_tensor(input_args): return input_args else: if self.allow_broadcast: shape = broadcast_shape(*[s.shape[:-1] for s in input_args]) + (-1,) input_args = [s.expand(shape) for s in input_args] return torch.cat(input_args, dim=-1)
def get_scipy_batch_logpdf(self, idx): if not self.scipy_arg_fn: return dist_params = self.get_dist_params(idx, wrap_tensor=False) dist_params_wrapped = self.get_dist_params(idx) dist_params = self._convert_logits_to_ps(dist_params) test_data = self.get_test_data(idx, wrap_tensor=False) test_data_wrapped = self.get_test_data(idx) shape = broadcast_shape(self.pyro_dist(**dist_params_wrapped).shape(), test_data_wrapped.size()) log_prob = [] for i in range(len(test_data)): batch_params = {} for k in dist_params: param = np.broadcast_to(dist_params[k], shape) batch_params[k] = param[i] args, kwargs = self.scipy_arg_fn(**batch_params) if self.is_discrete: log_prob.append(self.scipy_dist.logpmf(test_data[i], *args, **kwargs)) else: log_prob.append(self.scipy_dist.logpdf(test_data[i], *args, **kwargs)) return log_prob
def _log_prob_shape(dist, x_size=torch.Size()): event_dims = len(dist.event_shape) expected_shape = broadcast_shape(dist.shape(), x_size, strict=True) if event_dims > 0: expected_shape = expected_shape[:-event_dims] return expected_shape
def test_broadcast_shape(shapes): assert broadcast_shape(*shapes) == np.broadcast(*map(np.empty, shapes)).shape
def test_broadcast_shape_error(shapes): with pytest.raises((ValueError, RuntimeError)): broadcast_shape(*shapes)
def test_broadcast_shape_strict(shapes): assert broadcast_shape(*shapes, strict=True) == np.broadcast(*map(np.empty, shapes)).shape
def test_broadcast_shape_strict_error(shapes): with pytest.raises(ValueError): broadcast_shape(*shapes, strict=True)
def log_prob(self, value): shape = broadcast_shape(self.batch_shape, value.shape[:value.dim() - self.event_dim]) return sum_rightmost(self.base_dist.log_prob(value), self.reinterpreted_batch_ndims).expand(shape)