Beispiel #1
0
def test_add_op_broadcast():
    """
    Program:
        fn (x, y) {
            return x + y;
        }
    """
    env = Environment()
    x = relay.var('x', shape=(10, 5))
    y = relay.var('y', shape=(1, 5))
    func = relay.Function([x, y], add(x, y))
    x_data = np.random.rand(10, 5).astype('float32')
    y_data = np.random.rand(1, 5).astype('float32')
    check_rts(env, func, [x_data, y_data], x_data + y_data)
Beispiel #2
0
def test_add_op_scalar():
    """
    Program:
        fn (x, y) {
            return x + y;
        }
    """
    env = Environment()
    x = relay.var('x', shape=())
    y = relay.var('y', shape=())
    func = relay.Function([x, y], add(x, y))
    x_data = np.array(10.0, dtype='float32')
    y_data = np.array(1.0, dtype='float32')
    check_rts(env, func, [x_data, y_data], x_data + y_data)
Beispiel #3
0
def assert_has_type(expr, typ, env=Environment({})):
    checked_expr = infer_type(env, expr)
    checked_type = checked_expr.checked_type
    if checked_type != typ:
        raise RuntimeError("Type mismatch %s vs %s" % (checked_type, typ))
Beispiel #4
0
def assert_has_type(expr, typ, env=Environment({})):
    checked_expr = check_expr(env, expr)
    assert checked_expr.checked_type() == typ