Exemplo n.º 1
0
    def find_and_replace_pattern(self, graph: Graph):
        if graph.graph['layout'] != 'NHWC':
            # we check it here because this transformation is called explicitly from the pipeline
            return

        # reshape from 4D-5D -> ND. Insert Transpose(NC(D)HW->N(D)HWC) before Reshape
        for reinterp_shape_node_id in graph.get_nodes_with_attributes(reinterp_shape=True):
            reinterp_shape_node = Node(graph, reinterp_shape_node_id)
            assert 0 in reinterp_shape_node.in_nodes(), 'Node {} does not have 0 input. \n{}'.format(
                reinterp_shape_node_id, graph.dump_graph_for_graphviz())
            input_shape = reinterp_shape_node.in_node(0).shape
            if not is_input_data_in_correct_layout(reinterp_shape_node, 0) and len(input_shape) >= 4:
                order_const = Const(graph, {'value': PermuteAttrs().get_nchw_to_nhwc_permutation(len(input_shape)).perm
                                            }).create_node()
                permute_node = Transpose(graph,
                                         {'name': reinterp_shape_node.in_port(0).get_source().node.name + '/Transpose'
                                          }).create_node()
                reinterp_shape_node.in_port(0).get_connection().insert_node(permute_node)
                order_const.out_port(0).connect(permute_node.in_port(1))
                order_const.infer(order_const)

                # do not infer the Transpose node because it should have input data node in NCHW layout (but currently
                # it is NHWC because data node attributes has not been permuted yet) and produce output in NHWC layout
                # (which is true at this moment)
                permute_node['need_shape_inference'] = False
                # mark the Transpose output data node having correct layout so it's shape will not be permuted
                mark_output_as_in_correct_layout(permute_node, 0)

                # keep the reinterp_shape_node in NHWC layout
                mark_input_as_in_correct_layout(reinterp_shape_node, 0)
                mark_input_as_in_correct_layout(reinterp_shape_node, 1)

        # reshape from ND -> 4D-5D. Insert Transpose(N(D)HWC->NC(D)HW) after Reshape
        for reinterp_shape_node_id in graph.get_nodes_with_attributes(reinterp_shape=True):
            reinterp_shape_node = Node(graph, reinterp_shape_node_id)
            assert 0 in reinterp_shape_node.out_nodes(), 'Node {} does not have 0 output. \n{}'.format(
                reinterp_shape_node_id, graph.dump_graph_for_graphviz())
            output_shape = reinterp_shape_node.out_node(0).shape
            if not is_output_data_in_correct_layout(reinterp_shape_node, 0) and len(output_shape) >= 4:
                order_const = Const(graph, {
                    'value': PermuteAttrs().get_nhwc_to_nchw_permutation(len(output_shape)).perm}).create_node()
                permute_node = Transpose(graph, {'name': reinterp_shape_node.id + '/Transpose'}).create_node()
                reinterp_shape_node.out_port(0).get_connection().insert_node(permute_node)
                order_const.out_port(0).connect(permute_node.in_port(1))

                # the Reshape and Transpose operations should work in original (NHWC layout) so the Transpose
                # will convert it to the NCHW
                mark_input_as_in_correct_layout(permute_node, 0)
                mark_input_as_in_correct_layout(permute_node, 1)
                # do not set Transpose output data node 'correct_data_layout' attribute so the data node shape will be
                # permuted

                # keep the reinterp_shape_node in NHWC layout
                mark_output_as_in_correct_layout(reinterp_shape_node, 0)
                mark_input_as_in_correct_layout(reinterp_shape_node, 1)

                # do not re-infer the Transpose node because it output data node should be in NHWC layout to make the
                # rest of the graph consistent
                permute_node['need_shape_inference'] = False
Exemplo n.º 2
0
    def replace_pattern(self, graph: Graph, match: dict):
        if match['axis'].value is None or match['input'].shape is None:
            return
        dims = len(match['input'].shape)
        ones = np.ones(dims, dtype=np.int64)
        axis = np.array(match['axis'].value)
        axis = axis if axis.ndim != 0 else np.array([axis], dtype=np.int64)

        mean = graph.node[match['mean'].node]
        mean['stride'] = np.array(ones)
        # TODO: need to check axis with real layout
        spatial_dims = np.array(axis)
        mean['spatial_dims'] = spatial_dims
        mean['pad'] = np.zeros((dims, 2), np.int64)
        mean['pad_spatial_shape'] = np.array(mean['pad'][spatial_dims])
        window = np.array(ones)
        window[spatial_dims] = match['input'].shape[spatial_dims]
        mean['window'] = window
        mean['TF_op'] = mean['op']
        mean['op'] = 'AvgPool'
        mean['pool_method'] = 'avg'
        mean['rounding_type'] = 'ceil'
        mean['exclude_pad'] = 'true'
        mean['kernel_spatial'] = window[spatial_dims]
        graph.remove_edge(match['axis'].node, match['mean'].node)
        mean['permute_attrs'] = PermuteAttrs().update_attrs(attrs=[(
            'pad',
            'input:0'), ('stride',
                         'input:0'), ('window',
                                      'input:0'), ('spatial_dims', 'input:0')])

        if match['mean'].keep_dims == False:
            output = match['mean'].out_node()
            pool_node = match['mean']

            # Keep dims for AvgPool
            shape = np.array(output.shape)
            for idx in spatial_dims:
                shape = np.insert(shape, idx, 1)

            graph.remove_edge(pool_node.id, output.id)
            # Create new data for pool with all dims
            pool_data = Op.create_data_node(graph, pool_node,
                                            {'shape': np.array(shape)})
            # Create and connect reshape node
            reshape_op = Reshape(graph, {'dim': np.array(output.shape)})
            reshape_node = reshape_op.create_node(
                [pool_data],
                dict(name='Reshape_',
                     permute_attrs=PermuteAttrs().update_attrs(
                         attrs=[('dim', 'output:0')])))
            graph.create_edge(reshape_node, output)
Exemplo n.º 3
0
def apply_nhwc_to_nchw_permutation(graph: Graph):
    # Add NHWC to NCHW permutation for all data nodes (only for nodes without permutation)
    if graph.graph['layout'] == 'NCHW':
        return

    for node in graph.get_data_nodes():
        if node.has_and_set('nchw_layout'):
            continue

        # Get NHWC to NCHW permutation for N dims, where N = len(node.shape)
        permutation = PermuteAttrs().get_nhwc_to_nchw_permutation(
            len(node.shape))

        # Check that data node already has permutation
        skip_permutation = False
        for in_node in node.in_nodes():
            edge_attrs = node.graph.get_edge_data(in_node.id, node.id)[0]
            if 'permutation' in edge_attrs:
                skip_permutation = True
        for out_node in node.out_nodes():
            edge_attrs = node.graph.get_edge_data(node.id, out_node.id)[0]
            if 'permutation' in edge_attrs:
                skip_permutation = True

        if skip_permutation:
            continue

        # Set permutation to all in/out edges
        for in_node in node.in_nodes():
            PermuteAttrs.set_permutation(in_node, node, permutation)

        for out_node in node.out_nodes():
            PermuteAttrs.set_permutation(node, out_node, permutation)
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_data_nodes():
            if node.has_and_set('nchw_layout'):
                continue

            # Get NHWC to NCHW permutation for N dims, where N = len(node.shape)
            permutation = PermuteAttrs().get_nhwc_to_nchw_permutation(
                len(node.shape))

            # Check that data node already has permutation
            skip_permutation = False
            for in_node in node.in_nodes():
                edge_attrs = node.graph.get_edge_data(in_node.id, node.id)[0]
                if 'permutation' in edge_attrs:
                    skip_permutation = True
            for out_node in node.out_nodes():
                edge_attrs = node.graph.get_edge_data(node.id, out_node.id)[0]
                if 'permutation' in edge_attrs:
                    skip_permutation = True

            if skip_permutation:
                continue

            # Set permutation to all in/out edges
            for in_node in node.in_nodes():
                PermuteAttrs.set_permutation(in_node, node, permutation)

            for out_node in node.out_nodes():
                PermuteAttrs.set_permutation(node, out_node, permutation)
Exemplo n.º 5
0
    def test_transpose_insert(self, nhwc_to_nchw_order, nchw_to_nhwc_order, add_permutation_attrs):
        graph_nodes = {
            **valued_const_with_data('transpose_parameter_order', np.array(nhwc_to_nchw_order)),
            **valued_const_with_data('transpose_result_order', np.array(nchw_to_nhwc_order))
        }
        graph_nodes.update(nodes)
        shape_len = len(nhwc_to_nchw_order) if add_permutation_attrs else 3
        shape = np.array(range(shape_len))
        add_shape = shape if nhwc_to_nchw_order is None else shape[nhwc_to_nchw_order]
        graph_nodes.update(
            {
                **regular_op_with_shaped_data('placeholder1', shape,
                                              {'type': 'Parameter', 'rt_info': RTInfo(), 'shape': shape}),
                **regular_op_with_shaped_data('result', shape, {'type': 'Result', 'rt_info': RTInfo(), 'shape': shape}),
                **regular_op_with_shaped_data('add', add_shape,
                                              {'type': 'Add', 'op': 'Add', 'infer': copy_shape_infer}),
            }
        )

        graph = build_graph(graph_nodes, edges)
        graph_ref = build_graph(graph_nodes, edges_with_transpose if add_permutation_attrs else edges)

        param_node = Node(graph, 'placeholder1')
        result_node = Node(graph, 'result')

        if add_permutation_attrs:
            shape_len = len(nhwc_to_nchw_order)
            param_node['permute_attrs'] = PermuteAttrs().update_attrs(attrs=[('shape', 'output:0')])
            param_node.out_node(0)['permutation'] = PermuteAttrs().get_nhwc_to_nchw_permutation(shape_len)
            result_node.in_node(0)['permutation'] = PermuteAttrs().get_nhwc_to_nchw_permutation(shape_len)

        PreserveRuntimeInfo().find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
        self.assertTrue(flag, resp)

        self.assertFalse(param_node.has_valid('permute_attrs'))
        self.assertFalse(param_node.out_node(0).has_valid('permutation'))

        if add_permutation_attrs:
            rt_info = param_node.rt_info.info
            old_api_map = rt_info[('old_api_map_order', 0)].info
            self.assertTrue(np.array_equal(old_api_map['inverse_order'], nchw_to_nhwc_order))

            rt_info = result_node.rt_info.info
            old_api_map = rt_info[('old_api_map_order', 0)].info
            self.assertTrue(np.array_equal(old_api_map['order'], nhwc_to_nchw_order))
Exemplo n.º 6
0
def tf_placeholder_ext(pb):
    return {
        'data_type': tf_dtype_extractor(pb.attr["dtype"].type),
        'shape': tf_tensor_shape(pb.attr["shape"].shape),
        'type': 'Input',
        'infer': lambda node: single_output_infer(node, lambda n: n.shape),
        'permute_attrs': PermuteAttrs().update_attrs(attrs=[('shape', 'output:0')])
    }
Exemplo n.º 7
0
 def extract(node):
     attrs = {
         'data_type': tf_dtype_extractor(node.pb.attr["dtype"].type),
         'shape': tf_tensor_shape(node.pb.attr["shape"].shape),
         'permute_attrs': PermuteAttrs().update_attrs(attrs=[('shape', 'output:0')])
     }
     Parameter.update_node_stat(node, attrs)
     return __class__.enabled
Exemplo n.º 8
0
def convert_graph_inputs_to_parameters(internal_graph, internal_graph_proto):
    # create Parameter nodes for the body graph
    body_parameters = []
    body_parameter_names = []
    for idx, pb_node in enumerate(internal_graph_proto['input_arg']):
        param_id = internal_graph.unique_id(pb_node.name)
        internal_graph.add_node(param_id,
                                name=param_id,
                                kind='op',
                                op='Parameter',
                                pb=None,
                                shape=None)
        parameter_node = Node(internal_graph, pb_node.name)
        Parameter.update_node_stat(
            parameter_node, {
                'data_type':
                tf_dtype_extractor(pb_node.type),
                'permute_attrs':
                PermuteAttrs().update_attrs(attrs=[('shape', 'output:0')])
            })
        body_parameters.append(parameter_node)
        body_parameter_names.append(param_id)
    return body_parameters, body_parameter_names
Exemplo n.º 9
0
    def extract(cls, loop_node):
        Loop.update_node_stat(loop_node, {})
        loop_name = loop_node.soft_get('name', loop_node.id)

        # check that required body and condition functions exist in the graph library
        main_graph = loop_node.graph
        body_graph_name = loop_node.pb.attr['body'].func.name
        cond_graph_name = loop_node.pb.attr['cond'].func.name
        assert 'library' in main_graph.graph, 'The graph does not contain a library that is required ' \
                                              'by node with name "{}".'.format(loop_name)
        library_graph = main_graph.graph['library']

        assert body_graph_name in library_graph, 'The library does not contain a function with name "{}" ' \
                                                 'that is required by node ' \
                                                 'with name "{}".'.format(body_graph_name, loop_name)
        body_graph_proto = library_graph[body_graph_name]

        assert cond_graph_name in library_graph, 'The library does not contain a function with name "{}" ' \
                                                 'that is required by node ' \
                                                 'with name "{}".'.format(cond_graph_name, loop_name)
        cond_graph_proto = library_graph[cond_graph_name]

        body_graph = Graph()
        # fill the body graph
        for attr_key in main_graph.graph.keys():
            if attr_key != 'library':
                body_graph.graph[attr_key] = copy.deepcopy(main_graph.graph[attr_key])
            else:
                # it is sufficient to have a link to the library
                body_graph.graph['library'] = main_graph.graph['library']
        loop_node['body'] = body_graph

        # create Parameter nodes for the body graph
        body_parameters = []
        body_parameter_names = []
        for idx, pb_node in enumerate(body_graph_proto['input_arg']):
            param_id = body_graph.unique_id(pb_node.name)
            body_graph.add_node(param_id, name=param_id, kind='op', op='Parameter', pb=None, shape=None)
            parameter_node = Node(body_graph, pb_node.name)
            Parameter.update_node_stat(parameter_node,
                                       {'data_type': tf_dtype_extractor(pb_node.type),
                                        'permute_attrs': PermuteAttrs().update_attrs(attrs=[('shape', 'output:0')])}
                                       )
            body_parameters.append(parameter_node)
            body_parameter_names.append(param_id)

        # update the loop body graph with the body function graph
        body_results = []
        update_body_graph(body_graph, body_graph_proto, body_parameter_names, body_results)

        # update the loop body graph with the condition function graph
        update_body_graph(body_graph, cond_graph_proto, body_parameter_names, body_results)

        # add 'internal_layer_id' attribute which is a must have attribute for the loop body node
        for idx, body_node in enumerate(body_graph.get_op_nodes()):
            body_node['internal_layer_id'] = idx

        body_graph.stage = 'front'

        # Currently,
        # Loop Inputs Order:
        #   0    - current iteration
        #   1    - trip count
        #   2..  - "loop carried" dependencies variables
        #
        # Body Inputs Order:
        #   0    - current iteration
        #   1    - trip count
        #   2..  - "loop carried" dependencies variables
        #
        # Body Outputs Order:
        #   0      - current iteration
        #   1      - trip count
        #   2..    - "loop carried" dependencies variables
        #
        # Loop Outputs Order:
        #   0    - current iteration
        #   1    - trip count
        #   2..  - "loop carried" dependencies variables
        #
        # so inputs must be reordered and execution condition must be created in the front transformation
        # to be aligned with the specification

        # connect external input ports with body parameter nodes except current iteration
        # since it must be disconnected from external port
        for idx in range(1, len(body_parameters)):
            Loop.connect_body_input(loop_node, idx, body_parameters[idx])

        # mark current iteration input Parameter node and execution condition Result node
        Loop.mark_current_iteration_parameter_node(loop_node, body_parameters[0])
        Loop.mark_execution_condition_result_node(loop_node, body_results[-1])

        # connect back edges in the body except current iteration
        for idx in range(1, len(body_parameters)):
            Loop.add_back_edge(loop_node, body_parameters[idx], body_results[idx])

        # connect body outputs with Loop operation output ports except the execution condition result
        for idx in range(len(body_results)-1):
            Loop.connect_body_output(loop_node, idx, body_results[idx])

        # run function to parse body nodes attributes similar to the main graph
        extract_node_attrs(body_graph, lambda node: tf_op_extractor(node, check_for_duplicates(tf_op_extractors)))
        return cls.enabled