示例#1
0
    def transform(self, node):
        if not isinstance(node, tt.Apply):
            return False

        if self.node_filter(node):
            return False

        input_expr = node.default_output()

        with variables(*self.relation_lvars):
            q = var()
            kanren_results = run(1, q, (self.kanren_relation, input_expr, q))

        chosen_res = self.results_filter(kanren_results)

        if chosen_res:
            # Turn the meta objects and tuple-form expressions into Theano
            # objects.
            if isinstance(chosen_res, tuple) and chosen_res[0] == dict:
                # We got a dictionary of replacements.
                new_node = {k.obj: reify_meta(v)
                            for k, v in evalt(chosen_res).items()}

                assert all(k in node.fgraph.variables
                           for k in new_node)
            else:
                new_node = self.adjust_outputs(node, reify_meta(chosen_res))

            return new_node
        else:
            return False
示例#2
0
def test_objects():
    fact(commutative, Add)
    fact(associative, Add)
    assert tuple(goaleval(eq_assoccomm(add(1, 2, 3), add(3, 1, 2)))({}))
    assert tuple(goaleval(eq_assoccomm(add(1, 2, 3), add(3, 1, 2)))({}))

    x = var('x')

    assert reify(x, tuple(goaleval(eq_assoccomm(
        add(1, 2, 3), add(1, 2, x)))({}))[0]) == 3

    assert reify(x, next(goaleval(eq_assoccomm(
        add(1, 2, 3), add(x, 2, 1)))({}))) == 3

    v = add(1, 2, 3)
    with variables(v):
        x = add(5, 6)
        assert reify(v, next(goaleval(eq_assoccomm(v, x))({}))) == x
示例#3
0
def test_metatize():
    vec_tt = tt.vector('vec')
    vec_m = metatize(vec_tt)
    assert vec_m.base == type(vec_tt)

    test_list = [1, 2, 3]
    metatize_test_list = metatize(test_list)
    assert isinstance(metatize_test_list, list)
    assert all(isinstance(m, MetaSymbol) for m in metatize_test_list)

    test_iter = iter([1, 2, 3])
    metatize_test_iter = metatize(test_iter)
    assert isinstance(metatize_test_iter, Iterator)
    assert all(isinstance(m, MetaSymbol) for m in metatize_test_iter)

    test_out = metatize(var())
    assert isvar(test_out)

    with variables(vec_tt):
        test_out = metatize(vec_tt)
        assert test_out == vec_tt
        assert isvar(test_out)

    test_out = metatize(np.r_[1, 2, 3])
    assert isinstance(test_out, MetaSymbol)

    class TestClass(object):
        pass

    with pytest.raises(Exception):
        metatize(TestClass())

    class TestOp(tt.gof.Op):
        pass

    test_out = metatize(TestOp)
    assert issubclass(test_out, MetaOp)

    test_op_tt = TestOp()
    test_obj = test_out(obj=test_op_tt)
    assert isinstance(test_obj, MetaSymbol)
    assert test_obj.obj == test_op_tt
    assert test_obj.base == TestOp
示例#4
0
    def transform(self, node):
        if not isinstance(node, tt.Apply):
            return False

        if self.node_filter(node):
            return False

        try:
            input_expr = node.default_output()
        except AttributeError:
            input_expr = node.outputs

        with variables(*self.relation_lvars):
            q = var()
            kanren_results = run(None, q, self.kanren_relation(input_expr, q))

        chosen_res = self.results_filter(kanren_results)

        if chosen_res:
            if isinstance(chosen_res, ExpressionTuple):
                chosen_res = eval_and_reify_meta(chosen_res)

            if isinstance(chosen_res, dict):
                chosen_res = list(chosen_res.items())

            if isinstance(chosen_res, list):
                # We got a dictionary of replacements
                new_node = {eval_and_reify_meta(k): eval_and_reify_meta(v) for k, v in chosen_res}

                assert all(k in node.fgraph.variables for k in new_node.keys())
            elif isinstance(chosen_res, tt.Variable):
                # Attempt to automatically format the output for multi-output
                # `Apply` nodes.
                new_node = self.adjust_outputs(node, eval_and_reify_meta(chosen_res))
            else:
                raise ValueError(
                    "Unsupported FunctionGraph replacement variable type: {chosen_res}"
                )

            return new_node
        else:
            return False
示例#5
0
def test_objects():
    fact(commutative, Add)
    fact(associative, Add)
    assert tuple(goaleval(eq_assoccomm(add(1, 2, 3), add(3, 1, 2)))({}))
    assert tuple(goaleval(eq_assoccomm(add(1, 2, 3), add(3, 1, 2)))({}))

    x = var('x')

    assert reify(
        x,
        tuple(goaleval(eq_assoccomm(add(1, 2, 3), add(1, 2, x)))({}))[0]) == 3

    assert reify(x,
                 next(goaleval(eq_assoccomm(add(1, 2, 3), add(x, 2,
                                                              1)))({}))) == 3

    v = add(1, 2, 3)
    with variables(v):
        x = add(5, 6)
        assert reify(v, next(goaleval(eq_assoccomm(v, x))({}))) == x
def test_unification():
    x, y, a, b = tt.dvectors('xyab')
    x_s = tt.scalar('x_s')
    y_s = tt.scalar('y_s')
    c_tt = tt.constant(1, 'c')
    d_tt = tt.constant(2, 'd')
    # x_l = tt.vector('x_l')
    # y_l = tt.vector('y_l')
    # z_l = tt.vector('z_l')

    x_l = var('x_l')
    y_l = var('y_l')
    z_l = var('z_l')

    assert a == reify(x_l, {x_l: a}).reify()
    test_expr = mt.add(1, mt.mul(2, x_l))
    test_reify_res = reify(test_expr, {x_l: a})
    assert graph_equal(test_reify_res.reify(), 1 + 2 * a)

    z = tt.add(b, a)
    assert {x_l: z} == unify(x_l, z)
    assert b == unify(mt.add(x_l, a), mt.add(b, a))[x_l].reify()

    res = unify(mt.inv(mt.add(x_l, a)), mt.inv(mt.add(b, y_l)))
    assert res[x_l].reify() == b
    assert res[y_l].reify() == a

    # TODO: This produces a `DimShuffle` so that the scalar constant `1`
    # will match the dimensions of the vector `b`.  That `DimShuffle` isn't
    # handled by the logic variable form.
    # assert unify(mt.add(x_l, 1), mt.add(b_l, 1))[x] == b

    with variables(x):
        assert unify(x + 1, b + 1)[x].reify() == b

    assert unify(mt.add(x_l, a), mt.add(b, a))[x_l].reify() == b
    with variables(x):
        assert unify(x, b)[x] == b
        assert unify([x], [b])[x] == b
        assert unify((x, ), (b, ))[x] == b
        assert unify(x + 1, b + 1)[x].reify() == b
        assert unify(x + a, b + a)[x].reify() == b

    with variables(x):
        assert unify(a + b, a + x)[x].reify() == b

    mt_expr_add = mt.add(x_l, y_l)

    # The parameters are vectors
    tt_expr_add_1 = tt.add(x, y)
    assert graph_equal(
        tt_expr_add_1,
        reify(mt_expr_add, unify(mt_expr_add, tt_expr_add_1)).reify())

    # The parameters are scalars
    tt_expr_add_2 = tt.add(x_s, y_s)
    assert graph_equal(
        tt_expr_add_2,
        reify(mt_expr_add, unify(mt_expr_add, tt_expr_add_2)).reify())

    # The parameters are constants
    tt_expr_add_3 = tt.add(c_tt, d_tt)
    assert graph_equal(
        tt_expr_add_3,
        reify(mt_expr_add, unify(mt_expr_add, tt_expr_add_3)).reify())