Exemplo n.º 1
0
def optimize_graph(graph: tf.Graph, level=None) -> GraphDef:
    """Optimise a tensorflow graph for inference after modification

    This function optimises the given graph for inference after the graph
    may have been modified to replace known, but unsupported operations.
    Optimisation might use multiple passes and aim at CPUs or GPUs.

    Args:
        graph: Tensorflow v1 graph (or wrapped v2 function) to be optimised
        level: optional optimisation level; currently unsupported

    Returns:
        Optimised ``GraphDef`` message for inference or format conversion
    """
    inputs = get_input_nodes(graph)
    outputs = get_output_nodes(graph)
    signature_def = _build_signature_def(graph, inputs, outputs)
    _mark_outputs_as_train_op(graph, signature_def)
    config = ConfigProto()
    _set_optimization_options(
        config,
        ['debug_stripper', 'remap', 'constfold', 'arithmetic', 'dependency'])
    optimised_graph = _run_tf_optimizer(config, graph, signature_def)
    optimised_graph = _remove_unused_control_flow_inputs(optimised_graph)
    return optimised_graph
Exemplo n.º 2
0
    def test_get_input_nodes(self):
        """Should return node info for inputs"""
        def _shape_of(node):
            shape = [d.size for d in node.attr['shape'].shape.dim]
            return [n if n > 0 else None for n in shape]

        graph = testutils.get_sample_graph()
        actual = util.get_input_nodes(graph)
        expected = testutils.get_inputs(graph.as_graph_def())
        self.assertEqual(len(actual), len(expected))
        for i, result in enumerate(actual):
            self.assertEqual(result.name, expected[i].name)
            self.assertEqual(result.shape, _shape_of(expected[i]))
            self.assertEqual(result.tensor, expected[i].name + ':0')
Exemplo n.º 3
0
 def test_rename_input_nodes(self):
     """rename_input_nodes should rename input nodes in-place"""
     model_file = testutils.get_path_to(testutils.SIMPLE_MODEL_FILE_NAME)
     graph_def = testutils.get_sample_graph_def(model_file)
     updated = util.rename_input_nodes(graph_def, {'x': 'scalar'})
     # update should be in-place
     self.assertEqual(graph_def, updated)
     # inputs should be renamed
     self.assertEqual(util.get_input_nodes(updated)[0].name, 'scalar')
     # model should still work
     model = testutils.graph_to_model(updated)
     s = 18
     scalar = tf.constant([[s]], dtype=tf.float32)
     result = model(scalar)
     value = result[0].numpy()
     # value = np.reshape(value, (1))
     y = value[0]
     self.assertAlmostEqual(y, s*5, delta=0.1)
Exemplo n.º 4
0
def _build_signatures(graph: tf.Graph, signature_def_map: dict) -> dict:
    """
    Turn a signature map into a proper argument for ``add_meta_graph`` by
    creating tensor info for each output name given in ``signature_def_map``.

    Args:
        graph: TF Graph instance
        signature_def_map: Dictionary that maps signature keys to output names

    Returns:
        Signature definition map for passing to
        ``SavedModelBuilder.add_meta_graph``
    """
    inputs = {
        info.name: util._build_tensor_info(graph, info)
        for info in util.get_input_nodes(graph)
    }
    default_key = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
    output_tensors = set(name[0:-2] for name in util.get_output_tensors(graph))
    for key, value in signature_def_map.items():
        if SIGNATURE_OUTPUTS not in value:
            raise ValueError(f'Signature "{key or default_key}" is invalid: '
                             f'the key "{SIGNATURE_OUTPUTS}" is missing')
        if len(value[SIGNATURE_OUTPUTS]) == 0:
            raise ValueError(f'Signature key "{SIGNATURE_OUTPUTS}" must not'
                             'be empty')
        for name in value[SIGNATURE_OUTPUTS]:
            if name not in output_tensors:
                valid_outputs = str(output_tensors)[1:-1]
                raise ValueError(f'Signature "{key or default_key}" is '
                                 f'invalid: "{name}" is not an output tensor.'
                                 f' Valid outputs are {valid_outputs}')
    return {
        key or default_key: _build_signature(graph, inputs, info)
        for key, info in signature_def_map.items()
    }
Exemplo n.º 5
0
def convert_int64_to_int32(graph_def: r.GraphDef) -> r.GraphDef:
    """Convert int64 input to int32 for TFJS compatibility

    Args:
        graph_def: GraphDef proto containing the network layout
    Returns:
        Updated graph with int64 inputs converted to int32
    """
    inputs = util.get_input_nodes(graph_def)
    convert = [info.name for info in inputs if info.dtype == util.np.int64]
    if len(convert) == 0:
        return graph_def
    # quick access to nodes by name
    node_map = r.get_input_node_map(graph_def)
    # map of all node inputs to their referencing node and their argument index
    input_map = defaultdict(list)
    for node in graph_def.node:
        for index, name in enumerate(node.input):
            input_map[name].append((index, node))
    # type cast ops to add to the graph
    type_cast_ops = []
    # nodes that require a type cast op
    type_cast_candidates: Dict[str, Tuple[int, r.NodeDef]] = {}

    for node in map(lambda x: node_map[x], convert):
        _set_tensor_dtype(node, _DT_INT32)
        # find all nodes that reference this input and adjust their datatype
        # attributes if required
        # technical note: referenced_by is a stack, this really is a
        # depth-first recursion
        referenced_by = input_map[node.name]
        while len(referenced_by) > 0:
            idx, ref = referenced_by.pop()
            # get the input node and the index of the output tensor
            input_node, output_idx = _get_input_node(ref, idx, node_map)
            # find the description of this node's operation
            op = op_def_registry.get(ref.op)
            desc = op.input_arg[idx]
            # find out whether we can just change the input type and which
            # attributes we might need to touch
            if desc.type != 0 and desc.type != _DT_INT32:
                # input type is fixed and cannot be changed: add a type cast
                cast_op = _make_cast_node(input_node, output_idx, _DT_INT32,
                                          desc.type)
                ref.input[idx] = cast_op.name
                type_cast_ops.append(cast_op)
                node_map[cast_op.name] = cast_op
                input_map[cast_op.name].append((idx, ref))
            elif desc.type_list_attr != '' or desc.type_attr == '':
                # input arrays of potentially mixed types cannot be handled
                raise ValueError("don't know how to handle input type changes"
                                 f' for node "{ref.name}" op={ref.op}')
            else:
                # change the type of this input
                type_attr = desc.type_attr
                ref.attr[type_attr].type = _DT_INT32
                if ref.name in type_cast_candidates:
                    del type_cast_candidates[ref.name]
                # check the other inputs for type compatibility
                for i, desc in enumerate(op.input_arg):
                    if i == idx or desc.type_attr != type_attr:
                        continue  # not a matching input
                    input_node, output_idx = _get_input_node(ref, i, node_map)
                    if input_node.name in convert:
                        continue  # Placeholder that will be converted
                    src_type = _get_output_type(input_node, output_idx)
                    if src_type == _DT_INT32:
                        continue  # type matches already
                    if input_node.op == 'Const':
                        # weight tensor: harmonize_dtypes() will fix these
                        _set_tensor_dtype(input_node, _DT_INT32)
                    else:
                        # add node as a candidate for needing type cast op
                        type_cast_candidates[input_node.name] = (i, ref)
                # process any changed outputs next
                for idx, output in enumerate(op.output_arg):
                    if output.type_attr == type_attr:
                        input_name = _get_tensor_name(ref, idx)
                        referenced_by += input_map[input_name]

    for idx, ref in type_cast_candidates.values():
        # add type cast operations for all nodes that have a type mismatch
        inp_node, channel = _get_input_node(ref, idx, node_map)
        src_type = _get_output_type(inp_node, channel)
        if src_type != _DT_INT32:
            cast_op = _make_cast_node(inp_node, channel, src_type, _DT_INT32)
            ref.input[idx] = cast_op.name
            type_cast_ops.append(cast_op)
            node_map[cast_op.name] = cast_op

    graph_def.node.extend(type_cast_ops)
    return graph_def