def eager_delta(v, log_density, value): # This handles event_dim specially, and hence cannot use the # generic Delta.eager_log_prob() method. assert v.output == value.output event_dim = len(v.output.shape) inputs, (v, log_density, value) = align_tensors(v, log_density, value) data = dist.Delta(v, log_density, event_dim).log_prob(value) return Tensor(data, inputs)
def eager_multinomial(total_count, probs, value): # Multinomial.log_prob() supports inhomogeneous total_count only by # avoiding passing total_count to the constructor. inputs, (total_count, probs, value) = align_tensors(total_count, probs, value) shape = broadcast_shape(total_count.shape + (1,), probs.shape, value.shape) probs = Tensor(probs.expand(shape), inputs) value = Tensor(value.expand(shape), inputs) total_count = Number(total_count.max().item()) # Used by distributions validation code. return Multinomial.eager_log_prob(total_count=total_count, probs=probs, value=value)
def eager_mvn(loc, scale_tril, value): if isinstance(loc, Variable): loc, value = value, loc dim, = loc.output.shape inputs, (loc, scale_tril) = align_tensors(loc, scale_tril) inputs.update(value.inputs) int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real') log_prob = -0.5 * dim * math.log(2 * math.pi) - scale_tril.diagonal(dim1=-1, dim2=-2).log().sum(-1) inv_scale_tril = torch.inverse(scale_tril) precision = torch.matmul(inv_scale_tril.transpose(-1, -2), inv_scale_tril) return Tensor(log_prob, int_inputs) + Gaussian(loc, precision, inputs)
def eager_normal(loc, scale, value): if isinstance(loc, Variable): loc, value = value, loc inputs, (loc, scale) = align_tensors(loc, scale) loc, scale = torch.broadcast_tensors(loc, scale) inputs.update(value.inputs) int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real') log_prob = -0.5 * math.log(2 * math.pi) - scale.log() loc = loc.unsqueeze(-1) precision = scale.pow(-2).unsqueeze(-1).unsqueeze(-1) return Tensor(log_prob, int_inputs) + Gaussian(loc, precision, inputs)
def eager_normal(loc, scale, value): if isinstance(loc, Variable): loc, value = value, loc inputs, (loc, scale) = align_tensors(loc, scale, expand=True) inputs.update(value.inputs) int_inputs = OrderedDict( (k, v) for k, v in inputs.items() if v.dtype != 'real') precision = scale.pow(-2) info_vec = (precision * loc).unsqueeze(-1) precision = precision.unsqueeze(-1).unsqueeze(-1) log_prob = -0.5 * math.log( 2 * math.pi) - scale.log() - 0.5 * (loc * info_vec).squeeze(-1) return Tensor(log_prob, int_inputs) + Gaussian(info_vec, precision, inputs)
def test_batched_einsum(equation, batch1, batch2): inputs, output = equation.split('->') inputs = inputs.split(',') sizes = dict(a=2, b=3, c=4, i=5, j=6) batch1 = OrderedDict([(k, bint(sizes[k])) for k in batch1]) batch2 = OrderedDict([(k, bint(sizes[k])) for k in batch2]) funsors = [ random_tensor(batch, reals(*(sizes[d] for d in dims))) for batch, dims in zip([batch1, batch2], inputs) ] actual = Einsum(equation, tuple(funsors)) _equation = ','.join('...' + i for i in inputs) + '->...' + output inputs, tensors = align_tensors(*funsors) batch = tuple(v.size for v in inputs.values()) tensors = [x.expand(batch + f.shape) for (x, f) in zip(tensors, funsors)] expected = Tensor(torch.einsum(_equation, tensors), inputs) assert_close(actual, expected, atol=1e-5, rtol=None)
def test_binary_funsor_funsor(symbol, dims1, dims2): sizes = {'a': 3, 'b': 4, 'c': 5} shape1 = tuple(sizes[d] for d in dims1) shape2 = tuple(sizes[d] for d in dims2) inputs1 = OrderedDict((d, bint(sizes[d])) for d in dims1) inputs2 = OrderedDict((d, bint(sizes[d])) for d in dims2) data1 = torch.rand(shape1) + 0.5 data2 = torch.rand(shape2) + 0.5 dtype = 'real' if symbol in BOOLEAN_OPS: dtype = 2 data1 = data1.byte() data2 = data2.byte() x1 = Tensor(data1, inputs1, dtype) x2 = Tensor(data2, inputs2, dtype) inputs, aligned = align_tensors(x1, x2) expected_data = binary_eval(symbol, aligned[0], aligned[1]) actual = binary_eval(symbol, x1, x2) check_funsor(actual, inputs, Domain((), dtype), expected_data)
def test_tensor_distribution(event_inputs, batch_inputs, test_grad): num_samples = 50000 sample_inputs = OrderedDict(n=bint(num_samples)) be_inputs = OrderedDict(batch_inputs + event_inputs) batch_inputs = OrderedDict(batch_inputs) event_inputs = OrderedDict(event_inputs) sampled_vars = frozenset(event_inputs) p = random_tensor(be_inputs) p.data.requires_grad_(test_grad) q = p.sample(sampled_vars, sample_inputs) mq = materialize(q).reduce(ops.logaddexp, 'n') mq = mq.align(tuple(p.inputs)) assert_close(mq, p, atol=0.1, rtol=None) if test_grad: _, (p_data, mq_data) = align_tensors(p, mq) assert p_data.shape == mq_data.shape probe = torch.randn(p_data.shape) expected = grad((p_data.exp() * probe).sum(), [p.data])[0] actual = grad((mq_data.exp() * probe).sum(), [p.data])[0] assert_close(actual, expected, atol=0.1, rtol=None)
def eager_normal(loc, scale, value): affine = (loc - value) / scale assert isinstance(affine, Affine) real_inputs = OrderedDict((k, v) for k, v in affine.inputs.items() if v.dtype == 'real') assert not any(v.shape for v in real_inputs.values()) tensors = [affine.const] + [c for v, c in affine.coeffs.items()] inputs, tensors = align_tensors(*tensors) tensors = torch.broadcast_tensors(*tensors) const, coeffs = tensors[0], tensors[1:] dim = sum(d.num_elements for d in real_inputs.values()) loc = BlockVector(const.shape + (dim,)) loc[..., 0] = -const / coeffs[0] precision = BlockMatrix(const.shape + (dim, dim)) for i, (v1, c1) in enumerate(zip(real_inputs, coeffs)): for j, (v2, c2) in enumerate(zip(real_inputs, coeffs)): precision[..., i, j] = c1 * c2 loc = loc.as_tensor() precision = precision.as_tensor() log_prob = -0.5 * math.log(2 * math.pi) - scale.log() return log_prob + Gaussian(loc, precision, affine.inputs)
def eager_normal(loc, scale, value): affine = (loc - value) / scale if not affine.is_affine: return None real_inputs = OrderedDict( (k, v) for k, v in affine.inputs.items() if v.dtype == 'real') int_inputs = OrderedDict( (k, v) for k, v in affine.inputs.items() if v.dtype != 'real') assert not any(v.shape for v in real_inputs.values()) const = affine(**{k: 0. for k, v in real_inputs.items()}) coeffs = OrderedDict() for c in real_inputs.keys(): coeffs[c] = affine( **{k: 1. if c == k else 0. for k in real_inputs.keys()}) - const tensors = [const] + list(coeffs.values()) inputs, tensors = align_tensors(*tensors, expand=True) const, coeffs = tensors[0], tensors[1:] dim = sum(d.num_elements for d in real_inputs.values()) loc = BlockVector(const.shape + (dim, )) loc[..., 0] = -const / coeffs[0] precision = BlockMatrix(const.shape + (dim, dim)) for i, (v1, c1) in enumerate(zip(real_inputs, coeffs)): for j, (v2, c2) in enumerate(zip(real_inputs, coeffs)): precision[..., i, j] = c1 * c2 loc = loc.as_tensor() precision = precision.as_tensor() info_vec = precision.matmul(loc.unsqueeze(-1)).squeeze(-1) log_prob = -0.5 * math.log( 2 * math.pi) - scale.data.log() - 0.5 * (loc * info_vec).sum(-1) return Tensor(log_prob, int_inputs) + Gaussian(info_vec, precision, affine.inputs)
def eager_log_prob(cls, **params): inputs, tensors = align_tensors(*params.values()) params = dict(zip(params, tensors)) value = params.pop('value') data = cls.dist_class(**params).log_prob(value) return Tensor(data, inputs)
def eager_mvn(loc, scale_tril, value): assert len(loc.shape) == 1 assert len(scale_tril.shape) == 2 assert value.output == loc.output if not is_affine(loc) or not is_affine(value): return None # lazy # Extract an affine representation. eye = torch.eye(scale_tril.data.size(-1)).expand(scale_tril.data.shape) prec_sqrt = Tensor( eye.triangular_solve(scale_tril.data, upper=False).solution, scale_tril.inputs) affine = prec_sqrt @ (loc - value) const, coeffs = extract_affine(affine) if not isinstance(const, Tensor): return None # lazy if not all(isinstance(coeff, Tensor) for coeff, _ in coeffs.values()): return None # lazy # Compute log_prob using funsors. scale_diag = Tensor(scale_tril.data.diagonal(dim1=-1, dim2=-2), scale_tril.inputs) log_prob = (-0.5 * scale_diag.shape[0] * math.log(2 * math.pi) - scale_diag.log().sum() - 0.5 * (const**2).sum()) # Dovetail to avoid variable name collision in einsum. equations1 = [ ''.join(c if c in ',->' else chr(ord(c) * 2 - ord('a')) for c in eqn) for _, eqn in coeffs.values() ] equations2 = [ ''.join(c if c in ',->' else chr(ord(c) * 2 - ord('a') + 1) for c in eqn) for _, eqn in coeffs.values() ] real_inputs = OrderedDict( (k, v) for k, v in affine.inputs.items() if v.dtype == 'real') assert tuple(real_inputs) == tuple(coeffs) # Align and broadcast tensors. neg_const = -const tensors = [neg_const] + [coeff for coeff, _ in coeffs.values()] inputs, tensors = align_tensors(*tensors, expand=True) neg_const, coeffs = tensors[0], tensors[1:] dim = sum(d.num_elements for d in real_inputs.values()) batch_shape = neg_const.shape[:-1] info_vec = BlockVector(batch_shape + (dim, )) precision = BlockMatrix(batch_shape + (dim, dim)) offset1 = 0 for i1, (v1, c1) in enumerate(zip(real_inputs, coeffs)): size1 = real_inputs[v1].num_elements slice1 = slice(offset1, offset1 + size1) inputs1, output1 = equations1[i1].split('->') input11, input12 = inputs1.split(',') assert input11 == input12 + output1 info_vec[..., slice1] = torch.einsum( f'...{input11},...{output1}->...{input12}', c1, neg_const) \ .reshape(batch_shape + (size1,)) offset2 = 0 for i2, (v2, c2) in enumerate(zip(real_inputs, coeffs)): size2 = real_inputs[v2].num_elements slice2 = slice(offset2, offset2 + size2) inputs2, output2 = equations2[i2].split('->') input21, input22 = inputs2.split(',') assert input21 == input22 + output2 precision[..., slice1, slice2] = torch.einsum( f'...{input11},...{input22}{output1}->...{input12}{input22}', c1, c2) \ .reshape(batch_shape + (size1, size2)) offset2 += size2 offset1 += size1 info_vec = info_vec.as_tensor() precision = precision.as_tensor() inputs.update(real_inputs) return log_prob + Gaussian(info_vec, precision, inputs)
def _eager_subs_affine(self, subs, remaining_subs): # Extract an affine representation. affine = OrderedDict() for k, v in subs: const, coeffs = extract_affine(v) if (isinstance(const, Tensor) and all( isinstance(coeff, Tensor) for coeff, _ in coeffs.values())): affine[k] = const, coeffs else: remaining_subs += (k, v), if not affine: return reflect(Subs, self, remaining_subs) # Align integer dimensions. old_int_inputs = OrderedDict( (k, v) for k, v in self.inputs.items() if v.dtype != 'real') tensors = [ Tensor(self.info_vec, old_int_inputs), Tensor(self.precision, old_int_inputs) ] for const, coeffs in affine.values(): tensors.append(const) tensors.extend(coeff for coeff, _ in coeffs.values()) new_int_inputs, tensors = align_tensors(*tensors, expand=True) tensors = (Tensor(x, new_int_inputs) for x in tensors) old_info_vec = next(tensors).data old_precision = next(tensors).data for old_k, (const, coeffs) in affine.items(): const = next(tensors) for new_k, (coeff, eqn) in coeffs.items(): coeff = next(tensors) coeffs[new_k] = coeff, eqn affine[old_k] = const, coeffs batch_shape = old_info_vec.data.shape[:-1] # Align real dimensions. old_real_inputs = OrderedDict( (k, v) for k, v in self.inputs.items() if v.dtype == 'real') new_real_inputs = old_real_inputs.copy() for old_k, (const, coeffs) in affine.items(): del new_real_inputs[old_k] for new_k, (coeff, eqn) in coeffs.items(): new_shape = coeff.shape[:len(eqn.split('->')[0].split(',')[1])] new_real_inputs[new_k] = reals(*new_shape) old_offsets, old_dim = _compute_offsets(old_real_inputs) new_offsets, new_dim = _compute_offsets(new_real_inputs) new_inputs = new_int_inputs.copy() new_inputs.update(new_real_inputs) # Construct a blockwise affine representation of the substitution. subs_vector = BlockVector(batch_shape + (old_dim, )) subs_matrix = BlockMatrix(batch_shape + (new_dim, old_dim)) for old_k, old_offset in old_offsets.items(): old_size = old_real_inputs[old_k].num_elements old_slice = slice(old_offset, old_offset + old_size) if old_k in new_real_inputs: new_offset = new_offsets[old_k] new_slice = slice(new_offset, new_offset + old_size) subs_matrix[..., new_slice, old_slice] = \ torch.eye(old_size).expand(batch_shape + (-1, -1)) continue const, coeffs = affine[old_k] old_shape = old_real_inputs[old_k].shape assert const.data.shape == batch_shape + old_shape subs_vector[..., old_slice] = const.data.reshape(batch_shape + (old_size, )) for new_k, new_offset in new_offsets.items(): if new_k in coeffs: coeff, eqn = coeffs[new_k] new_size = new_real_inputs[new_k].num_elements new_slice = slice(new_offset, new_offset + new_size) assert coeff.shape == new_real_inputs[ new_k].shape + old_shape subs_matrix[..., new_slice, old_slice] = \ coeff.data.reshape(batch_shape + (new_size, old_size)) subs_vector = subs_vector.as_tensor() subs_matrix = subs_matrix.as_tensor() subs_matrix_t = subs_matrix.transpose(-1, -2) # Construct the new funsor. Suppose the old Gaussian funsor g has density # g(x) = < x | i - 1/2 P x> # Now define a new funsor f by substituting x = A y + B: # f(y) = g(A y + B) # = < A y + B | i - 1/2 P (A y + B) > # = < y | At (i - P B) - 1/2 At P A y > + < B | i - 1/2 P B > # =: < y | i' - 1/2 P' y > + C # where P' = At P A and i' = At (i - P B) parametrize a new Gaussian # and C = < B | i - 1/2 P B > parametrize a new Tensor. precision = subs_matrix @ old_precision @ subs_matrix_t info_vec = _mv(subs_matrix, old_info_vec - _mv(old_precision, subs_vector)) const = _vv(subs_vector, old_info_vec - 0.5 * _mv(old_precision, subs_vector)) result = Gaussian(info_vec, precision, new_inputs) + Tensor( const, new_int_inputs) return Subs(result, remaining_subs) if remaining_subs else result
def _eager_subs_real(self, subs, remaining_subs): # Broadcast all component tensors. subs = OrderedDict(subs) int_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype != 'real') tensors = [ Tensor(self.info_vec, int_inputs), Tensor(self.precision, int_inputs) ] tensors.extend(subs.values()) int_inputs, tensors = align_tensors(*tensors) batch_dim = tensors[0].dim() - 1 batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors)) (info_vec, precision), values = tensors[:2], tensors[2:] offsets, event_size = _compute_offsets(self.inputs) slices = [(k, slice(offset, offset + self.inputs[k].num_elements)) for k, offset in offsets.items()] # Expand all substituted values. values = OrderedDict(zip(subs, values)) for k, value in values.items(): value = value.reshape(value.shape[:batch_dim] + (-1, )) if not torch._C._get_tracing_state(): assert value.size(-1) == self.inputs[k].num_elements values[k] = value.expand(batch_shape + value.shape[-1:]) # Try to perform a complete substitution of all real variables, resulting in a Tensor. if all(k in subs for k, d in self.inputs.items() if d.dtype == 'real'): # Form the concatenated value. value = BlockVector(batch_shape + (event_size, )) for k, i in slices: if k in values: value[..., i] = values[k] value = value.as_tensor() # Evaluate the non-normalized log density. result = _vv(value, info_vec - 0.5 * _mv(precision, value)) result = Tensor(result, int_inputs) assert result.output == reals() return Subs(result, remaining_subs) if remaining_subs else result # Perform a partial substution of a subset of real variables, resulting in a Joint. # We split real inputs into two sets: a for the preserved and b for the substituted. b = frozenset(k for k, v in subs.items()) a = frozenset(k for k, d in self.inputs.items() if d.dtype == 'real' and k not in b) prec_aa = torch.cat([ torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in a], dim=-1) for k1, i1 in slices if k1 in a ], dim=-2) prec_ab = torch.cat([ torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in b], dim=-1) for k1, i1 in slices if k1 in a ], dim=-2) prec_bb = torch.cat([ torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in b], dim=-1) for k1, i1 in slices if k1 in b ], dim=-2) info_a = torch.cat([info_vec[..., i] for k, i in slices if k in a], dim=-1) info_b = torch.cat([info_vec[..., i] for k, i in slices if k in b], dim=-1) value_b = torch.cat([values[k] for k, i in slices if k in b], dim=-1) info_vec = info_a - _mv(prec_ab, value_b) log_scale = _vv(value_b, info_b - 0.5 * _mv(prec_bb, value_b)) precision = prec_aa.expand(info_vec.shape + (-1, )) inputs = int_inputs.copy() for k, d in self.inputs.items(): if k not in subs: inputs[k] = d result = Gaussian(info_vec, precision, inputs) + Tensor( log_scale, int_inputs) return Subs(result, remaining_subs) if remaining_subs else result
def eager_subs(self, subs): assert isinstance(subs, tuple) subs = tuple((k, materialize(to_funsor(v, self.inputs[k]))) for k, v in subs if k in self.inputs) if not subs: return self # Constants and Variables are eagerly substituted; # everything else is lazily substituted. lazy_subs = tuple((k, v) for k, v in subs if not isinstance(v, (Number, Tensor, Variable))) var_subs = tuple((k, v) for k, v in subs if isinstance(v, Variable)) int_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor)) if v.dtype != 'real') real_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor)) if v.dtype == 'real') if not (var_subs or int_subs or real_subs): return reflect(Subs, self, lazy_subs) # First perform any variable substitutions. if var_subs: rename = {k: v.name for k, v in var_subs} inputs = OrderedDict( (rename.get(k, k), d) for k, d in self.inputs.items()) if len(inputs) != len(self.inputs): raise ValueError("Variable substitution name conflict") var_result = Gaussian(self.loc, self.precision, inputs) return Subs(var_result, int_subs + real_subs + lazy_subs) # Next perform any integer substitution, i.e. slicing into a batch. if int_subs: int_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype != 'real') real_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype == 'real') tensors = [self.loc, self.precision] funsors = [Subs(Tensor(x, int_inputs), int_subs) for x in tensors] inputs = funsors[0].inputs.copy() inputs.update(real_inputs) int_result = Gaussian(funsors[0].data, funsors[1].data, inputs) return Subs(int_result, real_subs + lazy_subs) # Try to perform a complete substitution of all real variables, resulting in a Tensor. real_subs = OrderedDict(subs) assert real_subs and not int_subs if all(k in real_subs for k, d in self.inputs.items() if d.dtype == 'real'): # Broadcast all component tensors. int_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype != 'real') tensors = [ Tensor(self.loc, int_inputs), Tensor(self.precision, int_inputs) ] tensors.extend(real_subs.values()) inputs, tensors = align_tensors(*tensors) batch_dim = tensors[0].dim() - 1 batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors)) (loc, precision), values = tensors[:2], tensors[2:] # Form the concatenated value. offsets, event_size = _compute_offsets(self.inputs) value = BlockVector(batch_shape + (event_size, )) for k, value_k in zip(real_subs, values): offset = offsets[k] value_k = value_k.reshape(value_k.shape[:batch_dim] + (-1, )) if not torch._C._get_tracing_state(): assert value_k.size(-1) == self.inputs[k].num_elements value_k = value_k.expand(batch_shape + value_k.shape[-1:]) value[..., offset:offset + self.inputs[k].num_elements] = value_k value = value.as_tensor() # Evaluate the non-normalized log density. result = -0.5 * _vmv(precision, value - loc) result = Tensor(result, inputs) assert result.output == reals() return Subs(result, lazy_subs) # Perform a partial substution of a subset of real variables, resulting in a Joint. # See "The Matrix Cookbook" (November 15, 2012) ss. 8.1.3 eq. 353. # http://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf raise NotImplementedError( 'TODO implement partial substitution of real variables')
def eager_subs(self, subs): assert isinstance(subs, tuple) subs = tuple( (k, v if isinstance(v, (Variable, Slice)) else materialize(v)) for k, v in subs if k in self.inputs) if not subs: return self # Constants and Variables are eagerly substituted; # everything else is lazily substituted. lazy_subs = tuple( (k, v) for k, v in subs if not isinstance(v, (Number, Tensor, Variable, Slice))) var_subs = tuple((k, v) for k, v in subs if isinstance(v, Variable)) int_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor, Slice)) if v.dtype != 'real') real_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor)) if v.dtype == 'real') if not (var_subs or int_subs or real_subs): return reflect(Subs, self, lazy_subs) # First perform any variable substitutions. if var_subs: rename = {k: v.name for k, v in var_subs} inputs = OrderedDict( (rename.get(k, k), d) for k, d in self.inputs.items()) if len(inputs) != len(self.inputs): raise ValueError("Variable substitution name conflict") var_result = Gaussian(self.info_vec, self.precision, inputs) return Subs(var_result, int_subs + real_subs + lazy_subs) # Next perform any integer substitution, i.e. slicing into a batch. if int_subs: int_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype != 'real') real_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype == 'real') tensors = [self.info_vec, self.precision] funsors = [Subs(Tensor(x, int_inputs), int_subs) for x in tensors] inputs = funsors[0].inputs.copy() inputs.update(real_inputs) int_result = Gaussian(funsors[0].data, funsors[1].data, inputs) return Subs(int_result, real_subs + lazy_subs) # Broadcast all component tensors. real_subs = OrderedDict(subs) assert real_subs and not int_subs int_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype != 'real') tensors = [ Tensor(self.info_vec, int_inputs), Tensor(self.precision, int_inputs) ] tensors.extend(real_subs.values()) int_inputs, tensors = align_tensors(*tensors) batch_dim = tensors[0].dim() - 1 batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors)) (info_vec, precision), values = tensors[:2], tensors[2:] offsets, event_size = _compute_offsets(self.inputs) slices = [(k, slice(offset, offset + self.inputs[k].num_elements)) for k, offset in offsets.items()] # Expand all substituted values. values = OrderedDict(zip(real_subs, values)) for k, value in values.items(): value = value.reshape(value.shape[:batch_dim] + (-1, )) if not torch._C._get_tracing_state(): assert value.size(-1) == self.inputs[k].num_elements values[k] = value.expand(batch_shape + value.shape[-1:]) # Try to perform a complete substitution of all real variables, resulting in a Tensor. if all(k in real_subs for k, d in self.inputs.items() if d.dtype == 'real'): # Form the concatenated value. value = BlockVector(batch_shape + (event_size, )) for k, i in slices: if k in values: value[..., i] = values[k] value = value.as_tensor() # Evaluate the non-normalized log density. result = _vv(value, info_vec - 0.5 * _mv(precision, value)) result = Tensor(result, int_inputs) assert result.output == reals() return Subs(result, lazy_subs) # Perform a partial substution of a subset of real variables, resulting in a Joint. # We split real inputs into two sets: a for the preserved and b for the substituted. b = frozenset(k for k, v in real_subs.items()) a = frozenset(k for k, d in self.inputs.items() if d.dtype == 'real' and k not in b) prec_aa = torch.cat([ torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in a], dim=-1) for k1, i1 in slices if k1 in a ], dim=-2) prec_ab = torch.cat([ torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in b], dim=-1) for k1, i1 in slices if k1 in a ], dim=-2) prec_bb = torch.cat([ torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in b], dim=-1) for k1, i1 in slices if k1 in b ], dim=-2) info_a = torch.cat([info_vec[..., i] for k, i in slices if k in a], dim=-1) info_b = torch.cat([info_vec[..., i] for k, i in slices if k in b], dim=-1) value_b = torch.cat([values[k] for k, i in slices if k in b], dim=-1) info_vec = info_a - _mv(prec_ab, value_b) log_scale = _vv(value_b, info_b - 0.5 * _mv(prec_bb, value_b)) precision = prec_aa.expand(info_vec.shape + (-1, )) inputs = int_inputs.copy() for k, d in self.inputs.items(): if k not in real_subs: inputs[k] = d return Gaussian(info_vec, precision, inputs) + Tensor( log_scale, int_inputs)