def test_add_op_broadcast(): """ Program: fn (x, y) { return x + y; } """ env = Environment() x = relay.var('x', shape=(10, 5)) y = relay.var('y', shape=(1, 5)) func = relay.Function([x, y], add(x, y)) x_data = np.random.rand(10, 5).astype('float32') y_data = np.random.rand(1, 5).astype('float32') check_rts(env, func, [x_data, y_data], x_data + y_data)
def test_add_op_scalar(): """ Program: fn (x, y) { return x + y; } """ env = Environment() x = relay.var('x', shape=()) y = relay.var('y', shape=()) func = relay.Function([x, y], add(x, y)) x_data = np.array(10.0, dtype='float32') y_data = np.array(1.0, dtype='float32') check_rts(env, func, [x_data, y_data], x_data + y_data)
def assert_has_type(expr, typ, env=Environment({})): checked_expr = infer_type(env, expr) checked_type = checked_expr.checked_type if checked_type != typ: raise RuntimeError("Type mismatch %s vs %s" % (checked_type, typ))
def assert_has_type(expr, typ, env=Environment({})): checked_expr = check_expr(env, expr) assert checked_expr.checked_type() == typ