def test_block_matrix(): shape = (10, 10) expected = torch.zeros(shape) actual = BlockMatrix(shape) expected[1, 1] = torch.randn(()) actual[1, 1] = expected[1, 1] expected[1, 3:5] = torch.randn(2) actual[1, 3:5] = expected[1, 3:5] expected[3:5, 1] = torch.randn(2) actual[3:5, 1] = expected[3:5, 1] expected[3:5, 3:5] = torch.randn(2, 2) actual[3:5, 3:5] = expected[3:5, 3:5] assert_close(actual.as_tensor(), expected)
def test_block_matrix_batched(batch_shape): shape = batch_shape + (10, 10) expected = torch.zeros(shape) actual = BlockMatrix(shape) expected[..., 1, 1] = torch.randn(batch_shape) actual[..., 1, 1] = expected[..., 1, 1] expected[..., 1, 3:5] = torch.randn(batch_shape + (2,)) actual[..., 1, 3:5] = expected[..., 1, 3:5] expected[..., 3:5, 1] = torch.randn(batch_shape + (2,)) actual[..., 3:5, 1] = expected[..., 3:5, 1] expected[..., 3:5, 3:5] = torch.randn(batch_shape + (2, 2)) actual[..., 3:5, 3:5] = expected[..., 3:5, 3:5] assert_close(actual.as_tensor(), expected)
def test_block_matrix(sparse): shape = (10, 10) expected = zeros(shape) actual = BlockMatrix(shape) expected[1, 1] = randn(()) actual[1, 1] = expected[1, 1] if not sparse: expected[1, 3:5] = randn((2, )) actual[1, 3:5] = expected[1, 3:5] expected[3:5, 1] = randn((2, )) actual[3:5, 1] = expected[3:5, 1] expected[3:5, 3:5] = randn((2, 2)) actual[3:5, 3:5] = expected[3:5, 3:5] assert_close(actual.as_tensor(), expected)
def eager_normal(loc, scale, value): affine = (loc - value) / scale assert isinstance(affine, Affine) real_inputs = OrderedDict((k, v) for k, v in affine.inputs.items() if v.dtype == 'real') assert not any(v.shape for v in real_inputs.values()) tensors = [affine.const] + [c for v, c in affine.coeffs.items()] inputs, tensors = align_tensors(*tensors) tensors = torch.broadcast_tensors(*tensors) const, coeffs = tensors[0], tensors[1:] dim = sum(d.num_elements for d in real_inputs.values()) loc = BlockVector(const.shape + (dim,)) loc[..., 0] = -const / coeffs[0] precision = BlockMatrix(const.shape + (dim, dim)) for i, (v1, c1) in enumerate(zip(real_inputs, coeffs)): for j, (v2, c2) in enumerate(zip(real_inputs, coeffs)): precision[..., i, j] = c1 * c2 loc = loc.as_tensor() precision = precision.as_tensor() log_prob = -0.5 * math.log(2 * math.pi) - scale.log() return log_prob + Gaussian(loc, precision, affine.inputs)
def eager_normal(loc, scale, value): affine = (loc - value) / scale if not affine.is_affine: return None real_inputs = OrderedDict( (k, v) for k, v in affine.inputs.items() if v.dtype == 'real') int_inputs = OrderedDict( (k, v) for k, v in affine.inputs.items() if v.dtype != 'real') assert not any(v.shape for v in real_inputs.values()) const = affine(**{k: 0. for k, v in real_inputs.items()}) coeffs = OrderedDict() for c in real_inputs.keys(): coeffs[c] = affine( **{k: 1. if c == k else 0. for k in real_inputs.keys()}) - const tensors = [const] + list(coeffs.values()) inputs, tensors = align_tensors(*tensors, expand=True) const, coeffs = tensors[0], tensors[1:] dim = sum(d.num_elements for d in real_inputs.values()) loc = BlockVector(const.shape + (dim, )) loc[..., 0] = -const / coeffs[0] precision = BlockMatrix(const.shape + (dim, dim)) for i, (v1, c1) in enumerate(zip(real_inputs, coeffs)): for j, (v2, c2) in enumerate(zip(real_inputs, coeffs)): precision[..., i, j] = c1 * c2 loc = loc.as_tensor() precision = precision.as_tensor() info_vec = precision.matmul(loc.unsqueeze(-1)).squeeze(-1) log_prob = -0.5 * math.log( 2 * math.pi) - scale.data.log() - 0.5 * (loc * info_vec).sum(-1) return Tensor(log_prob, int_inputs) + Gaussian(info_vec, precision, affine.inputs)
def eager_mvn(loc, scale_tril, value): assert len(loc.shape) == 1 assert len(scale_tril.shape) == 2 assert value.output == loc.output if not is_affine(loc) or not is_affine(value): return None # lazy # Extract an affine representation. eye = torch.eye(scale_tril.data.size(-1)).expand(scale_tril.data.shape) prec_sqrt = Tensor( eye.triangular_solve(scale_tril.data, upper=False).solution, scale_tril.inputs) affine = prec_sqrt @ (loc - value) const, coeffs = extract_affine(affine) if not isinstance(const, Tensor): return None # lazy if not all(isinstance(coeff, Tensor) for coeff, _ in coeffs.values()): return None # lazy # Compute log_prob using funsors. scale_diag = Tensor(scale_tril.data.diagonal(dim1=-1, dim2=-2), scale_tril.inputs) log_prob = (-0.5 * scale_diag.shape[0] * math.log(2 * math.pi) - scale_diag.log().sum() - 0.5 * (const**2).sum()) # Dovetail to avoid variable name collision in einsum. equations1 = [ ''.join(c if c in ',->' else chr(ord(c) * 2 - ord('a')) for c in eqn) for _, eqn in coeffs.values() ] equations2 = [ ''.join(c if c in ',->' else chr(ord(c) * 2 - ord('a') + 1) for c in eqn) for _, eqn in coeffs.values() ] real_inputs = OrderedDict( (k, v) for k, v in affine.inputs.items() if v.dtype == 'real') assert tuple(real_inputs) == tuple(coeffs) # Align and broadcast tensors. neg_const = -const tensors = [neg_const] + [coeff for coeff, _ in coeffs.values()] inputs, tensors = align_tensors(*tensors, expand=True) neg_const, coeffs = tensors[0], tensors[1:] dim = sum(d.num_elements for d in real_inputs.values()) batch_shape = neg_const.shape[:-1] info_vec = BlockVector(batch_shape + (dim, )) precision = BlockMatrix(batch_shape + (dim, dim)) offset1 = 0 for i1, (v1, c1) in enumerate(zip(real_inputs, coeffs)): size1 = real_inputs[v1].num_elements slice1 = slice(offset1, offset1 + size1) inputs1, output1 = equations1[i1].split('->') input11, input12 = inputs1.split(',') assert input11 == input12 + output1 info_vec[..., slice1] = torch.einsum( f'...{input11},...{output1}->...{input12}', c1, neg_const) \ .reshape(batch_shape + (size1,)) offset2 = 0 for i2, (v2, c2) in enumerate(zip(real_inputs, coeffs)): size2 = real_inputs[v2].num_elements slice2 = slice(offset2, offset2 + size2) inputs2, output2 = equations2[i2].split('->') input21, input22 = inputs2.split(',') assert input21 == input22 + output2 precision[..., slice1, slice2] = torch.einsum( f'...{input11},...{input22}{output1}->...{input12}{input22}', c1, c2) \ .reshape(batch_shape + (size1, size2)) offset2 += size2 offset1 += size1 info_vec = info_vec.as_tensor() precision = precision.as_tensor() inputs.update(real_inputs) return log_prob + Gaussian(info_vec, precision, inputs)