def validate(graph_def, skip_op_check, strip_debug_ops): """Validate if the node's op is compatible with TensorFlow.js. Args: graph_def: tf.GraphDef TensorFlow GraphDef proto object, which represents the model topology. skip_op_check: Bool whether to skip the op check. strip_debug_ops: Bool whether to allow unsupported debug ops. """ nodes = [] + list(graph_def.node) for func in graph_def.library.function: nodes.extend(list(func.node_def)) if skip_op_check: return set() ops = [] for filename in resource_loader.list_dir('op_list'): if os.path.splitext(filename)[1] == '.json': with resource_loader.open_file(os.path.join( 'op_list', filename)) as json_data: ops += json.load(json_data) names = {x['tfOpName'] for x in ops} if strip_debug_ops: names = names.union({'Assert', 'CheckNumerics', 'Print'}) not_supported = {x.op for x in [x for x in nodes if x.op not in names]} return not_supported
def validate(nodes, skip_op_check, strip_debug_ops): """Validate if the node's op is compatible with TensorFlow.js. Args: nodes: tf.NodeDef TensorFlow NodeDef objects from GraphDef. skip_op_check: Bool whether to skip the op check. strip_debug_ops: Bool whether to allow unsupported debug ops. """ if skip_op_check: return set() ops = [] for filename in resource_loader.list_dir('op_list'): if os.path.splitext(filename)[1] == '.json': with resource_loader.open_file(os.path.join( 'op_list', filename)) as json_data: ops += json.load(json_data) names = {x['tfOpName'] for x in ops} if strip_debug_ops: names = names.union({'Assert', 'CheckNumerics', 'Print'}) not_supported = {x.op for x in [x for x in nodes if x.op not in names]} return not_supported
def testListingFilesInOpList(self): files = resource_loader.list_dir('op_list') self.assertGreater(len(files), 0) for filename in files: self.assertTrue(filename.endswith('.json'))