Exemplo n.º 1
0
 def forward(self, x):
     x = self.identity(x)
     for m in self.identity_dict:
         x = self.identity_dict[m](x)
     for m in self.identity_list:
         x = m(x)
     return F.neg(x) + self.param
Exemplo n.º 2
0
def test_insert():
    traced_module, x, expect = _init_block()
    graph = traced_module.graph
    relu_out = graph.get_function_by_type(F.relu).as_unique().outputs[0]
    with graph.insert_exprs():
        neg_out = F.neg(relu_out)
    graph.replace_node({relu_out: neg_out})
    graph.compile()
    np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6)
Exemplo n.º 3
0
def test_id_and_name():
    def _check_id(traced_module):
        _total_ids = traced_module.graph._total_ids
        node_ids = [n._id for n in traced_module.graph.nodes().as_list()]
        assert len(set(node_ids)) == len(node_ids)
        assert max(node_ids) + 1 == _total_ids[0]

        expr_ids = [n._id for n in traced_module.graph.exprs().as_list()]
        assert len(set(expr_ids)) == len(expr_ids)
        assert max(expr_ids) + 1 == _total_ids[1]

    def _check_name(flatened_module):
        node_names = [n._name for n in flatened_module.graph.nodes().as_list()]
        assert len(set(node_names)) == len(node_names)

    traced_module, x, expect = _init_module()
    _check_id(traced_module)

    flattened_module = traced_module.flatten()
    _check_id(flattened_module)
    _check_name(flattened_module)

    # pickle check
    obj = pickle.dumps(traced_module)
    traced_module = pickle.loads(obj)
    Node._set_next_id(159)
    Expr._set_next_id(1024)

    graph = traced_module.graph
    for expr in graph.get_function_by_type(F.relu).as_list():
        relu_out = expr.outputs[0]
        cur_graph = expr.top_graph
        with cur_graph.insert_exprs():
            neg_out = F.neg(relu_out)
        cur_graph.replace_node({relu_out: neg_out})
        cur_graph.compile()
    _check_id(traced_module)

    flattened_module = traced_module.flatten()
    _check_id(flattened_module)
    _check_name(flattened_module)

    # check trace TracedModule
    obj = pickle.dumps(traced_module)
    traced_module = pickle.loads(obj)
    module = NewModule(traced_module)
    traced_module = trace_module(module, x)
    _check_id(traced_module)

    flattened_module = traced_module.flatten()
    _check_id(flattened_module)
    _check_name(flattened_module)
Exemplo n.º 4
0
def test_op():
    g = mgb_graph.Graph()
    x = make_dev_tensor(np.random.randn(10).astype("float32"), device="xpux")
    v, _ = mgb_graph.input_callback(
        lambda: x, device=x.comp_node, dtype=x.dtype, graph=g
    )
    v = F.neg(v)
    y = Future()
    v = mgb_graph.output_callback(y.set_result, v)
    f = g.compile(v)
    f()

    np.testing.assert_equal(x.numpy(), -y.result().numpy())
Exemplo n.º 5
0
def test_exception():
    err_msg = "QwQ"

    def throw_exc():
        raise RuntimeError(err_msg)

    g = mgb_graph.Graph()
    x, _ = mgb_graph.input_callback(throw_exc, device="xpux", dtype="float32", graph=g)
    y = mgb_graph.OutputNode(F.neg(x))
    f = g.compile(y.outputs[0])
    try:
        f.execute()
        y.get_value()
    except Exception as exc:
        assert err_msg in str(exc)