def test_slice_ignorable_args_for_slice(self): graph_str = """graph(): %13 : int = prim::Constant[value=0]() %10 : bool = prim::Constant[value=0]() %8 : NoneType = prim::Constant() %0 : int = prim::Constant[value=1]() %1 : int = prim::Constant[value=2]() %2 : int = prim::Constant[value=3]() %3 : int = prim::Constant[value=4]() %4 : int = prim::Constant[value=9]() %5 : int[] = prim::ListConstruct(%0, %1, %2, %3, %4, %4) %6 : int[] = prim::ListConstruct(%0, %1, %2, %3, %4, %4) %7 : int[][] = prim::ListConstruct(%5, %6) %val.1 : Tensor = aten::tensor(%7, %8, %8, %10) %16 : Tensor = aten::slice(%val.1, %13, %1, %8, %0) %20 : Tensor = aten::slice(%16, %0, %8, %0, %0) return (%20)""" graph = parse_ir(graph_str) function = self.createFunctionFromGraph(graph) function_copy = self.getExportImportCopy(function) src = str(function.code) # For a signature: # aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor # We ignore trailing arguments after start=2 for dim 0 # and after end=1 for dim 1 # because in %16, %15 and %0 are default values for the schema. FileCheck().check( "torch.slice(torch.slice(torch.tensor(_0), 0, 2), 1, None, 1)" ).run(src) self.assertEqual(function(), function_copy())
def run_test(self, graph_ir, example_inputs): graph = parse_ir(graph_ir) jit_outs = torch._C._jit_interpret_graph(graph, example_inputs) onnx_proto = _jit_graph_to_onnx_model( graph, torch.onnx.OperatorExportTypes.ONNX, self.opset_version) ort_sess = onnxruntime.InferenceSession(onnx_proto, providers=self.ort_providers) ort_outs = run_ort(ort_sess, example_inputs) ort_compare_with_pytorch(ort_outs, jit_outs, rtol=1e-3, atol=1e-7)
def test_becomes_wildcard_annotations(self): graph_str = """ graph(%a.1 : Tensor, %b.1 : Tensor): %11 : NoneType = prim::Constant() %8 : int = prim::Constant[value=0]() %7 : int = prim::Constant[value=1]() %x.1 : Tensor = aten::add(%a.1, %b.1, %7) %y.1 : Tensor[] = aten::split(%x.1, %7, %8) return () """ graph = parse_ir(graph_str) alias_db = graph.alias_db() split_node = graph.findNode("aten::split") # split input enters wildcard set, list initalized as containing wildcard set self.assertTrue(alias_db.may_contain_alias(next(split_node.inputs()), split_node.output())) # because %x.1 enters wildcard set, it now aliases other members of wildcard set (graph inputs) self.assertTrue(alias_db.may_contain_alias(next(split_node.inputs()), next(graph.inputs())))