Exemplo n.º 1
0
def test_assoccomm():
    from symbolic_pymc.relations import buildo

    x, a, b, c = tt.dvectors('xabc')
    test_expr = x + 1
    q = var('q')

    res = run(1, q, buildo(tt.add, test_expr.owner.inputs, test_expr))
    assert q == res[0]

    res = run(1, q, buildo(q, test_expr.owner.inputs, test_expr))
    assert tt.add == res[0].reify()

    res = run(1, q, buildo(tt.add, q, test_expr))
    assert mt(tuple(test_expr.owner.inputs)) == res[0]

    res = run(0, var('x'), eq_comm(mt.mul(a, b), mt.mul(b, var('x'))))
    assert (mt(a), ) == res

    res = run(0, var('x'), eq_comm(mt.add(a, b), mt.add(b, var('x'))))
    assert (mt(a), ) == res

    res = run(0, var('x'), (eq_assoc, mt.add(a, b, c), mt.add(a, var('x'))))

    # TODO: `res[0]` should return `etuple`s.  Since `eq_assoc` effectively
    # picks apart the results of `arguments(...)`, I don't know if we can
    # keep the `etuple`s around.  We might be able to convert the results
    # to `etuple`s automatically by wrapping `eq_assoc`, though.
    res_obj = etuple(*res[0]).eval_obj
    assert res_obj == mt(b + c)

    res = run(0, var('x'), (eq_assoc, mt.mul(a, b, c), mt.mul(a, var('x'))))
    res_obj = etuple(*res[0]).eval_obj
    assert res_obj == mt(b * c)
Exemplo n.º 2
0
def test_assoccomm():
    x, a, b, c = tt.dvectors("xabc")
    test_expr = x + 1
    q = var()

    res = run(1, q, applyo(tt.add, etuple(*test_expr.owner.inputs), test_expr))
    assert q == res[0]

    res = run(1, q, applyo(q, etuple(*test_expr.owner.inputs), test_expr))
    assert tt.add == res[0].reify()

    res = run(1, q, applyo(tt.add, q, test_expr))
    assert mt(tuple(test_expr.owner.inputs)) == res[0]

    x = var()
    res = run(0, x, eq_comm(mt.mul(a, b), mt.mul(b, x)))
    assert (mt(a), ) == res

    res = run(0, x, eq_comm(mt.add(a, b), mt.add(b, x)))
    assert (mt(a), ) == res

    (res, ) = run(0, x, eq_assoc(mt.add(a, b, c), mt.add(a, x)))
    assert res == mt(b + c)

    (res, ) = run(0, x, eq_assoc(mt.mul(a, b, c), mt.mul(a, x)))
    assert res == mt(b * c)
Exemplo n.º 3
0
def test_eq_comm_object():
    x = var("x")

    fact(commutative, Add)
    fact(associative, Add)

    assert run(0, x, eq_comm(add(1, 2, 3), add(3, 1, x))) == (2, )
    assert set(run(0, x, eq_comm(add(1, 2), x))) == set((add(1, 2), add(2, 1)))
    assert set(run(0, x, eq_assoccomm(add(1, 2, 3), add(1, x)))) == set(
        (add(2, 3), add(3, 2)))
Exemplo n.º 4
0
def test_eq_comm_object():
    x = var('x')
    fact(commutative, Add)
    fact(associative, Add)

    assert run(0, x, eq_comm(add(1, 2, 3), add(3, 1, x))) == (2, )

    assert set(run(0, x, eq_comm(add(1, 2), x))) == set((add(1, 2), add(2, 1)))

    assert set(run(0, x, eq_assoccomm(add(1, 2, 3), add(1, x)))) == \
        set((add(2, 3), add(3, 2)))
Exemplo n.º 5
0
def test_eq_comm():
    assert results(eq_comm(1, 1))
    assert results(eq_comm((c, 1, 2, 3), (c, 1, 2, 3)))
    assert results(eq_comm((c, 3, 2, 1), (c, 1, 2, 3)))
    assert not results(eq_comm((a, 3, 2, 1), (a, 1, 2, 3)))  # not commutative
    assert not results(eq_comm((3, c, 2, 1), (c, 1, 2, 3)))
    assert not results(eq_comm((c, 1, 2, 1), (c, 1, 2, 3)))
    assert not results(eq_comm((a, 1, 2, 3), (c, 1, 2, 3)))
    assert len(results(eq_comm((c, 3, 2, 1), x))) >= 6
    assert results(eq_comm(x, y)) == ({x: y}, )
Exemplo n.º 6
0
def test_eq_comm():
    assert results(eq_comm(1, 1))
    assert results(eq_comm((c, 1, 2, 3), (c, 1, 2, 3)))
    assert results(eq_comm((c, 3, 2, 1), (c, 1, 2, 3)))
    assert not results(eq_comm((a, 3, 2, 1), (a, 1, 2, 3)))  # not commutative
    assert not results(eq_comm((3, c, 2, 1), (c, 1, 2, 3)))
    assert not results(eq_comm((c, 1, 2, 1), (c, 1, 2, 3)))
    assert not results(eq_comm((a, 1, 2, 3), (c, 1, 2, 3)))
    assert len(results(eq_comm((c, 3, 2, 1), x))) >= 6
    assert results(eq_comm(x, y)) == ({x: y}, )
Exemplo n.º 7
0
def test_commutativity():
    with enable_lvar_defaults('names'):
        add_1_mt = mt(1) + mt(2)
        add_2_mt = mt(2) + mt(1)

    res = run(0, var('q'), commutative(add_1_mt.base_operator))
    assert res is not False

    res = run(0, var('q'), eq_comm(add_1_mt, add_2_mt))
    assert res is not False

    with enable_lvar_defaults('names'):
        add_pattern_mt = mt(2) + var('q')

    res = run(0, var('q'), eq_comm(add_1_mt, add_pattern_mt))
    assert res[0] == add_1_mt.base_arguments[0]
Exemplo n.º 8
0
def test_commutativity_tfp():

    with tf.Graph().as_default():
        mu_tf = tf.compat.v1.placeholder(tf.float32,
                                         name="mu",
                                         shape=tf.TensorShape([None]))
        tau_tf = tf.compat.v1.placeholder(tf.float32,
                                          name="tau",
                                          shape=tf.TensorShape([None]))

        normal_tfp = tfd.normal.Normal(mu_tf, tau_tf)

        value_tf = tf.compat.v1.placeholder(tf.float32,
                                            name="value",
                                            shape=tf.TensorShape([None]))

        normal_log_lik = normal_tfp.log_prob(value_tf)

    normal_log_lik_opt = normalize_tf_graph(normal_log_lik)

    with enable_lvar_defaults("names", "node_attrs"):
        tfp_normal_pattern_mt = mt_normal_log_prob(var(), var(), var())

    normal_log_lik_mt = mt(normal_log_lik)
    normal_log_lik_opt_mt = mt(normal_log_lik_opt)

    # Our pattern is the form of an unnormalized TFP normal PDF.
    assert run(0, True, eq(normal_log_lik_mt,
                           tfp_normal_pattern_mt)) == (True, )
    # Our pattern should *not* match the Grappler-optimized graph, because
    # Grappler will reorder terms (e.g. the log + constant
    # variance/normalization term)
    assert run(0, True, eq(normal_log_lik_opt_mt, tfp_normal_pattern_mt)) == ()

    # XXX: `eq_comm` is, unfortunately, order sensitive!  LHS should be ground.
    assert run(0, True, eq_comm(normal_log_lik_mt,
                                tfp_normal_pattern_mt)) == (True, )
    assert run(0, True, eq_comm(normal_log_lik_opt_mt,
                                tfp_normal_pattern_mt)) == (True, )
Exemplo n.º 9
0
def test_deep_commutativity():
    x, y = var('x'), var('y')

    e1 = (c, (c, 1, x), y)
    e2 = (c, 2, (c, 3, 1))
    assert run(0, (x, y), eq_comm(e1, e2)) == ((3, 2), )
Exemplo n.º 10
0
def test_deep_commutativity():
    x, y = var('x'), var('y')

    e1 = (c, (c, 1, x), y)
    e2 = (c, 2, (c, 3, 1))
    assert run(0, (x, y), eq_comm(e1, e2)) == ((3, 2), )
Exemplo n.º 11
0
def test_eq_comm():
    x, y, z = var(), var(), var()

    commutative.facts.clear()
    commutative.index.clear()

    comm_op = "comm_op"

    fact(commutative, comm_op)

    assert run(0, True, eq_comm(1, 1)) == (True, )
    assert run(0, True, eq_comm((comm_op, 1, 2, 3),
                                (comm_op, 1, 2, 3))) == (True, )

    assert run(0, True, eq_comm((comm_op, 3, 2, 1),
                                (comm_op, 1, 2, 3))) == (True, )
    assert run(0, y, eq_comm((comm_op, 3, y, 1), (comm_op, 1, 2, 3))) == (2, )
    assert run(0, (x, y), eq_comm((comm_op, x, y, 1), (comm_op, 1, 2, 3))) == (
        (2, 3),
        (3, 2),
    )
    assert run(0, (x, y), eq_comm((comm_op, 2, 3, 1), (comm_op, 1, x, y))) == (
        (2, 3),
        (3, 2),
    )

    assert not run(0, True, eq_comm(("op", 3, 2, 1),
                                    ("op", 1, 2, 3)))  # not commutative
    assert not run(0, True, eq_comm((3, comm_op, 2, 1), (comm_op, 1, 2, 3)))
    assert not run(0, True, eq_comm((comm_op, 1, 2, 1), (comm_op, 1, 2, 3)))
    assert not run(0, True, eq_comm(("op", 1, 2, 3), (comm_op, 1, 2, 3)))

    # Test for variable args
    res = run(4, (x, y), eq_comm(x, y))
    exp_res_form = (
        (etuple(comm_op, x, y), etuple(comm_op, y, x)),
        (x, y),
        (etuple(etuple(comm_op, x, y)), etuple(etuple(comm_op, y, x))),
        (etuple(comm_op, x, y, z), etuple(comm_op, x, z, y)),
    )

    for a, b in zip(res, exp_res_form):
        s = unify(a, b)
        assert s is not False
        assert all(isvar(i) for i in reify((x, y, z), s))

    # Make sure it can unify single elements
    assert (3, ) == run(0, x, eq_comm((comm_op, 1, 2, 3), (comm_op, 2, x, 1)))

    # `eq_comm` should propagate through
    assert (3, ) == run(
        0, x,
        eq_comm(("div", 1, (comm_op, 1, 2, 3)),
                ("div", 1, (comm_op, 2, x, 1))))
    # Now it should not
    assert () == run(
        0, x,
        eq_comm(("div", 1, ("div", 1, 2, 3)), ("div", 1, ("div", 2, x, 1))))

    expected_res = {(1, 2, 3), (2, 1, 3), (3, 1, 2), (1, 3, 2), (2, 3, 1),
                    (3, 2, 1)}
    assert expected_res == set(
        run(0, (x, y, z), eq_comm((comm_op, 1, 2, 3), (comm_op, x, y, z))))
    assert expected_res == set(
        run(0, (x, y, z), eq_comm((comm_op, x, y, z), (comm_op, 1, 2, 3))))
    assert expected_res == set(
        run(
            0,
            (x, y, z),
            eq_comm(("div", 1, (comm_op, 1, 2, 3)),
                    ("div", 1, (comm_op, x, y, z))),
        ))

    e1 = (comm_op, (comm_op, 1, x), y)
    e2 = (comm_op, 2, (comm_op, 3, 1))
    assert run(0, (x, y), eq_comm(e1, e2)) == ((3, 2), )

    e1 = ((comm_op, 3, 1), )
    e2 = ((comm_op, 1, x), )

    assert run(0, x, eq_comm(e1, e2)) == (3, )

    e1 = (2, (comm_op, 3, 1))
    e2 = (y, (comm_op, 1, x))

    assert run(0, (x, y), eq_comm(e1, e2)) == ((3, 2), )

    e1 = (comm_op, (comm_op, 1, x), y)
    e2 = (comm_op, 2, (comm_op, 3, 1))

    assert run(0, (x, y), eq_comm(e1, e2)) == ((3, 2), )