예제 #1
0
def validate(nodes, skip_op_check, strip_debug_ops):
    """Validate if the node's op is compatible with TensorFlow.js.

  Args:
    nodes: tf.NodeDef TensorFlow NodeDef objects from GraphDef.
    skip_op_check: Bool whether to skip the op check.
    strip_debug_ops: Bool whether to allow unsupported debug ops.
  """
    if skip_op_check:
        return set()
    ops = []
    for filename in resource_loader.list_dir('op_list'):
        if os.path.splitext(filename)[1] == '.json':
            with resource_loader.open_file(os.path.join(
                    'op_list', filename)) as json_data:
                ops += json.load(json_data)

    names = {x['tfOpName'] for x in ops}
    if strip_debug_ops:
        names = names.union({'Assert', 'CheckNumerics', 'Print'})
    not_supported = {x.op for x in [x for x in nodes if x.op not in names]}
    return not_supported
예제 #2
0
 def testReadingNonExistentFileRaisesError(self):
     with self.assertRaises(IOError):
         resource_loader.open_file('___non_existent')
예제 #3
0
 def testReadingFileInOpList(self):
     with resource_loader.open_file('op_list/arithmetic.json') as f:
         data = json.load(f)
         first_op = data[0]
         self.assertIn('tfOpName', first_op)
         self.assertIn('category', first_op)