Beispiel #1
0
 def test_slice_ignorable_args_for_slice(self):
     graph_str = """graph():
         %13 : int = prim::Constant[value=0]()
         %10 : bool = prim::Constant[value=0]()
         %8 : NoneType = prim::Constant()
         %0 : int = prim::Constant[value=1]()
         %1 : int = prim::Constant[value=2]()
         %2 : int = prim::Constant[value=3]()
         %3 : int = prim::Constant[value=4]()
         %4 : int = prim::Constant[value=9]()
         %5 : int[] = prim::ListConstruct(%0, %1, %2, %3, %4, %4)
         %6 : int[] = prim::ListConstruct(%0, %1, %2, %3, %4, %4)
         %7 : int[][] = prim::ListConstruct(%5, %6)
         %val.1 : Tensor = aten::tensor(%7, %8, %8, %10)
         %16 : Tensor = aten::slice(%val.1, %13, %1, %8, %0)
         %20 : Tensor = aten::slice(%16, %0, %8, %0, %0)
         return (%20)"""
     graph = parse_ir(graph_str)
     function = self.createFunctionFromGraph(graph)
     function_copy = self.getExportImportCopy(function)
     src = str(function.code)
     # For a signature:
     # aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor
     # We ignore trailing arguments after start=2 for dim 0
     # and after end=1 for dim 1
     # because in %16, %15 and %0 are default values for the schema.
     FileCheck().check(
         "torch.slice(torch.slice(torch.tensor(_0), 0, 2), 1, None, 1)"
     ).run(src)
     self.assertEqual(function(), function_copy())
    def run_test(self, graph_ir, example_inputs):
        graph = parse_ir(graph_ir)
        jit_outs = torch._C._jit_interpret_graph(graph, example_inputs)

        onnx_proto = _jit_graph_to_onnx_model(
            graph, torch.onnx.OperatorExportTypes.ONNX, self.opset_version)
        ort_sess = onnxruntime.InferenceSession(onnx_proto,
                                                providers=self.ort_providers)
        ort_outs = run_ort(ort_sess, example_inputs)

        ort_compare_with_pytorch(ort_outs, jit_outs, rtol=1e-3, atol=1e-7)
Beispiel #3
0
 def test_becomes_wildcard_annotations(self):
     graph_str = """
     graph(%a.1 : Tensor, %b.1 : Tensor):
         %11 : NoneType = prim::Constant()
         %8 : int = prim::Constant[value=0]()
         %7 : int = prim::Constant[value=1]()
         %x.1 : Tensor = aten::add(%a.1, %b.1, %7)
         %y.1 : Tensor[] = aten::split(%x.1, %7, %8)
         return ()
     """
     graph = parse_ir(graph_str)
     alias_db = graph.alias_db()
     split_node = graph.findNode("aten::split")
     # split input enters wildcard set, list initalized as containing wildcard set
     self.assertTrue(alias_db.may_contain_alias(next(split_node.inputs()), split_node.output()))
     # because %x.1 enters wildcard set, it now aliases other members of wildcard set (graph inputs)
     self.assertTrue(alias_db.may_contain_alias(next(split_node.inputs()), next(graph.inputs())))