def test_extract_affine(expr): x = eval(expr) assert isinstance(x, (Contraction, Einsum)) real_inputs = OrderedDict( (k, d) for k, d in x.inputs.items() if d.dtype == 'real') const, coeffs = extract_affine(x) assert isinstance(const, Tensor) assert const.shape == x.shape assert list(coeffs) == list(real_inputs) for name, (coeff, eqn) in coeffs.items(): assert isinstance(name, str) assert isinstance(coeff, Tensor) assert isinstance(eqn, str) subs = {k: random_tensor(OrderedDict(), d) for k, d in real_inputs.items()} expected = x(**subs) assert isinstance(expected, Tensor) actual = const + sum( Einsum(eqn, (coeff, subs[k])) for k, (coeff, eqn) in coeffs.items()) assert isinstance(actual, Tensor) assert_close(actual, expected)
def _eager_subs_affine(self, subs, remaining_subs): # Extract an affine representation. affine = OrderedDict() for k, v in subs: const, coeffs = extract_affine(v) if (isinstance(const, Tensor) and all( isinstance(coeff, Tensor) for coeff, _ in coeffs.values())): affine[k] = const, coeffs else: remaining_subs += (k, v), if not affine: return reflect(Subs, self, remaining_subs) # Align integer dimensions. old_int_inputs = OrderedDict( (k, v) for k, v in self.inputs.items() if v.dtype != 'real') tensors = [ Tensor(self.info_vec, old_int_inputs), Tensor(self.precision, old_int_inputs) ] for const, coeffs in affine.values(): tensors.append(const) tensors.extend(coeff for coeff, _ in coeffs.values()) new_int_inputs, tensors = align_tensors(*tensors, expand=True) tensors = (Tensor(x, new_int_inputs) for x in tensors) old_info_vec = next(tensors).data old_precision = next(tensors).data for old_k, (const, coeffs) in affine.items(): const = next(tensors) for new_k, (coeff, eqn) in coeffs.items(): coeff = next(tensors) coeffs[new_k] = coeff, eqn affine[old_k] = const, coeffs batch_shape = old_info_vec.shape[:-1] # Align real dimensions. old_real_inputs = OrderedDict( (k, v) for k, v in self.inputs.items() if v.dtype == 'real') new_real_inputs = old_real_inputs.copy() for old_k, (const, coeffs) in affine.items(): del new_real_inputs[old_k] for new_k, (coeff, eqn) in coeffs.items(): new_shape = coeff.shape[:len(eqn.split('->')[0].split(',')[1])] new_real_inputs[new_k] = Reals[new_shape] old_offsets, old_dim = _compute_offsets(old_real_inputs) new_offsets, new_dim = _compute_offsets(new_real_inputs) new_inputs = new_int_inputs.copy() new_inputs.update(new_real_inputs) # Construct a blockwise affine representation of the substitution. subs_vector = BlockVector(batch_shape + (old_dim, )) subs_matrix = BlockMatrix(batch_shape + (new_dim, old_dim)) for old_k, old_offset in old_offsets.items(): old_size = old_real_inputs[old_k].num_elements old_slice = slice(old_offset, old_offset + old_size) if old_k in new_real_inputs: new_offset = new_offsets[old_k] new_slice = slice(new_offset, new_offset + old_size) subs_matrix[..., new_slice, old_slice] = \ ops.new_eye(self.info_vec, batch_shape + (old_size,)) continue const, coeffs = affine[old_k] old_shape = old_real_inputs[old_k].shape assert const.data.shape == batch_shape + old_shape subs_vector[..., old_slice] = const.data.reshape(batch_shape + (old_size, )) for new_k, new_offset in new_offsets.items(): if new_k in coeffs: coeff, eqn = coeffs[new_k] new_size = new_real_inputs[new_k].num_elements new_slice = slice(new_offset, new_offset + new_size) assert coeff.shape == new_real_inputs[ new_k].shape + old_shape subs_matrix[..., new_slice, old_slice] = \ coeff.data.reshape(batch_shape + (new_size, old_size)) subs_vector = subs_vector.as_tensor() subs_matrix = subs_matrix.as_tensor() subs_matrix_t = ops.transpose(subs_matrix, -1, -2) # Construct the new funsor. Suppose the old Gaussian funsor g has density # g(x) = < x | i - 1/2 P x> # Now define a new funsor f by substituting x = A y + B: # f(y) = g(A y + B) # = < A y + B | i - 1/2 P (A y + B) > # = < y | At (i - P B) - 1/2 At P A y > + < B | i - 1/2 P B > # =: < y | i' - 1/2 P' y > + C # where P' = At P A and i' = At (i - P B) parametrize a new Gaussian # and C = < B | i - 1/2 P B > parametrize a new Tensor. precision = subs_matrix @ old_precision @ subs_matrix_t info_vec = _mv(subs_matrix, old_info_vec - _mv(old_precision, subs_vector)) const = _vv(subs_vector, old_info_vec - 0.5 * _mv(old_precision, subs_vector)) result = Gaussian(info_vec, precision, new_inputs) + Tensor( const, new_int_inputs) return Subs(result, remaining_subs) if remaining_subs else result
def eager_mvn(loc, scale_tril, value): assert len(loc.shape) == 1 assert len(scale_tril.shape) == 2 assert value.output == loc.output if not is_affine(loc) or not is_affine(value): return None # lazy # Extract an affine representation. eye = torch.eye(scale_tril.data.size(-1)).expand(scale_tril.data.shape) prec_sqrt = Tensor( eye.triangular_solve(scale_tril.data, upper=False).solution, scale_tril.inputs) affine = prec_sqrt @ (loc - value) const, coeffs = extract_affine(affine) if not isinstance(const, Tensor): return None # lazy if not all(isinstance(coeff, Tensor) for coeff, _ in coeffs.values()): return None # lazy # Compute log_prob using funsors. scale_diag = Tensor(scale_tril.data.diagonal(dim1=-1, dim2=-2), scale_tril.inputs) log_prob = (-0.5 * scale_diag.shape[0] * math.log(2 * math.pi) - scale_diag.log().sum() - 0.5 * (const**2).sum()) # Dovetail to avoid variable name collision in einsum. equations1 = [ ''.join(c if c in ',->' else chr(ord(c) * 2 - ord('a')) for c in eqn) for _, eqn in coeffs.values() ] equations2 = [ ''.join(c if c in ',->' else chr(ord(c) * 2 - ord('a') + 1) for c in eqn) for _, eqn in coeffs.values() ] real_inputs = OrderedDict( (k, v) for k, v in affine.inputs.items() if v.dtype == 'real') assert tuple(real_inputs) == tuple(coeffs) # Align and broadcast tensors. neg_const = -const tensors = [neg_const] + [coeff for coeff, _ in coeffs.values()] inputs, tensors = align_tensors(*tensors, expand=True) neg_const, coeffs = tensors[0], tensors[1:] dim = sum(d.num_elements for d in real_inputs.values()) batch_shape = neg_const.shape[:-1] info_vec = BlockVector(batch_shape + (dim, )) precision = BlockMatrix(batch_shape + (dim, dim)) offset1 = 0 for i1, (v1, c1) in enumerate(zip(real_inputs, coeffs)): size1 = real_inputs[v1].num_elements slice1 = slice(offset1, offset1 + size1) inputs1, output1 = equations1[i1].split('->') input11, input12 = inputs1.split(',') assert input11 == input12 + output1 info_vec[..., slice1] = torch.einsum( f'...{input11},...{output1}->...{input12}', c1, neg_const) \ .reshape(batch_shape + (size1,)) offset2 = 0 for i2, (v2, c2) in enumerate(zip(real_inputs, coeffs)): size2 = real_inputs[v2].num_elements slice2 = slice(offset2, offset2 + size2) inputs2, output2 = equations2[i2].split('->') input21, input22 = inputs2.split(',') assert input21 == input22 + output2 precision[..., slice1, slice2] = torch.einsum( f'...{input11},...{input22}{output1}->...{input12}{input22}', c1, c2) \ .reshape(batch_shape + (size1, size2)) offset2 += size2 offset1 += size1 info_vec = info_vec.as_tensor() precision = precision.as_tensor() inputs.update(real_inputs) return log_prob + Gaussian(info_vec, precision, inputs)