def test_get_input_node_map_given_valid_graph(self):
     """get_input_node_map should accept valid graphs"""
     graph_def = testutils.get_sample_graph_def()
     input_nodes = rewrite.get_input_node_map(graph_def)
     self.assertGreater(len(input_nodes), 1)
     # randomly verify the existence of nodes in the map
     self.assertIn('model/conv1/BiasAdd', input_nodes)
     self.assertIn('model/flatten/Reshape', input_nodes)
     self.assertIn('model/output/MatMul', input_nodes)
 def test_get_input_node_map_given_duplicates(self):
     """get_input_node_map should raise ValueError given duplicate names"""
     graph_def = testutils.get_sample_graph_def()
     relu = _get_node_by_name(graph_def, 'model/conv3/Relu')
     neg = rewrite.make_op_node('Neg', list(relu.input), name='kate')
     dup = rewrite.make_op_node('Exp', neg, name='model/conv3/BiasAdd')
     replace_nodes = {
         'model/conv3/Relu': [neg, dup],
     }
     updated_graph = rewrite.update_graph_def(graph_def, replace_nodes, {})
     self.assertRaises(ValueError,
                       lambda: rewrite.get_input_node_map(updated_graph))
예제 #3
0
def convert_int64_to_int32(graph_def: r.GraphDef) -> r.GraphDef:
    """Convert int64 input to int32 for TFJS compatibility

    Args:
        graph_def: GraphDef proto containing the network layout
    Returns:
        Updated graph with int64 inputs converted to int32
    """
    inputs = util.get_input_nodes(graph_def)
    convert = [info.name for info in inputs if info.dtype == util.np.int64]
    if len(convert) == 0:
        return graph_def
    # quick access to nodes by name
    node_map = r.get_input_node_map(graph_def)
    # map of all node inputs to their referencing node and their argument index
    input_map = defaultdict(list)
    for node in graph_def.node:
        for index, name in enumerate(node.input):
            input_map[name].append((index, node))
    # type cast ops to add to the graph
    type_cast_ops = []
    # nodes that require a type cast op
    type_cast_candidates: Dict[str, Tuple[int, r.NodeDef]] = {}

    for node in map(lambda x: node_map[x], convert):
        _set_tensor_dtype(node, _DT_INT32)
        # find all nodes that reference this input and adjust their datatype
        # attributes if required
        # technical note: referenced_by is a stack, this really is a
        # depth-first recursion
        referenced_by = input_map[node.name]
        while len(referenced_by) > 0:
            idx, ref = referenced_by.pop()
            # get the input node and the index of the output tensor
            input_node, output_idx = _get_input_node(ref, idx, node_map)
            # find the description of this node's operation
            op = op_def_registry.get(ref.op)
            desc = op.input_arg[idx]
            # find out whether we can just change the input type and which
            # attributes we might need to touch
            if desc.type != 0 and desc.type != _DT_INT32:
                # input type is fixed and cannot be changed: add a type cast
                cast_op = _make_cast_node(input_node, output_idx, _DT_INT32,
                                          desc.type)
                ref.input[idx] = cast_op.name
                type_cast_ops.append(cast_op)
                node_map[cast_op.name] = cast_op
                input_map[cast_op.name].append((idx, ref))
            elif desc.type_list_attr != '' or desc.type_attr == '':
                # input arrays of potentially mixed types cannot be handled
                raise ValueError("don't know how to handle input type changes"
                                 f' for node "{ref.name}" op={ref.op}')
            else:
                # change the type of this input
                type_attr = desc.type_attr
                ref.attr[type_attr].type = _DT_INT32
                if ref.name in type_cast_candidates:
                    del type_cast_candidates[ref.name]
                # check the other inputs for type compatibility
                for i, desc in enumerate(op.input_arg):
                    if i == idx or desc.type_attr != type_attr:
                        continue  # not a matching input
                    input_node, output_idx = _get_input_node(ref, i, node_map)
                    if input_node.name in convert:
                        continue  # Placeholder that will be converted
                    src_type = _get_output_type(input_node, output_idx)
                    if src_type == _DT_INT32:
                        continue  # type matches already
                    if input_node.op == 'Const':
                        # weight tensor: harmonize_dtypes() will fix these
                        _set_tensor_dtype(input_node, _DT_INT32)
                    else:
                        # add node as a candidate for needing type cast op
                        type_cast_candidates[input_node.name] = (i, ref)
                # process any changed outputs next
                for idx, output in enumerate(op.output_arg):
                    if output.type_attr == type_attr:
                        input_name = _get_tensor_name(ref, idx)
                        referenced_by += input_map[input_name]

    for idx, ref in type_cast_candidates.values():
        # add type cast operations for all nodes that have a type mismatch
        inp_node, channel = _get_input_node(ref, idx, node_map)
        src_type = _get_output_type(inp_node, channel)
        if src_type != _DT_INT32:
            cast_op = _make_cast_node(inp_node, channel, src_type, _DT_INT32)
            ref.input[idx] = cast_op.name
            type_cast_ops.append(cast_op)
            node_map[cast_op.name] = cast_op

    graph_def.node.extend(type_cast_ops)
    return graph_def
    def test_replace_matching_nodes(self):
        # case 1: unchanged copy if no matches
        graph_def = testutils.get_sample_graph_def()

        def _is_prelu(node):
            return node.op == 'Prelu'

        def _remove_node(node, map, mods):
            return []

        updated_graph_def, modifiers = rewrite.replace_matching_nodes(
            graph_def, predicate=_is_prelu, transform=_remove_node)
        self.assertEqual(modifiers, {})
        self.assertEqual(updated_graph_def, graph_def)
        # case 2: replaces matching nodes and keeps graph valid
        name_of_node_to_replace = 'model/conv2/Relu'
        new_name_of_replaced_node = ''

        def _must_replace(node):
            return node.name == name_of_node_to_replace

        def _convert_to_log_sigmoid(node, input_map, modifiers):
            """replace Relu with logarithmic sigmoid 1/(1+exp(-x))"""
            def _get_name(suffix):
                return rewrite.generate_name_from(node.name, input_map,
                                                  f'logSigmoid/{suffix}')

            nonlocal new_name_of_replaced_node
            # -x
            neg = rewrite.make_op_node('Neg',
                                       list(node.input),
                                       name=_get_name('Neg'))
            # exp(-x)
            exp = rewrite.make_op_node('Exp', neg, name=_get_name('Exp'))
            # constant tensor holding "1"
            res = rewrite.make_const_node(np.array([1], dtype=np.float32),
                                          name=_get_name('Var/resource'))
            # variable holding "1"
            one = rewrite.make_op_node('Identity', res, _get_name('Var'))
            # 1+exp(-x)
            add = rewrite.make_op_node('Add', [one, exp], _get_name('Add'))
            # 1/(1+exp-x)
            inv = rewrite.make_op_node('Inv', add, _get_name('Inv'))
            new_name_of_replaced_node = inv.name  # remember the output name
            return [neg, exp, res, one, add, inv]

        updated_graph_def, modifiers = rewrite.replace_matching_nodes(
            graph_def,
            predicate=_must_replace,
            transform=_convert_to_log_sigmoid)

        # replaced node must have been removed
        updated_nodes = rewrite.get_input_node_map(updated_graph_def)
        self.assertNotIn(name_of_node_to_replace, updated_nodes)
        # replaced node must not be referenced
        for _, node in updated_nodes.items():
            # nodes with inputs only
            if node.op not in ('Const', 'Placeholder'):
                self.assertNotIn(name_of_node_to_replace, node.input)

        # referenced to replaced node must point to last node in replacement
        original_nodes = rewrite.get_input_node_map(graph_def)
        replaced_references = [
            node.name for node in original_nodes.values()
            if name_of_node_to_replace in node.input
        ]
        for node_name in replaced_references:
            node = updated_nodes[node_name]
            self.assertIn(new_name_of_replaced_node, node.input)