コード例 #1
0
def test_extract_affine(expr):
    x = eval(expr)
    assert isinstance(x, (Contraction, Einsum))
    real_inputs = OrderedDict(
        (k, d) for k, d in x.inputs.items() if d.dtype == 'real')

    const, coeffs = extract_affine(x)
    assert isinstance(const, Tensor)
    assert const.shape == x.shape
    assert list(coeffs) == list(real_inputs)
    for name, (coeff, eqn) in coeffs.items():
        assert isinstance(name, str)
        assert isinstance(coeff, Tensor)
        assert isinstance(eqn, str)

    subs = {k: random_tensor(OrderedDict(), d) for k, d in real_inputs.items()}
    expected = x(**subs)
    assert isinstance(expected, Tensor)

    actual = const + sum(
        Einsum(eqn, (coeff, subs[k])) for k, (coeff, eqn) in coeffs.items())
    assert isinstance(actual, Tensor)
    assert_close(actual, expected)
コード例 #2
0
    def _eager_subs_affine(self, subs, remaining_subs):
        # Extract an affine representation.
        affine = OrderedDict()
        for k, v in subs:
            const, coeffs = extract_affine(v)
            if (isinstance(const, Tensor) and all(
                    isinstance(coeff, Tensor)
                    for coeff, _ in coeffs.values())):
                affine[k] = const, coeffs
            else:
                remaining_subs += (k, v),
        if not affine:
            return reflect(Subs, self, remaining_subs)

        # Align integer dimensions.
        old_int_inputs = OrderedDict(
            (k, v) for k, v in self.inputs.items() if v.dtype != 'real')
        tensors = [
            Tensor(self.info_vec, old_int_inputs),
            Tensor(self.precision, old_int_inputs)
        ]
        for const, coeffs in affine.values():
            tensors.append(const)
            tensors.extend(coeff for coeff, _ in coeffs.values())
        new_int_inputs, tensors = align_tensors(*tensors, expand=True)
        tensors = (Tensor(x, new_int_inputs) for x in tensors)
        old_info_vec = next(tensors).data
        old_precision = next(tensors).data
        for old_k, (const, coeffs) in affine.items():
            const = next(tensors)
            for new_k, (coeff, eqn) in coeffs.items():
                coeff = next(tensors)
                coeffs[new_k] = coeff, eqn
            affine[old_k] = const, coeffs
        batch_shape = old_info_vec.shape[:-1]

        # Align real dimensions.
        old_real_inputs = OrderedDict(
            (k, v) for k, v in self.inputs.items() if v.dtype == 'real')
        new_real_inputs = old_real_inputs.copy()
        for old_k, (const, coeffs) in affine.items():
            del new_real_inputs[old_k]
            for new_k, (coeff, eqn) in coeffs.items():
                new_shape = coeff.shape[:len(eqn.split('->')[0].split(',')[1])]
                new_real_inputs[new_k] = Reals[new_shape]
        old_offsets, old_dim = _compute_offsets(old_real_inputs)
        new_offsets, new_dim = _compute_offsets(new_real_inputs)
        new_inputs = new_int_inputs.copy()
        new_inputs.update(new_real_inputs)

        # Construct a blockwise affine representation of the substitution.
        subs_vector = BlockVector(batch_shape + (old_dim, ))
        subs_matrix = BlockMatrix(batch_shape + (new_dim, old_dim))
        for old_k, old_offset in old_offsets.items():
            old_size = old_real_inputs[old_k].num_elements
            old_slice = slice(old_offset, old_offset + old_size)
            if old_k in new_real_inputs:
                new_offset = new_offsets[old_k]
                new_slice = slice(new_offset, new_offset + old_size)
                subs_matrix[..., new_slice, old_slice] = \
                    ops.new_eye(self.info_vec, batch_shape + (old_size,))
                continue
            const, coeffs = affine[old_k]
            old_shape = old_real_inputs[old_k].shape
            assert const.data.shape == batch_shape + old_shape
            subs_vector[..., old_slice] = const.data.reshape(batch_shape +
                                                             (old_size, ))
            for new_k, new_offset in new_offsets.items():
                if new_k in coeffs:
                    coeff, eqn = coeffs[new_k]
                    new_size = new_real_inputs[new_k].num_elements
                    new_slice = slice(new_offset, new_offset + new_size)
                    assert coeff.shape == new_real_inputs[
                        new_k].shape + old_shape
                    subs_matrix[..., new_slice, old_slice] = \
                        coeff.data.reshape(batch_shape + (new_size, old_size))
        subs_vector = subs_vector.as_tensor()
        subs_matrix = subs_matrix.as_tensor()
        subs_matrix_t = ops.transpose(subs_matrix, -1, -2)

        # Construct the new funsor. Suppose the old Gaussian funsor g has density
        #   g(x) = < x | i - 1/2 P x>
        # Now define a new funsor f by substituting x = A y + B:
        #   f(y) = g(A y + B)
        #        = < A y + B | i - 1/2 P (A y + B) >
        #        = < y | At (i - P B) - 1/2 At P A y > + < B | i - 1/2 P B >
        #        =: < y | i' - 1/2 P' y > + C
        # where  P' = At P A  and  i' = At (i - P B)  parametrize a new Gaussian
        # and  C = < B | i - 1/2 P B >  parametrize a new Tensor.
        precision = subs_matrix @ old_precision @ subs_matrix_t
        info_vec = _mv(subs_matrix,
                       old_info_vec - _mv(old_precision, subs_vector))
        const = _vv(subs_vector,
                    old_info_vec - 0.5 * _mv(old_precision, subs_vector))
        result = Gaussian(info_vec, precision, new_inputs) + Tensor(
            const, new_int_inputs)
        return Subs(result, remaining_subs) if remaining_subs else result
コード例 #3
0
def eager_mvn(loc, scale_tril, value):
    assert len(loc.shape) == 1
    assert len(scale_tril.shape) == 2
    assert value.output == loc.output
    if not is_affine(loc) or not is_affine(value):
        return None  # lazy

    # Extract an affine representation.
    eye = torch.eye(scale_tril.data.size(-1)).expand(scale_tril.data.shape)
    prec_sqrt = Tensor(
        eye.triangular_solve(scale_tril.data, upper=False).solution,
        scale_tril.inputs)
    affine = prec_sqrt @ (loc - value)
    const, coeffs = extract_affine(affine)
    if not isinstance(const, Tensor):
        return None  # lazy
    if not all(isinstance(coeff, Tensor) for coeff, _ in coeffs.values()):
        return None  # lazy

    # Compute log_prob using funsors.
    scale_diag = Tensor(scale_tril.data.diagonal(dim1=-1, dim2=-2),
                        scale_tril.inputs)
    log_prob = (-0.5 * scale_diag.shape[0] * math.log(2 * math.pi) -
                scale_diag.log().sum() - 0.5 * (const**2).sum())

    # Dovetail to avoid variable name collision in einsum.
    equations1 = [
        ''.join(c if c in ',->' else chr(ord(c) * 2 - ord('a')) for c in eqn)
        for _, eqn in coeffs.values()
    ]
    equations2 = [
        ''.join(c if c in ',->' else chr(ord(c) * 2 - ord('a') + 1)
                for c in eqn) for _, eqn in coeffs.values()
    ]

    real_inputs = OrderedDict(
        (k, v) for k, v in affine.inputs.items() if v.dtype == 'real')
    assert tuple(real_inputs) == tuple(coeffs)

    # Align and broadcast tensors.
    neg_const = -const
    tensors = [neg_const] + [coeff for coeff, _ in coeffs.values()]
    inputs, tensors = align_tensors(*tensors, expand=True)
    neg_const, coeffs = tensors[0], tensors[1:]
    dim = sum(d.num_elements for d in real_inputs.values())
    batch_shape = neg_const.shape[:-1]

    info_vec = BlockVector(batch_shape + (dim, ))
    precision = BlockMatrix(batch_shape + (dim, dim))
    offset1 = 0
    for i1, (v1, c1) in enumerate(zip(real_inputs, coeffs)):
        size1 = real_inputs[v1].num_elements
        slice1 = slice(offset1, offset1 + size1)
        inputs1, output1 = equations1[i1].split('->')
        input11, input12 = inputs1.split(',')
        assert input11 == input12 + output1
        info_vec[..., slice1] = torch.einsum(
            f'...{input11},...{output1}->...{input12}', c1, neg_const) \
            .reshape(batch_shape + (size1,))
        offset2 = 0
        for i2, (v2, c2) in enumerate(zip(real_inputs, coeffs)):
            size2 = real_inputs[v2].num_elements
            slice2 = slice(offset2, offset2 + size2)
            inputs2, output2 = equations2[i2].split('->')
            input21, input22 = inputs2.split(',')
            assert input21 == input22 + output2
            precision[..., slice1, slice2] = torch.einsum(
                f'...{input11},...{input22}{output1}->...{input12}{input22}', c1, c2) \
                .reshape(batch_shape + (size1, size2))
            offset2 += size2
        offset1 += size1

    info_vec = info_vec.as_tensor()
    precision = precision.as_tensor()
    inputs.update(real_inputs)
    return log_prob + Gaussian(info_vec, precision, inputs)