def validate(nodes, skip_op_check, strip_debug_ops): """Validate if the node's op is compatible with TensorFlow.js. Args: nodes: tf.NodeDef TensorFlow NodeDef objects from GraphDef. skip_op_check: Bool whether to skip the op check. strip_debug_ops: Bool whether to allow unsupported debug ops. """ if skip_op_check: return set() ops = [] for filename in resource_loader.list_dir('op_list'): if os.path.splitext(filename)[1] == '.json': with resource_loader.open_file(os.path.join( 'op_list', filename)) as json_data: ops += json.load(json_data) names = {x['tfOpName'] for x in ops} if strip_debug_ops: names = names.union({'Assert', 'CheckNumerics', 'Print'}) not_supported = {x.op for x in [x for x in nodes if x.op not in names]} return not_supported
def testReadingNonExistentFileRaisesError(self): with self.assertRaises(IOError): resource_loader.open_file('___non_existent')
def testReadingFileInOpList(self): with resource_loader.open_file('op_list/arithmetic.json') as f: data = json.load(f) first_op = data[0] self.assertIn('tfOpName', first_op) self.assertIn('category', first_op)