def eager_subs(self, subs): assert isinstance(subs, tuple) subs = tuple( (k, v if isinstance(v, (Variable, Slice)) else materialize(v)) for k, v in subs if k in self.inputs) if not subs: return self # Constants and Affine funsors are eagerly substituted; # everything else is lazily substituted. lazy_subs = tuple( (k, v) for k, v in subs if not isinstance(v, (Number, Tensor, Variable, Slice)) and not (is_affine(v) and affine_inputs(v))) var_subs = tuple((k, v) for k, v in subs if isinstance(v, Variable)) int_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor, Slice)) if v.dtype != 'real') real_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor)) if v.dtype == 'real') affine_subs = tuple((k, v) for k, v in subs if is_affine(v) and affine_inputs(v) and not isinstance(v, Variable)) if var_subs: return self._eager_subs_var( var_subs, int_subs + real_subs + affine_subs + lazy_subs) if int_subs: return self._eager_subs_int(int_subs, real_subs + affine_subs + lazy_subs) if real_subs: return self._eager_subs_real(real_subs, affine_subs + lazy_subs) if affine_subs: return self._eager_subs_affine(affine_subs, lazy_subs) return reflect(Subs, self, lazy_subs)
def eager_normal(loc, scale, value): assert loc.output == Real assert scale.output == Real assert value.output == Real if not is_affine(loc) or not is_affine(value): return None # lazy info_vec = ops.new_zeros(scale.data, scale.data.shape + (1, )) precision = ops.pow(scale.data, -2).reshape(scale.data.shape + (1, 1)) log_prob = -0.5 * math.log(2 * math.pi) - ops.log(scale).sum() inputs = scale.inputs.copy() var = gensym('value') inputs[var] = Real gaussian = log_prob + Gaussian(info_vec, precision, inputs) return gaussian(**{var: value - loc})
def eager_mvn(loc, scale_tril, value): assert len(loc.shape) == 1 assert len(scale_tril.shape) == 2 assert value.output == loc.output if not is_affine(loc) or not is_affine(value): return None # lazy info_vec = scale_tril.data.new_zeros(scale_tril.data.shape[:-1]) precision = ops.cholesky_inverse(scale_tril.data) scale_diag = Tensor(scale_tril.data.diagonal(dim1=-1, dim2=-2), scale_tril.inputs) log_prob = -0.5 * scale_diag.shape[0] * math.log(2 * math.pi) - scale_diag.log().sum() inputs = scale_tril.inputs.copy() var = gensym('value') inputs[var] = reals(scale_diag.shape[0]) gaussian = log_prob + Gaussian(info_vec, precision, inputs) return gaussian(**{var: value - loc})
def test_smoke(expr, expected_type): t = Tensor(randn(2, 3), OrderedDict([('i', bint(2)), ('j', bint(3))])) assert isinstance(t, Tensor) n = Number(2.) assert isinstance(n, Number) x = Variable('x', reals()) assert isinstance(x, Variable) y = Variable('y', reals()) assert isinstance(y, Variable) result = eval(expr) assert isinstance(result, expected_type) assert is_affine(result)
def test_affine_subs(expr, expected_type, expected_inputs): expected_output = reals() t = Tensor(randn(2, 3), OrderedDict([('i', bint(2)), ('j', bint(3))])) assert isinstance(t, Tensor) n = Number(2.) assert isinstance(n, Number) x = Variable('x', reals()) assert isinstance(x, Variable) y = Variable('y', reals()) assert isinstance(y, Variable) z = Variable('z', reals()) assert isinstance(z, Variable) result = eval(expr) assert isinstance(result, expected_type) check_funsor(result, expected_inputs, expected_output) assert is_affine(result)
def test_extract_affine(expr): x = eval(expr) assert is_affine(x) assert isinstance(x, (Unary, Contraction, Einsum)) real_inputs = OrderedDict((k, d) for k, d in x.inputs.items() if d.dtype == 'real') const, coeffs = extract_affine(x) assert isinstance(const, Tensor) assert const.shape == x.shape assert list(coeffs) == list(real_inputs) for name, (coeff, eqn) in coeffs.items(): assert isinstance(name, str) assert isinstance(coeff, Tensor) assert isinstance(eqn, str) subs = {k: random_tensor(OrderedDict(), d) for k, d in real_inputs.items()} expected = x(**subs) assert isinstance(expected, Tensor) actual = const + sum(Einsum(eqn, (coeff, subs[k])) for k, (coeff, eqn) in coeffs.items()) assert isinstance(actual, Tensor) assert_close(actual, expected)
def test_not_is_affine(expr): x = eval(expr) assert not is_affine(x)
def eager_mvn(loc, scale_tril, value): assert len(loc.shape) == 1 assert len(scale_tril.shape) == 2 assert value.output == loc.output if not is_affine(loc) or not is_affine(value): return None # lazy # Extract an affine representation. eye = torch.eye(scale_tril.data.size(-1)).expand(scale_tril.data.shape) prec_sqrt = Tensor( eye.triangular_solve(scale_tril.data, upper=False).solution, scale_tril.inputs) affine = prec_sqrt @ (loc - value) const, coeffs = extract_affine(affine) if not isinstance(const, Tensor): return None # lazy if not all(isinstance(coeff, Tensor) for coeff, _ in coeffs.values()): return None # lazy # Compute log_prob using funsors. scale_diag = Tensor(scale_tril.data.diagonal(dim1=-1, dim2=-2), scale_tril.inputs) log_prob = (-0.5 * scale_diag.shape[0] * math.log(2 * math.pi) - scale_diag.log().sum() - 0.5 * (const**2).sum()) # Dovetail to avoid variable name collision in einsum. equations1 = [ ''.join(c if c in ',->' else chr(ord(c) * 2 - ord('a')) for c in eqn) for _, eqn in coeffs.values() ] equations2 = [ ''.join(c if c in ',->' else chr(ord(c) * 2 - ord('a') + 1) for c in eqn) for _, eqn in coeffs.values() ] real_inputs = OrderedDict( (k, v) for k, v in affine.inputs.items() if v.dtype == 'real') assert tuple(real_inputs) == tuple(coeffs) # Align and broadcast tensors. neg_const = -const tensors = [neg_const] + [coeff for coeff, _ in coeffs.values()] inputs, tensors = align_tensors(*tensors, expand=True) neg_const, coeffs = tensors[0], tensors[1:] dim = sum(d.num_elements for d in real_inputs.values()) batch_shape = neg_const.shape[:-1] info_vec = BlockVector(batch_shape + (dim, )) precision = BlockMatrix(batch_shape + (dim, dim)) offset1 = 0 for i1, (v1, c1) in enumerate(zip(real_inputs, coeffs)): size1 = real_inputs[v1].num_elements slice1 = slice(offset1, offset1 + size1) inputs1, output1 = equations1[i1].split('->') input11, input12 = inputs1.split(',') assert input11 == input12 + output1 info_vec[..., slice1] = torch.einsum( f'...{input11},...{output1}->...{input12}', c1, neg_const) \ .reshape(batch_shape + (size1,)) offset2 = 0 for i2, (v2, c2) in enumerate(zip(real_inputs, coeffs)): size2 = real_inputs[v2].num_elements slice2 = slice(offset2, offset2 + size2) inputs2, output2 = equations2[i2].split('->') input21, input22 = inputs2.split(',') assert input21 == input22 + output2 precision[..., slice1, slice2] = torch.einsum( f'...{input11},...{input22}{output1}->...{input12}{input22}', c1, c2) \ .reshape(batch_shape + (size1, size2)) offset2 += size2 offset1 += size1 info_vec = info_vec.as_tensor() precision = precision.as_tensor() inputs.update(real_inputs) return log_prob + Gaussian(info_vec, precision, inputs)