示例#1
0
def _remove_unused_control_flow_inputs(graph_def: GraphKeys) -> GraphDef:
    """The graph optimizer marks unsused nodes, which we can remove
       from the graph
    """
    def is_unused(node):
        return (node.op == c.TFJS_NODE_PLACEHOLDER_KEY
                and node.name.startswith('unused_control_flow_input'))

    result, _ = replace_matching_nodes(graph_def, is_unused, lambda _: [])
    return result
示例#2
0
def replace_prelu(input_graph_def: util.GraphDef) -> util.GraphDef:
    """
    Replace all Prelu-activations in the graph with supported TF-operations.

    Arguments:
        input_graph_def: TF graph definition to examine

    Returns:
        Updated copy of the input graph with Prelu-nodes replaced by supported
        TF operations
    """
    def _predicate(node): return node.op == 'Prelu'
    return util.replace_matching_nodes(input_graph_def, _predicate,
                                       _split_prelu)
示例#3
0
def _remove_unused_nodes(graph_def: GraphDef) -> GraphDef:
    """Remove nodes that don't have any connections in the graph.
    A node is connected if its name occurs in the input of another node.
    Nodes that have non-empty input lists are considered connected, as well.
    """
    has_ref = set()  # names of all nodes that are referenced by other nodes
    for node in graph_def.node:
        for inp in node.input:
            has_ref.add(inp)

    def _unconnected(n):
        return len(n.input) == 0 and n.name not in has_ref

    result, _ = replace_matching_nodes(graph_def, _unconnected, _remove_nodes)
    return result
示例#4
0
def split_all_fused_ops(input_graph_def: util.GraphDef) -> util.GraphDef:
    """
    Split all fused-operation nodes in the graph into individual operations.
    This enables further conversion into formats that don't support fused
    operations (e.g. TFLite without Flex enabled).

    Args:
        input_graph_def: TF graph definition to examine

    Returns:
        Updated copy of the input graph with matching nodes replaced by
        individual operations
    """
    def _predicate(node):
        return util.is_fused_conv2d(node) or util.is_fused_matmul(node)

    return util.replace_matching_nodes(input_graph_def=input_graph_def,
                                       predicate=_predicate,
                                       transform=_split_fused_op)
def split_fused_depthwise(input_graph_def: util.GraphDef) -> util.GraphDef:
    """Decompose all fused depthwise conv2d operations into separate operations

    This function looks for fused depthwise operations and splits matching
    nodes into individual operations.

    Fused activation functions that aren't supported (e.g. 'Prelu') can be
    replaced afterwards in a separate processing step.

    Args:
        input_graph_def: TF graph_def proto to be processed

    Returns:
        Updated copy of the input graph with matching nodes replaced by
        individual operations
    """
    return util.replace_matching_nodes(input_graph_def,
                                       util.is_fused_depthwise,
                                       _split_fused_depthwise)
示例#6
0
def split_fused_prelu(input_graph_def: util.GraphDef) -> util.GraphDef:
    """
    This function looks for fused operations that include a 'Prelu'-activation.
    Matching nodes will be split into individual operations.

    TFJS uses fused operations for performance.
    Some fused activations aren't supported by TF (e.g. 'Prelu'), so we need
    to split the fused ops back into individual ops and replace unsupported
    functions by equivalent supported constructs later.

    Arguments:
        input_graph_def: TF graph definition to examine

    Results:
        Updated copy of the input graph with matching nodes replaced by
        individual operations
    """
    def _predicate(node):
        return (util.is_fused_conv2d(node, b'Prelu')
                or util.is_fused_matmul(node, b'Prelu'))
    return util.replace_matching_nodes(input_graph_def, _predicate,
                                       _split_fused_op)
    def test_replace_matching_nodes(self):
        # case 1: unchanged copy if no matches
        graph_def = testutils.get_sample_graph_def()

        def _is_prelu(node):
            return node.op == 'Prelu'

        def _remove_node(node, map, mods):
            return []

        updated_graph_def, modifiers = rewrite.replace_matching_nodes(
            graph_def, predicate=_is_prelu, transform=_remove_node)
        self.assertEqual(modifiers, {})
        self.assertEqual(updated_graph_def, graph_def)
        # case 2: replaces matching nodes and keeps graph valid
        name_of_node_to_replace = 'model/conv2/Relu'
        new_name_of_replaced_node = ''

        def _must_replace(node):
            return node.name == name_of_node_to_replace

        def _convert_to_log_sigmoid(node, input_map, modifiers):
            """replace Relu with logarithmic sigmoid 1/(1+exp(-x))"""
            def _get_name(suffix):
                return rewrite.generate_name_from(node.name, input_map,
                                                  f'logSigmoid/{suffix}')

            nonlocal new_name_of_replaced_node
            # -x
            neg = rewrite.make_op_node('Neg',
                                       list(node.input),
                                       name=_get_name('Neg'))
            # exp(-x)
            exp = rewrite.make_op_node('Exp', neg, name=_get_name('Exp'))
            # constant tensor holding "1"
            res = rewrite.make_const_node(np.array([1], dtype=np.float32),
                                          name=_get_name('Var/resource'))
            # variable holding "1"
            one = rewrite.make_op_node('Identity', res, _get_name('Var'))
            # 1+exp(-x)
            add = rewrite.make_op_node('Add', [one, exp], _get_name('Add'))
            # 1/(1+exp-x)
            inv = rewrite.make_op_node('Inv', add, _get_name('Inv'))
            new_name_of_replaced_node = inv.name  # remember the output name
            return [neg, exp, res, one, add, inv]

        updated_graph_def, modifiers = rewrite.replace_matching_nodes(
            graph_def,
            predicate=_must_replace,
            transform=_convert_to_log_sigmoid)

        # replaced node must have been removed
        updated_nodes = rewrite.get_input_node_map(updated_graph_def)
        self.assertNotIn(name_of_node_to_replace, updated_nodes)
        # replaced node must not be referenced
        for _, node in updated_nodes.items():
            # nodes with inputs only
            if node.op not in ('Const', 'Placeholder'):
                self.assertNotIn(name_of_node_to_replace, node.input)

        # referenced to replaced node must point to last node in replacement
        original_nodes = rewrite.get_input_node_map(graph_def)
        replaced_references = [
            node.name for node in original_nodes.values()
            if name_of_node_to_replace in node.input
        ]
        for node_name in replaced_references:
            node = updated_nodes[node_name]
            self.assertIn(new_name_of_replaced_node, node.input)