def _remove_unused_control_flow_inputs(graph_def: GraphKeys) -> GraphDef: """The graph optimizer marks unsused nodes, which we can remove from the graph """ def is_unused(node): return (node.op == c.TFJS_NODE_PLACEHOLDER_KEY and node.name.startswith('unused_control_flow_input')) result, _ = replace_matching_nodes(graph_def, is_unused, lambda _: []) return result
def replace_prelu(input_graph_def: util.GraphDef) -> util.GraphDef: """ Replace all Prelu-activations in the graph with supported TF-operations. Arguments: input_graph_def: TF graph definition to examine Returns: Updated copy of the input graph with Prelu-nodes replaced by supported TF operations """ def _predicate(node): return node.op == 'Prelu' return util.replace_matching_nodes(input_graph_def, _predicate, _split_prelu)
def _remove_unused_nodes(graph_def: GraphDef) -> GraphDef: """Remove nodes that don't have any connections in the graph. A node is connected if its name occurs in the input of another node. Nodes that have non-empty input lists are considered connected, as well. """ has_ref = set() # names of all nodes that are referenced by other nodes for node in graph_def.node: for inp in node.input: has_ref.add(inp) def _unconnected(n): return len(n.input) == 0 and n.name not in has_ref result, _ = replace_matching_nodes(graph_def, _unconnected, _remove_nodes) return result
def split_all_fused_ops(input_graph_def: util.GraphDef) -> util.GraphDef: """ Split all fused-operation nodes in the graph into individual operations. This enables further conversion into formats that don't support fused operations (e.g. TFLite without Flex enabled). Args: input_graph_def: TF graph definition to examine Returns: Updated copy of the input graph with matching nodes replaced by individual operations """ def _predicate(node): return util.is_fused_conv2d(node) or util.is_fused_matmul(node) return util.replace_matching_nodes(input_graph_def=input_graph_def, predicate=_predicate, transform=_split_fused_op)
def split_fused_depthwise(input_graph_def: util.GraphDef) -> util.GraphDef: """Decompose all fused depthwise conv2d operations into separate operations This function looks for fused depthwise operations and splits matching nodes into individual operations. Fused activation functions that aren't supported (e.g. 'Prelu') can be replaced afterwards in a separate processing step. Args: input_graph_def: TF graph_def proto to be processed Returns: Updated copy of the input graph with matching nodes replaced by individual operations """ return util.replace_matching_nodes(input_graph_def, util.is_fused_depthwise, _split_fused_depthwise)
def split_fused_prelu(input_graph_def: util.GraphDef) -> util.GraphDef: """ This function looks for fused operations that include a 'Prelu'-activation. Matching nodes will be split into individual operations. TFJS uses fused operations for performance. Some fused activations aren't supported by TF (e.g. 'Prelu'), so we need to split the fused ops back into individual ops and replace unsupported functions by equivalent supported constructs later. Arguments: input_graph_def: TF graph definition to examine Results: Updated copy of the input graph with matching nodes replaced by individual operations """ def _predicate(node): return (util.is_fused_conv2d(node, b'Prelu') or util.is_fused_matmul(node, b'Prelu')) return util.replace_matching_nodes(input_graph_def, _predicate, _split_fused_op)
def test_replace_matching_nodes(self): # case 1: unchanged copy if no matches graph_def = testutils.get_sample_graph_def() def _is_prelu(node): return node.op == 'Prelu' def _remove_node(node, map, mods): return [] updated_graph_def, modifiers = rewrite.replace_matching_nodes( graph_def, predicate=_is_prelu, transform=_remove_node) self.assertEqual(modifiers, {}) self.assertEqual(updated_graph_def, graph_def) # case 2: replaces matching nodes and keeps graph valid name_of_node_to_replace = 'model/conv2/Relu' new_name_of_replaced_node = '' def _must_replace(node): return node.name == name_of_node_to_replace def _convert_to_log_sigmoid(node, input_map, modifiers): """replace Relu with logarithmic sigmoid 1/(1+exp(-x))""" def _get_name(suffix): return rewrite.generate_name_from(node.name, input_map, f'logSigmoid/{suffix}') nonlocal new_name_of_replaced_node # -x neg = rewrite.make_op_node('Neg', list(node.input), name=_get_name('Neg')) # exp(-x) exp = rewrite.make_op_node('Exp', neg, name=_get_name('Exp')) # constant tensor holding "1" res = rewrite.make_const_node(np.array([1], dtype=np.float32), name=_get_name('Var/resource')) # variable holding "1" one = rewrite.make_op_node('Identity', res, _get_name('Var')) # 1+exp(-x) add = rewrite.make_op_node('Add', [one, exp], _get_name('Add')) # 1/(1+exp-x) inv = rewrite.make_op_node('Inv', add, _get_name('Inv')) new_name_of_replaced_node = inv.name # remember the output name return [neg, exp, res, one, add, inv] updated_graph_def, modifiers = rewrite.replace_matching_nodes( graph_def, predicate=_must_replace, transform=_convert_to_log_sigmoid) # replaced node must have been removed updated_nodes = rewrite.get_input_node_map(updated_graph_def) self.assertNotIn(name_of_node_to_replace, updated_nodes) # replaced node must not be referenced for _, node in updated_nodes.items(): # nodes with inputs only if node.op not in ('Const', 'Placeholder'): self.assertNotIn(name_of_node_to_replace, node.input) # referenced to replaced node must point to last node in replacement original_nodes = rewrite.get_input_node_map(graph_def) replaced_references = [ node.name for node in original_nodes.values() if name_of_node_to_replace in node.input ] for node_name in replaced_references: node = updated_nodes[node_name] self.assertIn(new_name_of_replaced_node, node.input)