示例#1
0
def test_replace_vars():
    g = mgb_graph.Graph()
    g.options.async_exec_level = 0b100
    device = "xpux"
    dtype = np.float32
    a = mgb_graph.InputNode(device=device, dtype=dtype, graph=g)
    const = g.make_const(1.234, device=device)
    add_op = Elemwise(Elemwise.Mode.ADD)
    mul_op = Elemwise(Elemwise.Mode.MUL)
    a_plus_a = apply_normal_varnode(add_op, a.outputs[0], a.outputs[0])[0]
    a_plus_a_mul_const = apply_normal_varnode(mul_op, a_plus_a, const)[0]
    rst = apply_normal_varnode(add_op, a_plus_a_mul_const, a.outputs[0])[0]
    (new, ) = cgtools.replace_vars([rst._node], {const._node: a_plus_a._node})
    out = mgb_graph.OutputNode(mgb_graph.VarNode(new))
    func = g.compile(out.outputs[0])
    func.execute()
    x = make_dev_tensor(5.0, device=device)
    a.set_value(x)
    res = out.get_value().numpy()
    np.testing.assert_equal(res, np.array([105.0]))
示例#2
0
def test_op():
    g = mgb_graph.Graph()
    x = Tensor(np.random.randn(10).astype("float32"), device="xpux")._dev_tensor()
    v, _ = mgb_graph.input_callback(
        lambda: x, device=x.comp_node, dtype=x.dtype, graph=g
    )
    neg = Elemwise(Elemwise.Mode.NEGATE)
    v = mgb_graph.apply_normal_varnode(neg, v)[0]
    y = Future()
    v = mgb_graph.output_callback(y.set_result, v)
    f = g.compile(v)
    f()

    np.testing.assert_equal(x.numpy(), -y.result().numpy())
示例#3
0
def test_exception():
    err_msg = "QwQ"

    def throw_exc():
        raise RuntimeError(err_msg)

    g = mgb_graph.Graph()
    x, _ = mgb_graph.input_callback(throw_exc, device="xpux", dtype="float32", graph=g)
    neg = Elemwise(Elemwise.Mode.NEGATE)
    y = mgb_graph.OutputNode(mgb_graph.apply_normal_varnode(neg, x)[0])
    f = g.compile(y.outputs[0])
    try:
        f.execute()
        y.get_value()
    except Exception as exc:
        assert err_msg in str(exc)
示例#4
0
def test_assert_equal():
    g = G.Graph()
    inp1 = g.make_h2d(dtype=np.float32, device="xpux")
    inp2 = g.make_h2d(dtype=np.float32, device="xpux")
    op = builtin.AssertEqual(maxerr=1e-5)
    out = G.apply_normal_varnode(op, inp1._node, inp2._node)[0]
    g.compile(out)
    file = io.BytesIO()
    out_model = G.dump_graph([out])
    file.write(out_model[0])
    file.seek(0)
    net = Net.load(file)

    dump_file = io.BytesIO()
    net.dump(dump_file)
    dump_file.seek(0)
    g = GraphInference(dump_file)
    g.run(np.array([1.0, 2.0]), np.array([1.0, 2.0]))
示例#5
0
 def assert_equal(expect, real, **kwargs):
     op = builtin.AssertEqual(**kwargs)
     (res, ) = G.apply_normal_varnode(op, expect, real)
     return G.VarNode(res)
示例#6
0
 def typecvt(x, dt=None):
     (y, ) = G.apply_normal_varnode(ops.TypeCvt(dtype=dt), x)
     return y