示例#1
0
    def _get_unimplemented_ops(self, pb_path):
        """
        Returns a list of unimplemented ops' names.

        Arguments:
            pb_path: Protobuf file path.

        Returns:
            List of unimplemented ops' names.
        """
        # get required op
        with open(pb_path) as f:
            lines = f.readlines()
            lines = [l.strip() for l in lines]

            required_ops = set()
            for line in lines:
                if line[:3] == 'op:':
                    op_name = line.split(' ')[1][1:-1]
                    required_ops.add(op_name)

        # get supported ops
        ob = OpsBridge()
        supported_ops = set([
            name for name in dir(ob)
            if name[:1] != "_" and name not in ob.__dict__
        ])

        # get unimplemented ops
        unimplemented_ops = required_ops - supported_ops
        return sorted(list(unimplemented_ops))
示例#2
0
 def reset(self):
     """
     Resets importer states.
     """
     self.name_op_map = dict()
     self.ops_bridge = OpsBridge()
     self.init_ops = []
     self.graph_def = None
示例#3
0
 def reset(self):
     """
     Resets importer states.
     """
     # TF's graph and graph_def
     self._graph = None
     self._graph_def = None
     # name to op dict and obs bridge converter
     self._name_op_map = dict()
     self._ops_bridge = OpsBridge()
     # checkpoint path for weight import
     self._checkpoint_path = None