コード例 #1
0
def test_advanced_indexing_lazy(output_shape):
    x = Tensor(randn((2, 3, 4) + output_shape),
               OrderedDict([
                   ('i', bint(2)),
                   ('j', bint(3)),
                   ('k', bint(4)),
               ]))
    u = Variable('u', bint(2))
    v = Variable('v', bint(3))
    with interpretation(lazy):
        i = Number(1, 2) - u
        j = Number(2, 3) - v
        k = u + v

    expected_data = empty((2, 3) + output_shape)
    i_data = x.materialize(i).data
    j_data = x.materialize(j).data
    k_data = x.materialize(k).data
    for u in range(2):
        for v in range(3):
            expected_data[u, v] = x.data[i_data[u], j_data[v], k_data[u, v]]
    expected = Tensor(expected_data,
                      OrderedDict([
                          ('u', bint(2)),
                          ('v', bint(3)),
                      ]))

    assert_equiv(expected, x(i, j, k))
    assert_equiv(expected, x(i=i, j=j, k=k))

    assert_equiv(expected, x(i=i, j=j)(k=k))
    assert_equiv(expected, x(j=j, k=k)(i=i))
    assert_equiv(expected, x(k=k, i=i)(j=j))

    assert_equiv(expected, x(i=i)(j=j, k=k))
    assert_equiv(expected, x(j=j)(k=k, i=i))
    assert_equiv(expected, x(k=k)(i=i, j=j))

    assert_equiv(expected, x(i=i)(j=j)(k=k))
    assert_equiv(expected, x(i=i)(k=k)(j=j))
    assert_equiv(expected, x(j=j)(i=i)(k=k))
    assert_equiv(expected, x(j=j)(k=k)(i=i))
    assert_equiv(expected, x(k=k)(i=i)(j=j))
    assert_equiv(expected, x(k=k)(j=j)(i=i))
コード例 #2
0
    def diff_fn(p_data):
        p = Tensor(p_data, be_inputs)
        q = p.sample(sampled_vars, sample_inputs, rng_key=rng_key)
        mq = p.materialize(q).reduce(ops.logaddexp, 'n')
        mq = mq.align(tuple(p.inputs))

        _, (p_data, mq_data) = align_tensors(p, mq)
        assert p_data.shape == mq_data.shape
        return (ops.exp(mq_data) * probe).sum() - (ops.exp(p_data) *
                                                   probe).sum(), mq
コード例 #3
0
ファイル: test_tensor.py プロジェクト: ordabayevy/funsor
def test_diagonal_rename():
    x = Tensor(
        randn(2, 2, 3),
        OrderedDict(a=funsor.Bint[2], b=funsor.Bint[2], c=funsor.Bint[3]),
        'real')
    d = Variable("d", funsor.Bint[2])
    dt = x.materialize(d)
    yt = x(a=dt, b=dt)
    y = x(a=d, b=d)
    assert_close(y, yt)
コード例 #4
0
    def eager_subs(self, subs):
        assert isinstance(subs, tuple)
        prototype = Tensor(self.info_vec)
        subs = tuple(
            (k, v if isinstance(v, (Variable,
                                    Slice)) else prototype.materialize(v))
            for k, v in subs if k in self.inputs)
        if not subs:
            return self

        # Constants and Affine funsors are eagerly substituted;
        # everything else is lazily substituted.
        lazy_subs = tuple(
            (k, v) for k, v in subs
            if not isinstance(v, (Number, Tensor, Variable, Slice))
            and not (is_affine(v) and affine_inputs(v)))
        var_subs = tuple((k, v) for k, v in subs if isinstance(v, Variable))
        int_subs = tuple((k, v) for k, v in subs
                         if isinstance(v, (Number, Tensor, Slice))
                         if v.dtype != 'real')
        real_subs = tuple((k, v) for k, v in subs
                          if isinstance(v, (Number, Tensor))
                          if v.dtype == 'real')
        affine_subs = tuple((k, v) for k, v in subs
                            if is_affine(v) and affine_inputs(v)
                            and not isinstance(v, Variable))
        if var_subs:
            return self._eager_subs_var(
                var_subs, int_subs + real_subs + affine_subs + lazy_subs)
        if int_subs:
            return self._eager_subs_int(int_subs,
                                        real_subs + affine_subs + lazy_subs)
        if real_subs:
            return self._eager_subs_real(real_subs, affine_subs + lazy_subs)
        if affine_subs:
            return self._eager_subs_affine(affine_subs, lazy_subs)
        return reflect(Subs, self, lazy_subs)