def test_replace_vars(): g = mgb_graph.Graph() g.options.async_exec_level = 0b100 device = "xpux" dtype = np.float32 a = mgb_graph.InputNode(device=device, dtype=dtype, graph=g) const = g.make_const(1.234, device=device) add_op = Elemwise(Elemwise.Mode.ADD) mul_op = Elemwise(Elemwise.Mode.MUL) a_plus_a = apply_normal_varnode(add_op, a.outputs[0], a.outputs[0])[0] a_plus_a_mul_const = apply_normal_varnode(mul_op, a_plus_a, const)[0] rst = apply_normal_varnode(add_op, a_plus_a_mul_const, a.outputs[0])[0] (new, ) = cgtools.replace_vars([rst._node], {const._node: a_plus_a._node}) out = mgb_graph.OutputNode(mgb_graph.VarNode(new)) func = g.compile(out.outputs[0]) func.execute() x = make_dev_tensor(5.0, device=device) a.set_value(x) res = out.get_value().numpy() np.testing.assert_equal(res, np.array([105.0]))
def test_op(): g = mgb_graph.Graph() x = Tensor(np.random.randn(10).astype("float32"), device="xpux")._dev_tensor() v, _ = mgb_graph.input_callback( lambda: x, device=x.comp_node, dtype=x.dtype, graph=g ) neg = Elemwise(Elemwise.Mode.NEGATE) v = mgb_graph.apply_normal_varnode(neg, v)[0] y = Future() v = mgb_graph.output_callback(y.set_result, v) f = g.compile(v) f() np.testing.assert_equal(x.numpy(), -y.result().numpy())
def test_exception(): err_msg = "QwQ" def throw_exc(): raise RuntimeError(err_msg) g = mgb_graph.Graph() x, _ = mgb_graph.input_callback(throw_exc, device="xpux", dtype="float32", graph=g) neg = Elemwise(Elemwise.Mode.NEGATE) y = mgb_graph.OutputNode(mgb_graph.apply_normal_varnode(neg, x)[0]) f = g.compile(y.outputs[0]) try: f.execute() y.get_value() except Exception as exc: assert err_msg in str(exc)
def test_assert_equal(): g = G.Graph() inp1 = g.make_h2d(dtype=np.float32, device="xpux") inp2 = g.make_h2d(dtype=np.float32, device="xpux") op = builtin.AssertEqual(maxerr=1e-5) out = G.apply_normal_varnode(op, inp1._node, inp2._node)[0] g.compile(out) file = io.BytesIO() out_model = G.dump_graph([out]) file.write(out_model[0]) file.seek(0) net = Net.load(file) dump_file = io.BytesIO() net.dump(dump_file) dump_file.seek(0) g = GraphInference(dump_file) g.run(np.array([1.0, 2.0]), np.array([1.0, 2.0]))
def assert_equal(expect, real, **kwargs): op = builtin.AssertEqual(**kwargs) (res, ) = G.apply_normal_varnode(op, expect, real) return G.VarNode(res)
def typecvt(x, dt=None): (y, ) = G.apply_normal_varnode(ops.TypeCvt(dtype=dt), x) return y