示例#1
0
    def make_simple_model(self) -> Model:
        graph = Graph()

        # two inputs
        x = Input(
            'input',
            [1, 5, 5, 3],
            Float32(),
        )

        w = Constant(
            'weight',
            Float32(),
            np.zeros([1, 2, 2, 3]),
            dimension_format='NHWC',
        )

        # Conv
        conv = Conv('conv', [1, 4, 4, 1],
                    Float32(), {
                        'X': x,
                        'W': w
                    },
                    kernel_shape=[2, 2])

        # One output
        y = Output('output', [1, 4, 4, 1], Float32(), {'input': conv})

        # add ops to the graph
        graph.add_op_and_inputs(y)
        model = Model()
        model.graph = graph
        return model
示例#2
0
    def read(self, pb_path: str) -> Model:
        """Read TF file and load model.

        Parameters
        ----------
        pb_path : str
            Path to TF file

        Returns
        -------
        model : Model
            Loaded model

        """
        model = Model()

        # load tensorflow model
        graph_def = graph_pb2.GraphDef()
        try:
            f = open(path.abspath(pb_path), "rb")
            graph_def.ParseFromString(f.read())
            f.close()
        except IOError:
            print("Could not open file. Creating a new one.")

        # import graph
        model.graph = Importer.make_graph(graph_def)

        return model
示例#3
0
    def read(self, pb_path: str, json_path: Optional[str] = None) -> Model:
        """Read ONNX file and load model.

        Parameters
        ----------
        pb_path : str
            Path to ONNX file

        Returns
        -------
        model : Model
            Loaded model

        """
        model = Model()

        # load onnx model
        onnx_model = onnx.load(path.abspath(pb_path))

        # debug print in JSON
        if json_path:
            from pip._internal import main
            main(['install', 'protobuf'])
            from google.protobuf.json_format import MessageToJson, Parse
            js_str = MessageToJson(onnx_model)
            js_obj = json.loads(js_str)
            with open(json_path, 'w') as fw:
                json.dump(js_obj, fw, indent=4)

        # ckeck if it's a valid model
        # onnx.checker.check_model(onnx_model)

        # import graph
        model.graph = Importer.make_graph(onnx_model)

        return model