コード例 #1
0
ファイル: onnx_loader.py プロジェクト: zuston/analytics-zoo
    def to_keras(self):
        """Convert a Onnx model to KerasNet model.
      """
        # parse network inputs, aka parameters
        for init_tensor in self.graph.initializer:
            if not init_tensor.name.strip():
                raise ValueError("Tensor's name is required.")
            self.initializer[init_tensor.name] = OnnxInput(
                name=init_tensor.name, zvalue=OnnxHelper.to_numpy(init_tensor))

        # converting GraphProto message
        # i: ValueInfoProto
        for i in self.graph.input:
            if i.name in self.initializer:
                # we should have added that via graph._initializer
                self._all_tensors[i.name] = self.initializer[i.name]
            else:
                self._inputs[i.name] = OnnxInput(
                    name=i.name, zvalue=OnnxHelper.get_shape_from_node(i))
                self._all_tensors[i.name] = self._inputs[i.name]

        # constructing nodes, nodes are stored as directed acyclic graph
        # converting NodeProto message
        for node in self.graph.node:
            inputs = []
            for i in node.input:
                if i == "":
                    continue
                if i not in self._all_tensors:
                    raise Exception("Cannot find {}".format(i))
                inputs.append(self._all_tensors[i])

            mapper = OperatorMapper.of(node, self.initializer, inputs)
            # update inputs and all_tensors
            for input in mapper.model_inputs:
                # Only update the original inputs here.
                if input.name in self._inputs:
                    self._inputs[input.name] = input.zvalue
                self._all_tensors[input.name] = input.zvalue
            tensor = mapper.to_tensor()
            output_ids = list(node.output)
            assert len(output_ids) == 1 or node.op_type == "Dropout",\
                "Only support single output for now"
            self._all_tensors[output_ids[0]] = OnnxInput(name=tensor.name,
                                                         zvalue=tensor)

        output_tensors = []
        for i in self.graph.output:
            if i.name not in self._all_tensors:
                raise Exception("The output haven't been calculate")
            output_tensors.append(self._all_tensors[i.name].zvalue)
        model = zmodels.Model(input=list(self._inputs.values()),
                              output=output_tensors)
        return model
コード例 #2
0
 def _extract_model_inputs(self):
     """
     :return: list of OnnxInput
     """
     input = OnnxInput(name=self.op_name, zvalue=OnnxHelper.to_numpy(self.onnx_attr['value']))
     return [self._to_zoo_input(input, is_constant=True)]