def test_get_opr_seq(): class Net(M.Module): def __init__(self): super().__init__() self.data = megengine.tensor(np.random.random((1, 1, 4, 4)), dtype=np.float32) def forward(self, input): A = input.shape[0] shape = astensor1d((A, A), self.data, dtype="int32", device=input.device) x = F.reshape(self.data, shape) o = input + x return o net = Net() input = megengine.tensor(np.random.random((4, 4)), dtype=np.float32) @trace(symbolic=True, capture_as_const=True) def func(inp, *, net=None): return net(inp) func(input, net=net) file = io.BytesIO() func.dump(file, optimize_for_inference=False) file.seek(0) *_, outputs = mgb_graph.load_graph(file) seq_1 = cgtools.get_oprs_seq(outputs, True) assert len(seq_1) == 5 seq_2 = cgtools.get_oprs_seq(outputs, False) assert len(seq_2) == 6
def test_goptions_log_exp(): @trace(symbolic=True, opt_level=0, capture_as_const=True) def f(x): return log(exp(x)) @trace(symbolic=True, opt_level=1, capture_as_const=True) def g(x): return log(exp(x)) f(tensor(1.0)) _, out = mkstemp() f.dump(out, optimize_for_inference=False) *_, outputs = G.load_graph(out) oprs_1 = cgtools.get_oprs_seq(outputs) g(tensor(1.0)) g.dump(out, optimize_for_inference=False) *_, outputs = G.load_graph(out) oprs_2 = cgtools.get_oprs_seq(outputs) assert len(oprs_1) - len(oprs_2) == 2
def test_catch_input_name(tensor_name, var_name): def f(x): return 2 * x func = trace(f, symbolic=True, capture_as_const=True) x = Tensor(np.ones(shape=(2, 3)), name=tensor_name) func(x).numpy() file = io.BytesIO() func.dump(file, optimize_for_inference=False, keep_opr_name=True, keep_var_name=2) file.seek(0) *_, outputs = G.load_graph(file) op = cgtools.get_oprs_seq(outputs)[-1] assert op.inputs[0].name == var_name
def _dump_and_load(func, symbolic, keep_opr_name=True): AutoNaming.clear() func = trace(func, symbolic=symbolic, capture_as_const=True) x = Tensor(np.ones(shape=(2, 3))) func(x).numpy() file = io.BytesIO() func.dump( file, optimize_for_inference=False, arg_names=("x", ), keep_opr_name=keep_opr_name, keep_var_name=2, ) file.seek(0) outputs = G.load_graph(file).output_vars_list ops = cgtools.get_oprs_seq(outputs) return ops