def find_and_replace_pattern(self, graph: Graph): if graph.graph['layout'] != 'NHWC': # we check it here because this transformation is called explicitly from the pipeline return # reshape from 4D-5D -> ND. Insert Transpose(NC(D)HW->N(D)HWC) before Reshape for reinterp_shape_node_id in graph.get_nodes_with_attributes(reinterp_shape=True): reinterp_shape_node = Node(graph, reinterp_shape_node_id) assert 0 in reinterp_shape_node.in_nodes(), 'Node {} does not have 0 input. \n{}'.format( reinterp_shape_node_id, graph.dump_graph_for_graphviz()) input_shape = reinterp_shape_node.in_node(0).shape if not is_input_data_in_correct_layout(reinterp_shape_node, 0) and len(input_shape) >= 4: order_const = Const(graph, {'value': PermuteAttrs().get_nchw_to_nhwc_permutation(len(input_shape)).perm }).create_node() permute_node = Transpose(graph, {'name': reinterp_shape_node.in_port(0).get_source().node.name + '/Transpose' }).create_node() reinterp_shape_node.in_port(0).get_connection().insert_node(permute_node) order_const.out_port(0).connect(permute_node.in_port(1)) order_const.infer(order_const) # do not infer the Transpose node because it should have input data node in NCHW layout (but currently # it is NHWC because data node attributes has not been permuted yet) and produce output in NHWC layout # (which is true at this moment) permute_node['need_shape_inference'] = False # mark the Transpose output data node having correct layout so it's shape will not be permuted mark_output_as_in_correct_layout(permute_node, 0) # keep the reinterp_shape_node in NHWC layout mark_input_as_in_correct_layout(reinterp_shape_node, 0) mark_input_as_in_correct_layout(reinterp_shape_node, 1) # reshape from ND -> 4D-5D. Insert Transpose(N(D)HWC->NC(D)HW) after Reshape for reinterp_shape_node_id in graph.get_nodes_with_attributes(reinterp_shape=True): reinterp_shape_node = Node(graph, reinterp_shape_node_id) assert 0 in reinterp_shape_node.out_nodes(), 'Node {} does not have 0 output. \n{}'.format( reinterp_shape_node_id, graph.dump_graph_for_graphviz()) output_shape = reinterp_shape_node.out_node(0).shape if not is_output_data_in_correct_layout(reinterp_shape_node, 0) and len(output_shape) >= 4: order_const = Const(graph, { 'value': PermuteAttrs().get_nhwc_to_nchw_permutation(len(output_shape)).perm}).create_node() permute_node = Transpose(graph, {'name': reinterp_shape_node.id + '/Transpose'}).create_node() reinterp_shape_node.out_port(0).get_connection().insert_node(permute_node) order_const.out_port(0).connect(permute_node.in_port(1)) # the Reshape and Transpose operations should work in original (NHWC layout) so the Transpose # will convert it to the NCHW mark_input_as_in_correct_layout(permute_node, 0) mark_input_as_in_correct_layout(permute_node, 1) # do not set Transpose output data node 'correct_data_layout' attribute so the data node shape will be # permuted # keep the reinterp_shape_node in NHWC layout mark_output_as_in_correct_layout(reinterp_shape_node, 0) mark_input_as_in_correct_layout(reinterp_shape_node, 1) # do not re-infer the Transpose node because it output data node should be in NHWC layout to make the # rest of the graph consistent permute_node['need_shape_inference'] = False
def replace_pattern(self, graph: Graph, match: dict): if match['axis'].value is None or match['input'].shape is None: return dims = len(match['input'].shape) ones = np.ones(dims, dtype=np.int64) axis = np.array(match['axis'].value) axis = axis if axis.ndim != 0 else np.array([axis], dtype=np.int64) mean = graph.node[match['mean'].node] mean['stride'] = np.array(ones) # TODO: need to check axis with real layout spatial_dims = np.array(axis) mean['spatial_dims'] = spatial_dims mean['pad'] = np.zeros((dims, 2), np.int64) mean['pad_spatial_shape'] = np.array(mean['pad'][spatial_dims]) window = np.array(ones) window[spatial_dims] = match['input'].shape[spatial_dims] mean['window'] = window mean['TF_op'] = mean['op'] mean['op'] = 'AvgPool' mean['pool_method'] = 'avg' mean['rounding_type'] = 'ceil' mean['exclude_pad'] = 'true' mean['kernel_spatial'] = window[spatial_dims] graph.remove_edge(match['axis'].node, match['mean'].node) mean['permute_attrs'] = PermuteAttrs().update_attrs(attrs=[( 'pad', 'input:0'), ('stride', 'input:0'), ('window', 'input:0'), ('spatial_dims', 'input:0')]) if match['mean'].keep_dims == False: output = match['mean'].out_node() pool_node = match['mean'] # Keep dims for AvgPool shape = np.array(output.shape) for idx in spatial_dims: shape = np.insert(shape, idx, 1) graph.remove_edge(pool_node.id, output.id) # Create new data for pool with all dims pool_data = Op.create_data_node(graph, pool_node, {'shape': np.array(shape)}) # Create and connect reshape node reshape_op = Reshape(graph, {'dim': np.array(output.shape)}) reshape_node = reshape_op.create_node( [pool_data], dict(name='Reshape_', permute_attrs=PermuteAttrs().update_attrs( attrs=[('dim', 'output:0')]))) graph.create_edge(reshape_node, output)
def apply_nhwc_to_nchw_permutation(graph: Graph): # Add NHWC to NCHW permutation for all data nodes (only for nodes without permutation) if graph.graph['layout'] == 'NCHW': return for node in graph.get_data_nodes(): if node.has_and_set('nchw_layout'): continue # Get NHWC to NCHW permutation for N dims, where N = len(node.shape) permutation = PermuteAttrs().get_nhwc_to_nchw_permutation( len(node.shape)) # Check that data node already has permutation skip_permutation = False for in_node in node.in_nodes(): edge_attrs = node.graph.get_edge_data(in_node.id, node.id)[0] if 'permutation' in edge_attrs: skip_permutation = True for out_node in node.out_nodes(): edge_attrs = node.graph.get_edge_data(node.id, out_node.id)[0] if 'permutation' in edge_attrs: skip_permutation = True if skip_permutation: continue # Set permutation to all in/out edges for in_node in node.in_nodes(): PermuteAttrs.set_permutation(in_node, node, permutation) for out_node in node.out_nodes(): PermuteAttrs.set_permutation(node, out_node, permutation)
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_data_nodes(): if node.has_and_set('nchw_layout'): continue # Get NHWC to NCHW permutation for N dims, where N = len(node.shape) permutation = PermuteAttrs().get_nhwc_to_nchw_permutation( len(node.shape)) # Check that data node already has permutation skip_permutation = False for in_node in node.in_nodes(): edge_attrs = node.graph.get_edge_data(in_node.id, node.id)[0] if 'permutation' in edge_attrs: skip_permutation = True for out_node in node.out_nodes(): edge_attrs = node.graph.get_edge_data(node.id, out_node.id)[0] if 'permutation' in edge_attrs: skip_permutation = True if skip_permutation: continue # Set permutation to all in/out edges for in_node in node.in_nodes(): PermuteAttrs.set_permutation(in_node, node, permutation) for out_node in node.out_nodes(): PermuteAttrs.set_permutation(node, out_node, permutation)
def test_transpose_insert(self, nhwc_to_nchw_order, nchw_to_nhwc_order, add_permutation_attrs): graph_nodes = { **valued_const_with_data('transpose_parameter_order', np.array(nhwc_to_nchw_order)), **valued_const_with_data('transpose_result_order', np.array(nchw_to_nhwc_order)) } graph_nodes.update(nodes) shape_len = len(nhwc_to_nchw_order) if add_permutation_attrs else 3 shape = np.array(range(shape_len)) add_shape = shape if nhwc_to_nchw_order is None else shape[nhwc_to_nchw_order] graph_nodes.update( { **regular_op_with_shaped_data('placeholder1', shape, {'type': 'Parameter', 'rt_info': RTInfo(), 'shape': shape}), **regular_op_with_shaped_data('result', shape, {'type': 'Result', 'rt_info': RTInfo(), 'shape': shape}), **regular_op_with_shaped_data('add', add_shape, {'type': 'Add', 'op': 'Add', 'infer': copy_shape_infer}), } ) graph = build_graph(graph_nodes, edges) graph_ref = build_graph(graph_nodes, edges_with_transpose if add_permutation_attrs else edges) param_node = Node(graph, 'placeholder1') result_node = Node(graph, 'result') if add_permutation_attrs: shape_len = len(nhwc_to_nchw_order) param_node['permute_attrs'] = PermuteAttrs().update_attrs(attrs=[('shape', 'output:0')]) param_node.out_node(0)['permutation'] = PermuteAttrs().get_nhwc_to_nchw_permutation(shape_len) result_node.in_node(0)['permutation'] = PermuteAttrs().get_nhwc_to_nchw_permutation(shape_len) PreserveRuntimeInfo().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'result') self.assertTrue(flag, resp) self.assertFalse(param_node.has_valid('permute_attrs')) self.assertFalse(param_node.out_node(0).has_valid('permutation')) if add_permutation_attrs: rt_info = param_node.rt_info.info old_api_map = rt_info[('old_api_map_order', 0)].info self.assertTrue(np.array_equal(old_api_map['inverse_order'], nchw_to_nhwc_order)) rt_info = result_node.rt_info.info old_api_map = rt_info[('old_api_map_order', 0)].info self.assertTrue(np.array_equal(old_api_map['order'], nhwc_to_nchw_order))
def tf_placeholder_ext(pb): return { 'data_type': tf_dtype_extractor(pb.attr["dtype"].type), 'shape': tf_tensor_shape(pb.attr["shape"].shape), 'type': 'Input', 'infer': lambda node: single_output_infer(node, lambda n: n.shape), 'permute_attrs': PermuteAttrs().update_attrs(attrs=[('shape', 'output:0')]) }
def extract(node): attrs = { 'data_type': tf_dtype_extractor(node.pb.attr["dtype"].type), 'shape': tf_tensor_shape(node.pb.attr["shape"].shape), 'permute_attrs': PermuteAttrs().update_attrs(attrs=[('shape', 'output:0')]) } Parameter.update_node_stat(node, attrs) return __class__.enabled
def convert_graph_inputs_to_parameters(internal_graph, internal_graph_proto): # create Parameter nodes for the body graph body_parameters = [] body_parameter_names = [] for idx, pb_node in enumerate(internal_graph_proto['input_arg']): param_id = internal_graph.unique_id(pb_node.name) internal_graph.add_node(param_id, name=param_id, kind='op', op='Parameter', pb=None, shape=None) parameter_node = Node(internal_graph, pb_node.name) Parameter.update_node_stat( parameter_node, { 'data_type': tf_dtype_extractor(pb_node.type), 'permute_attrs': PermuteAttrs().update_attrs(attrs=[('shape', 'output:0')]) }) body_parameters.append(parameter_node) body_parameter_names.append(param_id) return body_parameters, body_parameter_names
def extract(cls, loop_node): Loop.update_node_stat(loop_node, {}) loop_name = loop_node.soft_get('name', loop_node.id) # check that required body and condition functions exist in the graph library main_graph = loop_node.graph body_graph_name = loop_node.pb.attr['body'].func.name cond_graph_name = loop_node.pb.attr['cond'].func.name assert 'library' in main_graph.graph, 'The graph does not contain a library that is required ' \ 'by node with name "{}".'.format(loop_name) library_graph = main_graph.graph['library'] assert body_graph_name in library_graph, 'The library does not contain a function with name "{}" ' \ 'that is required by node ' \ 'with name "{}".'.format(body_graph_name, loop_name) body_graph_proto = library_graph[body_graph_name] assert cond_graph_name in library_graph, 'The library does not contain a function with name "{}" ' \ 'that is required by node ' \ 'with name "{}".'.format(cond_graph_name, loop_name) cond_graph_proto = library_graph[cond_graph_name] body_graph = Graph() # fill the body graph for attr_key in main_graph.graph.keys(): if attr_key != 'library': body_graph.graph[attr_key] = copy.deepcopy(main_graph.graph[attr_key]) else: # it is sufficient to have a link to the library body_graph.graph['library'] = main_graph.graph['library'] loop_node['body'] = body_graph # create Parameter nodes for the body graph body_parameters = [] body_parameter_names = [] for idx, pb_node in enumerate(body_graph_proto['input_arg']): param_id = body_graph.unique_id(pb_node.name) body_graph.add_node(param_id, name=param_id, kind='op', op='Parameter', pb=None, shape=None) parameter_node = Node(body_graph, pb_node.name) Parameter.update_node_stat(parameter_node, {'data_type': tf_dtype_extractor(pb_node.type), 'permute_attrs': PermuteAttrs().update_attrs(attrs=[('shape', 'output:0')])} ) body_parameters.append(parameter_node) body_parameter_names.append(param_id) # update the loop body graph with the body function graph body_results = [] update_body_graph(body_graph, body_graph_proto, body_parameter_names, body_results) # update the loop body graph with the condition function graph update_body_graph(body_graph, cond_graph_proto, body_parameter_names, body_results) # add 'internal_layer_id' attribute which is a must have attribute for the loop body node for idx, body_node in enumerate(body_graph.get_op_nodes()): body_node['internal_layer_id'] = idx body_graph.stage = 'front' # Currently, # Loop Inputs Order: # 0 - current iteration # 1 - trip count # 2.. - "loop carried" dependencies variables # # Body Inputs Order: # 0 - current iteration # 1 - trip count # 2.. - "loop carried" dependencies variables # # Body Outputs Order: # 0 - current iteration # 1 - trip count # 2.. - "loop carried" dependencies variables # # Loop Outputs Order: # 0 - current iteration # 1 - trip count # 2.. - "loop carried" dependencies variables # # so inputs must be reordered and execution condition must be created in the front transformation # to be aligned with the specification # connect external input ports with body parameter nodes except current iteration # since it must be disconnected from external port for idx in range(1, len(body_parameters)): Loop.connect_body_input(loop_node, idx, body_parameters[idx]) # mark current iteration input Parameter node and execution condition Result node Loop.mark_current_iteration_parameter_node(loop_node, body_parameters[0]) Loop.mark_execution_condition_result_node(loop_node, body_results[-1]) # connect back edges in the body except current iteration for idx in range(1, len(body_parameters)): Loop.add_back_edge(loop_node, body_parameters[idx], body_results[idx]) # connect body outputs with Loop operation output ports except the execution condition result for idx in range(len(body_results)-1): Loop.connect_body_output(loop_node, idx, body_results[idx]) # run function to parse body nodes attributes similar to the main graph extract_node_attrs(body_graph, lambda node: tf_op_extractor(node, check_for_duplicates(tf_op_extractors))) return cls.enabled