コード例 #1
0
def validate(graph_def, skip_op_check, strip_debug_ops):
    """Validate if the node's op is compatible with TensorFlow.js.

  Args:
    graph_def: tf.GraphDef TensorFlow GraphDef proto object, which represents
      the model topology.
    skip_op_check: Bool whether to skip the op check.
    strip_debug_ops: Bool whether to allow unsupported debug ops.
  """
    nodes = [] + list(graph_def.node)
    for func in graph_def.library.function:
        nodes.extend(list(func.node_def))

    if skip_op_check:
        return set()
    ops = []
    for filename in resource_loader.list_dir('op_list'):
        if os.path.splitext(filename)[1] == '.json':
            with resource_loader.open_file(os.path.join(
                    'op_list', filename)) as json_data:
                ops += json.load(json_data)

    names = {x['tfOpName'] for x in ops}
    if strip_debug_ops:
        names = names.union({'Assert', 'CheckNumerics', 'Print'})
    not_supported = {x.op for x in [x for x in nodes if x.op not in names]}
    return not_supported
コード例 #2
0
def validate(nodes, skip_op_check, strip_debug_ops):
    """Validate if the node's op is compatible with TensorFlow.js.

  Args:
    nodes: tf.NodeDef TensorFlow NodeDef objects from GraphDef.
    skip_op_check: Bool whether to skip the op check.
    strip_debug_ops: Bool whether to allow unsupported debug ops.
  """
    if skip_op_check:
        return set()
    ops = []
    for filename in resource_loader.list_dir('op_list'):
        if os.path.splitext(filename)[1] == '.json':
            with resource_loader.open_file(os.path.join(
                    'op_list', filename)) as json_data:
                ops += json.load(json_data)

    names = {x['tfOpName'] for x in ops}
    if strip_debug_ops:
        names = names.union({'Assert', 'CheckNumerics', 'Print'})
    not_supported = {x.op for x in [x for x in nodes if x.op not in names]}
    return not_supported
コード例 #3
0
 def testListingFilesInOpList(self):
     files = resource_loader.list_dir('op_list')
     self.assertGreater(len(files), 0)
     for filename in files:
         self.assertTrue(filename.endswith('.json'))