コード例 #1
0
    def run_node(cls, node, inputs, device='CPU', outputs_info=None):
        inputs_info = [(x.dtype, x.shape) for x in inputs]
        input_value_infos = [
            helper.make_tensor_value_info(x, NP_TYPE_TO_TENSOR_TYPE[t], shape)
            for x, (t, shape) in zip(node.input, inputs_info)
        ]
        output_value_infos = [
            helper.make_tensor_value_info(x, NP_TYPE_TO_TENSOR_TYPE[t], shape)
            for x, (t, shape) in zip(node.output, outputs_info)
        ]
        if outputs_info:
            graph = helper.make_graph([node], "test", input_value_infos, [])
            orig_model = helper.make_model(graph, producer_name='onnx-test')
            orig_model_str = orig_model.SerializeToString()
            inferred_model_str = onnx.shape_inference.infer_shapes(
                orig_model_str)
            inferred_model = ModelProto()
            inferred_model.ParseFromString(inferred_model_str)

            # Allow shape inference to not return anything, but if it
            # does then check that it's correct
            if inferred_model.graph.value_info:
                assert (list(
                    inferred_model.graph.value_info) == output_value_infos)
        raise BackendIsNotSupposedToImplementIt(
            "This is the dummy backend test that doesn't verify the results but does run the shape inference"
        )
コード例 #2
0
    def prepare(
        cls,
        model,  # type: ModelProto
        device='CPU',  # type: Text
        **kwargs  # type: Any
    ):  # type: (...) -> Optional[onnx.backend.base.BackendRep]
        super(DummyBackend, cls).prepare(model, device, **kwargs)

        # test shape inference
        model = onnx.shape_inference.infer_shapes(model)
        value_infos = {
            vi.name: vi
            for vi in itertools.chain(model.graph.value_info,
                                      model.graph.output)
        }

        if do_enforce_test_coverage_safelist(model):
            for node in model.graph.node:
                for i, output in enumerate(node.output):
                    if node.op_type == 'Dropout' and i != 0:
                        continue
                    assert output in value_infos
                    tt = value_infos[output].type.tensor_type
                    assert tt.elem_type != TensorProto.UNDEFINED
                    for dim in tt.shape.dim:
                        assert dim.WhichOneof('value') == 'dim_value'

        raise BackendIsNotSupposedToImplementIt(
            "This is the dummy backend test that doesn't verify the results but does run the checker"
        )
コード例 #3
0
ファイル: backend.py プロジェクト: yyqgood/onnx-tensorflow
    def _onnx_node_to_tensorflow_op(cls,
                                    node,
                                    tensor_dict,
                                    handlers=None,
                                    opset=None,
                                    strict=True):
        """
    Convert onnx node to tensorflow op.

    Args:
      node: Onnx node object.
      tensor_dict: Tensor dict of graph.
      opset: Opset version of the operator set. Default 0 means using latest version.
      strict: whether to enforce semantic equivalence between the original model
        and the converted tensorflow model, defaults to True (yes, enforce semantic equivalence).
        Changing to False is strongly discouraged.
    Returns:
      Tensorflow op
    """
        handlers = handlers or cls._get_handlers(opset)
        if handlers:
            handler = handlers[node.domain].get(
                node.op_type, None) if node.domain in handlers else None
            if handler:
                return handler.handle(node,
                                      tensor_dict=tensor_dict,
                                      strict=strict)

        raise BackendIsNotSupposedToImplementIt(
            "{} is not implemented.".format(node.op_type))
コード例 #4
0
 def run_node(cls, node, inputs, device='CPU', outputs_info=None):
     super(DummyBackend, cls).run_node(node,
                                       inputs,
                                       device=device,
                                       outputs_info=outputs_info)
     raise BackendIsNotSupposedToImplementIt(
         "This is the dummy backend test that doesn't verify the results but does run the checker"
     )
コード例 #5
0
ファイル: handler.py プロジェクト: gglin001/onnx-jax
    def handle(cls, node, **kwargs):
        ver_handle = getattr(cls, "version_{}".format(cls.SINCE_VERSION), None)
        if ver_handle:
            cls.args_check(node, **kwargs)
            return ver_handle(node, **kwargs)

        raise BackendIsNotSupposedToImplementIt(
            "{} version {} is not implemented.".format(node.op_type,
                                                       cls.SINCE_VERSION))
コード例 #6
0
ファイル: test_backend_test.py プロジェクト: sridhar551/ONNX
    def prepare(cls, model, device='CPU', **kwargs):
        super(DummyBackend, cls).prepare(model, device, **kwargs)

        # test shape inference
        onnx.shape_inference.infer_shapes(model)

        raise BackendIsNotSupposedToImplementIt(
            "This is the dummy backend test that doesn't verify the results but does run the checker"
        )
コード例 #7
0
    def _jit(cls, node, opset=None, handlers=None, **kwargs):
        handlers = handlers or cls._get_handlers(opset)
        if handlers:
            handler = (handlers[node.domain].get(node.op_type, None)
                       if node.domain in handlers else None)
            if handler:
                return handler.handle(node, inputs=None, **kwargs)

        raise BackendIsNotSupposedToImplementIt(
            f"{node.op_type} is not implemented.")
コード例 #8
0
ファイル: test_backend_test.py プロジェクト: zzmcdc/onnx
 def run_node(cls,
              node,  # type: NodeProto
              inputs,  # type: Any
              device='CPU',  # type: Text
              outputs_info=None,  # type: Optional[Sequence[Tuple[numpy.dtype, Tuple[int, ...]]]]
              **kwargs  # type: Any
              ):  # type: (...) -> Optional[Tuple[Any, ...]]
     super(DummyBackend, cls).run_node(node, inputs, device=device, outputs_info=outputs_info)
     raise BackendIsNotSupposedToImplementIt(
         "This is the dummy backend test that doesn't verify the results but does run the checker")
コード例 #9
0
ファイル: test_backend_test.py プロジェクト: onnx/onnx
 def run_node(cls,
              node: NodeProto,
              inputs: Any,
              device: str = 'CPU',
              outputs_info: Optional[Sequence[Tuple[numpy.dtype, Tuple[int, ...]]]] = None,
              **kwargs: Any
              ) -> Optional[Tuple[Any, ...]]:
     super().run_node(node, inputs, device=device, outputs_info=outputs_info)
     raise BackendIsNotSupposedToImplementIt(
         "This is the dummy backend test that doesn't verify the results but does run the checker")
コード例 #10
0
    def prepare(cls,
                model,  # type: ModelProto
                device='CPU',  # type: Text
                **kwargs  # type: Any
                ):  # type: (...) -> Optional[onnx.backend.base.BackendRep]
        super(DummyBackend, cls).prepare(model, device, **kwargs)

        # test shape inference
        onnx.shape_inference.infer_shapes(model)

        raise BackendIsNotSupposedToImplementIt(
            "This is the dummy backend test that doesn't verify the results but does run the checker")
コード例 #11
0
    def handle(cls, node, **kwargs):
        """ Main method in handler. It will find corresponding versioned handle method,
    whose name format is `version_%d`. So prefix `version_` is reserved in onnx-tensorflow.
    DON'T use it for other purpose.

    :param node: NodeProto for backend.
    :param kwargs: Other args.
    :return: TensorflowNode for backend.
    """
        ver_handle = getattr(cls, "version_{}".format(cls.SINCE_VERSION), None)
        if ver_handle:
            cls.args_check(node, **kwargs)
            return ver_handle(node, **kwargs)

        raise BackendIsNotSupposedToImplementIt(
            "{} version {} is not implemented.".format(node.op_type,
                                                       cls.SINCE_VERSION))