def find_domain(op, *domains): r""" Finds the :class:`Domain` resulting when applying ``op`` to ``domains``. :param callable op: An operation. :param Domain \*domains: One or more input domains. """ assert callable(op), op assert all(isinstance(arg, Domain) for arg in domains) if len(domains) == 1: dtype = domains[0].dtype shape = domains[0].shape if op is ops.log or op is ops.exp: dtype = 'real' elif isinstance(op, ops.ReshapeOp): shape = op.shape elif isinstance(op, ops.AssociativeOp): shape = () return Domain(shape, dtype) lhs, rhs = domains if isinstance(op, ops.GetitemOp): dtype = lhs.dtype shape = lhs.shape[:op.offset] + lhs.shape[1 + op.offset:] return Domain(shape, dtype) elif op == ops.matmul: assert lhs.shape and rhs.shape if len(rhs.shape) == 1: assert lhs.shape[-1] == rhs.shape[-1] shape = lhs.shape[:-1] elif len(lhs.shape) == 1: assert lhs.shape[-1] == rhs.shape[-2] shape = rhs.shape[:-2] + rhs.shape[-1:] else: assert lhs.shape[-1] == rhs.shape[-2] shape = broadcast_shape(lhs.shape[:-1], rhs.shape[:-2] + (1, )) + rhs.shape[-1:] return Domain(shape, 'real') if lhs.dtype == 'real' or rhs.dtype == 'real': dtype = 'real' elif op in (ops.add, ops.mul, ops.pow, ops.max, ops.min): dtype = op(lhs.dtype - 1, rhs.dtype - 1) + 1 elif op in (ops.and_, ops.or_, ops.xor): dtype = 2 elif lhs.dtype == rhs.dtype: dtype = lhs.dtype else: raise NotImplementedError('TODO') if lhs.shape == rhs.shape: shape = lhs.shape else: shape = broadcast_shape(lhs.shape, rhs.shape) return Domain(shape, dtype)
def _eager_contract_tensors(reduced_vars, terms, backend): iter_symbols = map(opt_einsum.get_symbol, itertools.count()) symbols = defaultdict(functools.partial(next, iter_symbols)) inputs = OrderedDict() einsum_inputs = [] operands = [] for term in terms: inputs.update(term.inputs) einsum_inputs.append("".join(symbols[k] for k in term.inputs) + "".join(symbols[i - len(term.shape)] for i, size in enumerate(term.shape) if size != 1)) # Squeeze absent event dims to be compatible with einsum. data = term.data batch_shape = data.shape[:len(data.shape) - len(term.shape)] event_shape = tuple(size for size in term.shape if size != 1) data = data.reshape(batch_shape + event_shape) operands.append(data) for k in reduced_vars: del inputs[k] batch_shape = tuple(v.size for v in inputs.values()) event_shape = broadcast_shape(*(term.shape for term in terms)) einsum_output = ( "".join(symbols[k] for k in inputs) + "".join(symbols[dim] for dim in range(-len(event_shape), 0) if dim in symbols)) equation = ",".join(einsum_inputs) + "->" + einsum_output data = opt_einsum.contract(equation, *operands, backend=backend) data = data.reshape(batch_shape + event_shape) return Tensor(data, inputs)
def eager_multinomial(total_count, probs, value): # Multinomial.log_prob() supports inhomogeneous total_count only by # avoiding passing total_count to the constructor. inputs, (total_count, probs, value) = align_tensors(total_count, probs, value) shape = broadcast_shape(total_count.shape + (1,), probs.shape, value.shape) probs = Tensor(probs.expand(shape), inputs) value = Tensor(value.expand(shape), inputs) total_count = Number(total_count.max().item()) # Used by distributions validation code. return Multinomial.eager_log_prob(total_count, probs, value)
def __init__(self, initial_dist, transition_dist, observation_dist, validate_args=None): assert isinstance(initial_dist, torch.distributions.MultivariateNormal) assert isinstance(transition_dist, torch.distributions.MultivariateNormal) assert isinstance(observation_dist, torch.distributions.MultivariateNormal) hidden_dim = initial_dist.event_shape[0] assert transition_dist.event_shape[0] == hidden_dim + hidden_dim obs_dim = observation_dist.event_shape[0] - hidden_dim shape = broadcast_shape(initial_dist.batch_shape + (1, ), transition_dist.batch_shape, observation_dist.batch_shape) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + (obs_dim, ) # Convert distributions to funsors. init = dist_to_funsor(initial_dist)(value="state") trans = mvn_to_funsor( transition_dist, ("time", ), OrderedDict([("state", Reals[hidden_dim]), ("state(time=1)", Reals[hidden_dim])])) obs = mvn_to_funsor( observation_dist, ("time", ), OrderedDict([("state(time=1)", Reals[hidden_dim]), ("value", Reals[obs_dim])])) # Construct the joint funsor. # Compare with pyro.distributions.hmm.GaussianMRF.log_prob(). with interpretation(lazy): time = Variable("time", Bint[time_shape[0]]) value = Variable("value", Reals[time_shape[0], obs_dim]) logp_oh = trans + obs(value=value["time"]) logp_oh = MarkovProduct(ops.logaddexp, ops.add, logp_oh, time, {"state": "state(time=1)"}) logp_oh += init logp_oh = logp_oh.reduce(ops.logaddexp, frozenset({"state", "state(time=1)"})) logp_h = trans + obs.reduce(ops.logaddexp, "value") logp_h = MarkovProduct(ops.logaddexp, ops.add, logp_h, time, {"state": "state(time=1)"}) logp_h += init logp_h = logp_h.reduce(ops.logaddexp, frozenset({"state", "state(time=1)"})) funsor_dist = logp_oh - logp_h dtype = "real" super(GaussianMRF, self).__init__(funsor_dist, batch_shape, event_shape, dtype, validate_args) self.hidden_dim = hidden_dim self.obs_dim = obs_dim
def test_switching_linear_hmm_shape(init_cat_shape, init_mvn_shape, trans_cat_shape, trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape): hidden_dim, obs_dim = obs_mat_shape[-2:] assert trans_mat_shape[-2:] == (hidden_dim, hidden_dim) init_logits = torch.randn(init_cat_shape) init_mvn = random_mvn(init_mvn_shape, hidden_dim) trans_logits = torch.randn(trans_cat_shape) trans_matrix = torch.randn(trans_mat_shape) trans_mvn = random_mvn(trans_mvn_shape, hidden_dim) obs_matrix = torch.randn(obs_mat_shape) obs_mvn = random_mvn(obs_mvn_shape, obs_dim) init_shape = broadcast_shape(init_cat_shape, init_mvn_shape) shape = broadcast_shape(init_shape[:-1] + (1, init_shape[-1]), trans_cat_shape[:-1], trans_mat_shape[:-2], trans_mvn_shape, obs_mat_shape[:-2], obs_mvn_shape) expected_batch_shape, time_shape = shape[:-2], shape[-2:-1] expected_event_shape = time_shape + (obs_dim,) actual_dist = SwitchingLinearHMM(init_logits, init_mvn, trans_logits, trans_matrix, trans_mvn, obs_matrix, obs_mvn) assert actual_dist.event_shape == expected_event_shape assert actual_dist.batch_shape == expected_batch_shape data = obs_mvn.expand(shape).sample()[..., 0, :] actual_log_prob = actual_dist.log_prob(data) assert actual_log_prob.shape == expected_batch_shape check_expand(actual_dist, data) final_cat, final_mvn = actual_dist.filter(data) assert isinstance(final_cat, dist.Categorical) assert isinstance(final_mvn, dist.MultivariateNormal) assert final_cat.batch_shape == actual_dist.batch_shape assert final_mvn.batch_shape == actual_dist.batch_shape + final_cat.logits.shape[-1:]
def __init__(self, initial_dist, transition_matrix, transition_dist, observation_matrix, observation_dist, validate_args=None): assert isinstance(initial_dist, torch.distributions.MultivariateNormal) assert isinstance(transition_matrix, torch.Tensor) assert isinstance(transition_dist, torch.distributions.MultivariateNormal) assert isinstance(observation_matrix, torch.Tensor) assert isinstance(observation_dist, torch.distributions.MultivariateNormal) hidden_dim, obs_dim = observation_matrix.shape[-2:] assert obs_dim >= hidden_dim // 2, "obs_dim must be at least half of hidden_dim" assert initial_dist.event_shape == (hidden_dim, ) assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim) assert transition_dist.event_shape == (hidden_dim, ) assert observation_dist.event_shape == (obs_dim, ) shape = broadcast_shape(initial_dist.batch_shape + (1, ), transition_matrix.shape[:-2], transition_dist.batch_shape, observation_matrix.shape[:-2], observation_dist.batch_shape) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + (obs_dim, ) # Convert distributions to funsors. init = dist_to_funsor(initial_dist)(value="state") trans = matrix_and_mvn_to_funsor(transition_matrix, transition_dist, ("time", ), "state", "state(time=1)") obs = matrix_and_mvn_to_funsor(observation_matrix, observation_dist, ("time", ), "state(time=1)", "value") dtype = "real" # Construct the joint funsor. with interpretation(lazy): value = Variable("value", Reals[time_shape[0], obs_dim]) result = trans + obs(value=value["time"]) result = MarkovProduct(ops.logaddexp, ops.add, result, "time", {"state": "state(time=1)"}) result = init + result.reduce(ops.logaddexp, "state(time=1)") funsor_dist = result.reduce(ops.logaddexp, "state") super(GaussianHMM, self).__init__(funsor_dist, batch_shape, event_shape, dtype, validate_args) self.hidden_dim = hidden_dim self.obs_dim = obs_dim
def test_gaussian_mrf_log_prob(init_shape, trans_shape, obs_shape, hidden_dim, obs_dim): init_dist = random_mvn(init_shape, hidden_dim) trans_dist = random_mvn(trans_shape, hidden_dim + hidden_dim) obs_dist = random_mvn(obs_shape, hidden_dim + obs_dim) actual_dist = GaussianMRF(init_dist, trans_dist, obs_dist) expected_dist = dist.GaussianMRF(init_dist, trans_dist, obs_dist) assert actual_dist.event_shape == expected_dist.event_shape assert actual_dist.batch_shape == expected_dist.batch_shape batch_shape = broadcast_shape(init_shape + (1,), trans_shape, obs_shape) data = obs_dist.expand(batch_shape).sample()[..., hidden_dim:] actual_log_prob = actual_dist.log_prob(data) expected_log_prob = expected_dist.log_prob(data) assert_close(actual_log_prob, expected_log_prob, atol=1e-4, rtol=1e-4) check_expand(actual_dist, data)
def test_discrete_normal_log_prob(init_shape, trans_shape, obs_shape, state_dim): init_logits = torch.randn(init_shape + (state_dim,)) trans_logits = torch.randn(trans_shape + (state_dim, state_dim)) loc = torch.randn(obs_shape + (state_dim,)) scale = torch.randn(obs_shape + (state_dim,)).exp() obs_dist = dist.Normal(loc, scale) actual_dist = DiscreteHMM(init_logits, trans_logits, obs_dist) expected_dist = dist.DiscreteHMM(init_logits, trans_logits, obs_dist) assert actual_dist.event_shape == expected_dist.event_shape assert actual_dist.batch_shape == expected_dist.batch_shape batch_shape = broadcast_shape(init_shape + (1,), trans_shape, obs_shape) data = obs_dist.expand(batch_shape + (state_dim,)).sample() data = data[(slice(None),) * len(batch_shape) + (0,)] actual_log_prob = actual_dist.log_prob(data) expected_log_prob = expected_dist.log_prob(data) assert_close(actual_log_prob, expected_log_prob, rtol=5e-5) check_expand(actual_dist, data)
def test_discrete_categorical_log_prob(init_shape, trans_shape, obs_shape, state_dim): obs_dim = 4 init_logits = torch.randn(init_shape + (state_dim,)) trans_logits = torch.randn(trans_shape + (state_dim, state_dim)) obs_logits = torch.randn(obs_shape + (state_dim, obs_dim)) obs_dist = dist.Categorical(logits=obs_logits) actual_dist = DiscreteHMM(init_logits, trans_logits, obs_dist) expected_dist = dist.DiscreteHMM(init_logits, trans_logits, obs_dist) assert actual_dist.event_shape == expected_dist.event_shape assert actual_dist.batch_shape == expected_dist.batch_shape batch_shape = broadcast_shape(init_shape + (1,), trans_shape, obs_shape) data = obs_dist.expand(batch_shape + (state_dim,)).sample() data = data[(slice(None),) * len(batch_shape) + (0,)] actual_log_prob = actual_dist.log_prob(data) expected_log_prob = expected_dist.log_prob(data) assert_close(actual_log_prob, expected_log_prob) check_expand(actual_dist, data)
def __init__(self, initial_logits, transition_logits, observation_dist, validate_args=None): assert isinstance(initial_logits, torch.Tensor) assert isinstance(transition_logits, torch.Tensor) assert isinstance(observation_dist, torch.distributions.Distribution) assert initial_logits.dim() >= 1 assert transition_logits.dim() >= 2 assert len(observation_dist.batch_shape) >= 1 shape = broadcast_shape(initial_logits.shape[:-1] + (1, ), transition_logits.shape[:-2], observation_dist.batch_shape[:-1]) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + observation_dist.event_shape self._has_rsample = observation_dist.has_rsample # Normalize. initial_logits = initial_logits - initial_logits.logsumexp(-1, True) transition_logits = transition_logits - transition_logits.logsumexp( -1, True) # Convert tensors and distributions to funsors. init = tensor_to_funsor(initial_logits, ("state", )) trans = tensor_to_funsor(transition_logits, ("time", "state", "state(time=1)")) obs = dist_to_funsor(observation_dist, ("time", "state(time=1)")) dtype = obs.inputs["value"].dtype # Construct the joint funsor. with interpretation(lazy): # TODO perform math here once sequential_sum_product has been # implemented as a first-class funsor. funsor_dist = Variable("value", obs.inputs["value"]) # a bogus value # Until funsor_dist is defined, we save factors for hand-computation in .log_prob(). self._init = init self._trans = trans self._obs = obs super(DiscreteHMM, self).__init__(funsor_dist, batch_shape, event_shape, dtype, validate_args)
def eager_multinomial(total_count, probs, value): # Multinomial.log_prob() supports inhomogeneous total_count only by # avoiding passing total_count to the constructor. inputs, (total_count, probs, value) = align_tensors(total_count, probs, value) shape = broadcast_shape(total_count.shape + (1, ), probs.shape, value.shape) probs = Tensor(ops.expand(probs, shape), inputs) value = Tensor(ops.expand(value, shape), inputs) if get_backend() == "torch": total_count = Number( ops.amax(total_count, None).item()) # Used by distributions validation code. else: total_count = Tensor(ops.expand(total_count, shape[:-1]), inputs) backend_dist = import_module( BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) return backend_dist.Multinomial.eager_log_prob(total_count, probs, value) # noqa: F821
def test_discrete_mvn_log_prob(init_shape, trans_shape, obs_shape, state_dim): event_size = 4 init_logits = torch.randn(init_shape + (state_dim,)) trans_logits = torch.randn(trans_shape + (state_dim, state_dim)) loc = torch.randn(obs_shape + (state_dim, event_size)) cov = torch.randn(obs_shape + (state_dim, event_size, 2 * event_size)) cov = cov.matmul(cov.transpose(-1, -2)) scale_tril = torch.cholesky(cov) obs_dist = dist.MultivariateNormal(loc, scale_tril=scale_tril) actual_dist = DiscreteHMM(init_logits, trans_logits, obs_dist) expected_dist = dist.DiscreteHMM(init_logits, trans_logits, obs_dist) assert actual_dist.event_shape == expected_dist.event_shape assert actual_dist.batch_shape == expected_dist.batch_shape batch_shape = broadcast_shape(init_shape + (1,), trans_shape, obs_shape) data = obs_dist.expand(batch_shape + (state_dim,)).sample() data = data[(slice(None),) * len(batch_shape) + (0,)] actual_log_prob = actual_dist.log_prob(data) expected_log_prob = expected_dist.log_prob(data) assert_close(actual_log_prob, expected_log_prob) check_expand(actual_dist, data)
def test_gaussian_hmm_log_prob(init_shape, trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape, hidden_dim, obs_dim): init_dist = random_mvn(init_shape, hidden_dim) trans_mat = torch.randn(trans_mat_shape + (hidden_dim, hidden_dim)) trans_dist = random_mvn(trans_mvn_shape, hidden_dim) obs_mat = torch.randn(obs_mat_shape + (hidden_dim, obs_dim)) obs_dist = random_mvn(obs_mvn_shape, obs_dim) actual_dist = GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) expected_dist = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) assert actual_dist.batch_shape == expected_dist.batch_shape assert actual_dist.event_shape == expected_dist.event_shape shape = broadcast_shape(init_shape + (1,), trans_mat_shape, trans_mvn_shape, obs_mat_shape, obs_mvn_shape) data = obs_dist.expand(shape).sample() assert data.shape == actual_dist.shape() actual_log_prob = actual_dist.log_prob(data) expected_log_prob = expected_dist.log_prob(data) assert_close(actual_log_prob, expected_log_prob, atol=1e-5, rtol=1e-5) check_expand(actual_dist, data)
def test_gaussian_hmm_log_prob_null_dynamics(init_shape, trans_mat_shape, trans_mvn_shape, obs_mvn_shape, hidden_dim): obs_dim = hidden_dim init_dist = random_mvn(init_shape, hidden_dim) # impose null dynamics trans_mat = torch.zeros(trans_mat_shape + (hidden_dim, hidden_dim)) trans_dist = random_mvn(trans_mvn_shape, hidden_dim, diag=True) # trivial observation matrix (hidden_dim = obs_dim) obs_mat = torch.eye(hidden_dim) obs_dist = random_mvn(obs_mvn_shape, obs_dim, diag=True) actual_dist = GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) expected_dist = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist) assert actual_dist.batch_shape == expected_dist.batch_shape assert actual_dist.event_shape == expected_dist.event_shape shape = broadcast_shape(init_shape + (1,), trans_mat_shape, trans_mvn_shape, obs_mvn_shape) data = obs_dist.expand(shape).sample() assert data.shape == actual_dist.shape() actual_log_prob = actual_dist.log_prob(data) expected_log_prob = expected_dist.log_prob(data) assert_close(actual_log_prob, expected_log_prob, atol=1e-5, rtol=1e-5) check_expand(actual_dist, data) obs_cov = obs_dist.covariance_matrix.diagonal(dim1=-1, dim2=-2) trans_cov = trans_dist.covariance_matrix.diagonal(dim1=-1, dim2=-2) sum_scale = (obs_cov + trans_cov).sqrt() sum_loc = trans_dist.loc + obs_dist.loc analytic_log_prob = dist.Normal(sum_loc, sum_scale).log_prob(data).sum(-1).sum(-1) assert_close(analytic_log_prob, actual_log_prob, atol=1.0e-5)
def matrix_and_mvn_to_funsor(matrix, mvn, event_dims=(), x_name="value_x", y_name="value_y"): """ Convert a noisy affine function to a Gaussian. The noisy affine function is defined as:: y = x @ matrix + mvn.sample() The result is a non-normalized Gaussian funsor with two real inputs, ``x_name`` and ``y_name``, corresponding to a conditional distribution of real vector ``y` given real vector ``x``. :param torch.Tensor matrix: A matrix with rightmost shape ``(x_size, y_size)``. :param mvn: A multivariate normal distribution with ``event_shape == (y_size,)``. :type mvn: torch.distributions.MultivariateNormal or torch.distributions.Independent of torch.distributions.Normal :param tuple event_dims: A tuple of names for rightmost dimensions. These will be assigned to ``result.inputs`` of type ``Bint``. :param str x_name: The name of the ``x`` random variable. :param str y_name: The name of the ``y`` random variable. :return: A funsor with given ``real_inputs`` and possibly additional Bint inputs. :rtype: funsor.terms.Funsor """ assert (isinstance(mvn, torch.distributions.MultivariateNormal) or (isinstance(mvn, torch.distributions.Independent) and isinstance(mvn.base_dist, torch.distributions.Normal))) assert isinstance(matrix, torch.Tensor) x_size, y_size = matrix.shape[-2:] assert mvn.event_shape == (y_size, ) # Handle diagonal normal distributions as an efficient special case. if isinstance(mvn, torch.distributions.Independent): return AffineNormal( tensor_to_funsor(matrix, event_dims, 2), tensor_to_funsor(mvn.base_dist.loc, event_dims, 1), tensor_to_funsor(mvn.base_dist.scale, event_dims, 1), Variable(x_name, Reals[x_size]), Variable(y_name, Reals[y_size])) info_vec = mvn.loc.unsqueeze(-1).cholesky_solve(mvn.scale_tril).squeeze(-1) log_prob = (-0.5 * y_size * math.log(2 * math.pi) - mvn.scale_tril.diagonal(dim1=-1, dim2=-2).log().sum(-1) - 0.5 * (info_vec * mvn.loc).sum(-1)) batch_shape = broadcast_shape(matrix.shape[:-2], mvn.batch_shape) P_yy = mvn.precision_matrix.expand(batch_shape + (y_size, y_size)) neg_P_xy = matrix.matmul(P_yy) P_xy = -neg_P_xy P_yx = P_xy.transpose(-1, -2) P_xx = neg_P_xy.matmul(matrix.transpose(-1, -2)) precision = torch.cat( [torch.cat([P_xx, P_xy], -1), torch.cat([P_yx, P_yy], -1)], -2) info_y = info_vec.expand(batch_shape + (y_size, )) info_x = -matrix.matmul(info_y.unsqueeze(-1)).squeeze(-1) info_vec = torch.cat([info_x, info_y], -1) info_vec = tensor_to_funsor(info_vec, event_dims, 1) precision = tensor_to_funsor(precision, event_dims, 2) inputs = info_vec.inputs.copy() inputs[x_name] = Reals[x_size] inputs[y_name] = Reals[y_size] return tensor_to_funsor(log_prob, event_dims) + Gaussian( info_vec.data, precision.data, inputs)
def _eager_subs_real(self, subs, remaining_subs): # Broadcast all component tensors. subs = OrderedDict(subs) int_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype != 'real') tensors = [ Tensor(self.info_vec, int_inputs), Tensor(self.precision, int_inputs) ] tensors.extend(subs.values()) int_inputs, tensors = align_tensors(*tensors) batch_dim = len(tensors[0].shape) - 1 batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors)) (info_vec, precision), values = tensors[:2], tensors[2:] offsets, event_size = _compute_offsets(self.inputs) slices = [(k, slice(offset, offset + self.inputs[k].num_elements)) for k, offset in offsets.items()] # Expand all substituted values. values = OrderedDict(zip(subs, values)) for k, value in values.items(): value = value.reshape(value.shape[:batch_dim] + (-1, )) if not get_tracing_state(): assert value.shape[-1] == self.inputs[k].num_elements values[k] = ops.expand(value, batch_shape + value.shape[-1:]) # Try to perform a complete substitution of all real variables, resulting in a Tensor. if all(k in subs for k, d in self.inputs.items() if d.dtype == 'real'): # Form the concatenated value. value = BlockVector(batch_shape + (event_size, )) for k, i in slices: if k in values: value[..., i] = values[k] value = value.as_tensor() # Evaluate the non-normalized log density. result = _vv(value, info_vec - 0.5 * _mv(precision, value)) result = Tensor(result, int_inputs) assert result.output == Real return Subs(result, remaining_subs) if remaining_subs else result # Perform a partial substution of a subset of real variables, resulting in a Joint. # We split real inputs into two sets: a for the preserved and b for the substituted. b = frozenset(k for k, v in subs.items()) a = frozenset(k for k, d in self.inputs.items() if d.dtype == 'real' and k not in b) prec_aa = ops.cat( -2, *[ ops.cat( -1, *[precision[..., i1, i2] for k2, i2 in slices if k2 in a]) for k1, i1 in slices if k1 in a ]) prec_ab = ops.cat( -2, *[ ops.cat( -1, *[precision[..., i1, i2] for k2, i2 in slices if k2 in b]) for k1, i1 in slices if k1 in a ]) prec_bb = ops.cat( -2, *[ ops.cat( -1, *[precision[..., i1, i2] for k2, i2 in slices if k2 in b]) for k1, i1 in slices if k1 in b ]) info_a = ops.cat(-1, *[info_vec[..., i] for k, i in slices if k in a]) info_b = ops.cat(-1, *[info_vec[..., i] for k, i in slices if k in b]) value_b = ops.cat(-1, *[values[k] for k, i in slices if k in b]) info_vec = info_a - _mv(prec_ab, value_b) log_scale = _vv(value_b, info_b - 0.5 * _mv(prec_bb, value_b)) precision = ops.expand(prec_aa, info_vec.shape + info_vec.shape[-1:]) inputs = int_inputs.copy() for k, d in self.inputs.items(): if k not in subs: inputs[k] = d result = Gaussian(info_vec, precision, inputs) + Tensor( log_scale, int_inputs) return Subs(result, remaining_subs) if remaining_subs else result
def __init__(self, initial_logits, initial_mvn, transition_logits, transition_matrix, transition_mvn, observation_matrix, observation_mvn, exact=False, validate_args=None): assert isinstance(initial_logits, torch.Tensor) assert isinstance(initial_mvn, torch.distributions.MultivariateNormal) assert isinstance(transition_logits, torch.Tensor) assert isinstance(transition_matrix, torch.Tensor) assert isinstance(transition_mvn, torch.distributions.MultivariateNormal) assert isinstance(observation_matrix, torch.Tensor) assert isinstance(observation_mvn, torch.distributions.MultivariateNormal) hidden_cardinality = initial_logits.size(-1) hidden_dim, obs_dim = observation_matrix.shape[-2:] assert obs_dim >= hidden_dim // 2, "obs_dim must be at least half of hidden_dim" assert initial_mvn.event_shape[0] == hidden_dim assert transition_logits.size(-1) == hidden_cardinality assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim) assert transition_mvn.event_shape[0] == hidden_dim assert observation_mvn.event_shape[0] == obs_dim init_shape = broadcast_shape(initial_logits.shape, initial_mvn.batch_shape) shape = broadcast_shape(init_shape[:-1] + (1, init_shape[-1]), transition_logits.shape[:-1], transition_matrix.shape[:-2], transition_mvn.batch_shape, observation_matrix.shape[:-2], observation_mvn.batch_shape) assert shape[-1] == hidden_cardinality batch_shape, time_shape = shape[:-2], shape[-2:-1] event_shape = time_shape + (obs_dim, ) # Normalize. initial_logits = initial_logits - initial_logits.logsumexp(-1, True) transition_logits = transition_logits - transition_logits.logsumexp( -1, True) # Convert tensors and distributions to funsors. init = (tensor_to_funsor(initial_logits, ("class", )) + dist_to_funsor(initial_mvn, ("class", ))(value="state")) trans = (tensor_to_funsor(transition_logits, ("time", "class", "class(time=1)")) + matrix_and_mvn_to_funsor(transition_matrix, transition_mvn, ("time", "class(time=1)"), "state", "state(time=1)")) obs = matrix_and_mvn_to_funsor(observation_matrix, observation_mvn, ("time", "class(time=1)"), "state(time=1)", "value") if "class(time=1)" not in set(trans.inputs).union(obs.inputs): raise ValueError( "neither transition nor observation depend on discrete state") dtype = "real" # Construct the joint funsor. with interpretation(lazy): # TODO perform math here once sequential_sum_product has been # implemented as a first-class funsor. funsor_dist = Variable("value", obs.inputs["value"]) # a bogus value # Until funsor_dist is defined, we save factors for hand-computation in .log_prob(). self._init = init self._trans = trans self._obs = obs super(SwitchingLinearHMM, self).__init__(funsor_dist, batch_shape, event_shape, dtype, validate_args) self.exact = exact