def forward(self, x): x = self.identity(x) for m in self.identity_dict: x = self.identity_dict[m](x) for m in self.identity_list: x = m(x) return F.neg(x) + self.param
def test_insert(): traced_module, x, expect = _init_block() graph = traced_module.graph relu_out = graph.get_function_by_type(F.relu).as_unique().outputs[0] with graph.insert_exprs(): neg_out = F.neg(relu_out) graph.replace_node({relu_out: neg_out}) graph.compile() np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6)
def test_id_and_name(): def _check_id(traced_module): _total_ids = traced_module.graph._total_ids node_ids = [n._id for n in traced_module.graph.nodes().as_list()] assert len(set(node_ids)) == len(node_ids) assert max(node_ids) + 1 == _total_ids[0] expr_ids = [n._id for n in traced_module.graph.exprs().as_list()] assert len(set(expr_ids)) == len(expr_ids) assert max(expr_ids) + 1 == _total_ids[1] def _check_name(flatened_module): node_names = [n._name for n in flatened_module.graph.nodes().as_list()] assert len(set(node_names)) == len(node_names) traced_module, x, expect = _init_module() _check_id(traced_module) flattened_module = traced_module.flatten() _check_id(flattened_module) _check_name(flattened_module) # pickle check obj = pickle.dumps(traced_module) traced_module = pickle.loads(obj) Node._set_next_id(159) Expr._set_next_id(1024) graph = traced_module.graph for expr in graph.get_function_by_type(F.relu).as_list(): relu_out = expr.outputs[0] cur_graph = expr.top_graph with cur_graph.insert_exprs(): neg_out = F.neg(relu_out) cur_graph.replace_node({relu_out: neg_out}) cur_graph.compile() _check_id(traced_module) flattened_module = traced_module.flatten() _check_id(flattened_module) _check_name(flattened_module) # check trace TracedModule obj = pickle.dumps(traced_module) traced_module = pickle.loads(obj) module = NewModule(traced_module) traced_module = trace_module(module, x) _check_id(traced_module) flattened_module = traced_module.flatten() _check_id(flattened_module) _check_name(flattened_module)
def test_op(): g = mgb_graph.Graph() x = make_dev_tensor(np.random.randn(10).astype("float32"), device="xpux") v, _ = mgb_graph.input_callback( lambda: x, device=x.comp_node, dtype=x.dtype, graph=g ) v = F.neg(v) y = Future() v = mgb_graph.output_callback(y.set_result, v) f = g.compile(v) f() np.testing.assert_equal(x.numpy(), -y.result().numpy())
def test_exception(): err_msg = "QwQ" def throw_exc(): raise RuntimeError(err_msg) g = mgb_graph.Graph() x, _ = mgb_graph.input_callback(throw_exc, device="xpux", dtype="float32", graph=g) y = mgb_graph.OutputNode(F.neg(x)) f = g.compile(y.outputs[0]) try: f.execute() y.get_value() except Exception as exc: assert err_msg in str(exc)