def grouped_convolutions_fusing(graph: Graph): while True: is_fused = False graph_clean_up(graph, ['TFCustomSubgraphCall', 'ShapeOf', 'Shape']) for node in graph.pseudo_topological_sort(): if node.kind == 'op' and len(node.out_nodes()) > 1: if node.soft_get('can_be_fused') == False: continue is_valid_convolutions = True last_layer = None next_nodes = get_next_operation(node) # Check that all operation after this one are Convolutions # and all convolutions has same output if len(next_nodes) > 1 and all(_node.soft_get('type') in ['Convolution', 'Deconvolution'] for _node in next_nodes): for conv in next_nodes: conv_outputs = get_next_operation(conv) if conv.soft_get('can_be_fused') == False: is_valid_convolutions = False if len(conv_outputs) != 1: is_valid_convolutions = False if last_layer is None: last_layer = conv_outputs[0].id elif conv_outputs[0].id != last_layer: is_valid_convolutions = False if is_valid_convolutions: is_fused = concat_convolutions(graph, node, Node(graph, last_layer)) if is_fused: break if not is_fused: break
def _simple_stride_prop(graph: Graph, node: Node, spatial_dims, supported=True): """ This function handles stride propagation for op nodes. If node is in supported ops dict so this is supported operation and we can propagate stride directly via this op (stride_prop will be set by using bottom stride_prop), otherwise we can't and stride_prop attr will be set as 1,1,1,1 """ next_ops = get_next_operation(node) stride_props, all_ops_are_valid = _check_next_ops(next_ops) if not supported or not all_ops_are_valid: # We have to insert pooling layers for op in next_ops: if op.has_valid('stride_prop') and not np.array_equal(op.stride_prop[spatial_dims], np.array([1, 1])) and \ (op.has_valid('has_stride') == False or op.soft_get('has_stride') == False): _insert_pooling(graph, node.out_node(), op, spatial_dims) # If Convolution is valid then set `stride_prop` to Convolution stride node['stride_prop'] = np.array([1, 1, 1, 1]) return for op in next_ops: if op.soft_get('has_stride') == True: op.stride = np.array([1, 1, 1, 1]) log.debug( "STRIDE PROP: {} {} strides was moved upper via {}".format( op.type, op.name, node.name)) node['stride_prop'] = np.array( stride_props[0]) if len(stride_props) > 0 else np.array([1, 1, 1, 1]) node['is_partial_inferred'] = False _clean_fw_tensor_attrs(node.out_node())
def _conv_stride_prop(graph: Graph, node: Node, spatial_dims, supported=True): """ This function handles convolution stride propagation. There is two cases: conv->(op) and conv->conv. In first case we propagate stride from op, and in second case we also change stride for second conv """ next_ops = get_next_operation(node) stride_props, all_ops_are_valid = _check_next_ops(next_ops) def _check_convolution(node: Node): return node.has_valid('kernel_spatial') and np.array_equal( node.kernel_spatial, np.array([1, 1])) # Check that all ops are valid and have same values if not all_ops_are_valid: # We have to insert pooling layers for op in next_ops: if op.has_valid('stride_prop') and not np.array_equal( op.stride_prop[spatial_dims], np.array([1, 1])): # Insert pooling _insert_pooling(graph, node.out_node(), op, spatial_dims) elif len(stride_props) > 0: node.stride *= stride_props[0] log.debug('STRIDE PROP: {} got new strides {}'.format( node.name, node.stride)) for op in next_ops: if op.soft_get('has_stride') == True: op.stride = np.array([1, 1, 1, 1]) node['is_partial_inferred'] = False node['output_spatial_shape'] = False _clean_fw_tensor_attrs(node.out_node()) # If Convolution is valid then set `stride_prop` to Convolution stride node['stride_prop'] = np.array( node.stride) if _check_convolution(node) else np.array([1, 1, 1, 1])
def find_and_replace_pattern(self, graph: Graph): from tensorflow.core.framework import types_pb2 as tf_types # pylint: disable=no-name-in-module for node_name, node_attrs in list(graph.nodes(data=True)): node = Node(graph, node_name) pb = node_attrs.get('pb') if pb is not None and pb.op == 'Parameter' and pb.attr['dtype'].type != tf_types.DT_FLOAT: log.info('Placeholder "{}" has type that is different from DT_FLOAT'.format(node_name)) next_ops = get_next_operation(node) # check that all output nodes are nodes of type 'ToFloat' if all([ChangePlaceholderTypes.is_node_casts_to_float(op) and len(op.in_nodes()) == 1 for op in next_ops]): ChangePlaceholderTypes.change_node_type(node, tf_types.DT_FLOAT) ChangePlaceholderTypes.remove_node_preserving_edges(node, next_ops) # remove 'Cast' nodes elif all([ChangePlaceholderTypes.is_node_gather(op) for op in next_ops] for op in next_ops): ChangePlaceholderTypes.change_node_type(node, tf_types.DT_FLOAT) else: raise Error( ('Cannot convert type of placeholder "{}" because not all of its outputs are "Cast" to float ' 'operations: {}. ' + refer_to_faq_msg(49)), node.soft_get('name'), [op.soft_get('name') for op in next_ops] )
def change_placeholders_types_to_FP32(graph: nx.MultiDiGraph): for node_name, node_attrs in list(graph.nodes(data=True)): node = Node(graph, node_name) pb = node_attrs.get('pb') if pb is not None and pb.op == 'Placeholder' and pb.attr[ 'dtype'].type != tf_types.DT_FLOAT: log.info( 'Placeholder "{}" has type that is different from DT_FLOAT'. format(node_name)) next_ops = get_next_operation(node) # check that all output nodes are nodes of type 'ToFloat' if all([ is_node_casts_to_float(op) and len(op.in_nodes()) == 1 for op in next_ops ]): change_node_type(node, tf_types.DT_FLOAT) remove_node_preserving_edges(node, next_ops) # remove 'Cast' nodes elif all([is_node_gather(op) for op in next_ops] for op in next_ops): change_node_type(node, tf_types.DT_FLOAT) else: raise Error(( 'Cannot convert type of placeholder "{}" because not all of its outputs are "Cast" to float ' 'operations: {}. ' + refer_to_faq_msg(49)), node.soft_get('name'), [op.soft_get('name') for op in next_ops]) return graph
def find_and_replace_pattern(self, graph: Graph): for node in list(graph.nodes()): if node not in graph.nodes(): continue permute_node = Node(graph, node) if permute_node.has_valid( 'type') and permute_node.type == 'Permute': list_of_permutes = [permute_node] # Get sequence of permutations node = permute_node while True: next_ops = get_next_operation(node) if len(next_ops) != 1: break next_op = next_ops[0] if next_op.has_valid('type') and next_op.type == 'Permute': list_of_permutes.append(next_op) node = next_op else: break final_permutation = np.array( [x for x in range(len(list_of_permutes[0].order))], dtype=np.int64) for permute in list_of_permutes: if not permute.has_valid('order'): raise Error( "Permute node {} has wrong attribute order = None". format(permute.name)) final_permutation = final_permutation[np.array( permute.order, dtype=np.int64)] if np.array_equal( final_permutation, [x for x in range(len(list_of_permutes[0].order))]): first_data_node, last_data_node = list_of_permutes[ 0].in_node(), list_of_permutes[-1].out_node() graph.remove_edge(first_data_node.id, list_of_permutes[0].id) else: if len(list_of_permutes) < 2: continue first_data_node, last_data_node = list_of_permutes[ 0].out_node(), list_of_permutes[-1].out_node() list_of_permutes[0].order = final_permutation graph.remove_edge(first_data_node.id, first_data_node.out_node().id) graph.remove_edge(last_data_node.in_node().id, last_data_node.id) merge_data_nodes(graph, first_data_node, last_data_node) graph.remove_node(last_data_node.id) graph_clean_up_tf(graph)
def find_and_replace_pattern(self, graph: Graph): for permute_node in graph.get_op_nodes(type='Transpose'): if permute_node.id not in graph.nodes(): continue list_of_permutes = [permute_node] # Get sequence of permutations node = permute_node while True: next_ops = get_next_operation(node) if len(next_ops) != 1: break next_op = next_ops[0] if next_op.soft_get('type') == 'Transpose': list_of_permutes.append(next_op) node = next_op else: break final_permutation = int64_array([ x for x in range( len(list_of_permutes[0].in_port(1).data.get_value())) ]) for permute in list_of_permutes: order = permute.in_port(1).data.get_value() if order is None: raise Error( "Transpose node {} has wrong order for permute = None". format(permute.name)) final_permutation = final_permutation[int64_array(order)] if np.array_equal(final_permutation, [ x for x in range( len(list_of_permutes[0].in_port(1).data.get_value())) ]): first_data_node, last_data_node = list_of_permutes[0].in_node( ), list_of_permutes[-1].out_node() graph.remove_edge(first_data_node.id, list_of_permutes[0].id) else: if len(list_of_permutes) < 2: continue first_data_node, last_data_node = list_of_permutes[0].out_node( ), list_of_permutes[-1].out_node() list_of_permutes[0].in_port(1).data.set_value( final_permutation) graph.remove_edge(first_data_node.id, first_data_node.out_node().id) graph.remove_edge(last_data_node.in_node().id, last_data_node.id) merge_data_nodes(graph, first_data_node, last_data_node) graph.remove_node(last_data_node.id) graph.clean_up()
def test_get_next_operation_3(self): # Placeholder-+--->ScaleShift # +-----^ graph = build_graph(nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1', 'placeholder_2_data'), ('placeholder_1_data', 'mul_1'), ('placeholder_2_data', 'mul_1'), ('mul_1', 'mul_1_data'), ('mul_1_data', 'op_output') ]) res = get_next_operation(Node(graph, 'placeholder_1')) self.assertTrue(len(res) == 1 and res[0].id == 'mul_1', 'get_nex_operation returned wrong op')
def test_get_next_operation_2(self): # Placeholder->ScaleShift->Mul->Add graph = build_graph(nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'mul_1'), ('placeholder_1_data', 'add_1'), ('mul_1', 'mul_1_data'), ('mul_1_data', 'add_1'), ('add_1', 'add_1_data'), ('add_1_data', 'op_output')]) res = get_next_operation(Node(graph, 'placeholder_1')) self.assertTrue( len(res) == 2 and all([x.id in ['add_1', 'mul_1'] for x in res]), 'get_nex_operation returned wrong op')
def test_get_next_operation_1(self): # Placeholder->ScaleShift->Mul->Add graph = build_graph(nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'scaleshift_1'), ('scaleshift_1_w', 'scaleshift_1'), ('scaleshift_1', 'scaleshift_1_data'), ('scaleshift_1_data', 'mul_1'), ('mul_1', 'mul_1_data'), ('mul_1_data', 'add_1'), ('add_1', 'add_1_data'), ('add_1_data', 'op_output') ]) res = get_next_operation(Node(graph, 'mul_1')) self.assertTrue(len(res) == 1 and res[0].id == 'add_1', 'get_nex_operation returned wrong op')
def find_and_replace_pattern(self, graph: Graph): for start_node in graph.pseudo_topological_sort(): matched_nodes = [] if self.is_node_match_for_optimization(start_node): next_node = start_node while self.is_node_match_for_optimization(next_node): matched_nodes.append(next_node) next_node[self.OPTIMIZED_NODE_FLAG] = True next_nodes = get_next_operation(next_node) if len(next_nodes) > 1: log.debug('There are two consumers of the node {}. Stop matching sequence.'.format( next_node.soft_get('name'))) break next_node = next_nodes[0] # optimize sequence of three or more Transpose-Reshape nodes if len(matched_nodes) >= 3: self.optimize_permute_reshape_sequence(graph, matched_nodes) # run the RemoveRedundantReshapes to remove dummy (NOP) reshapes. After that we can run Transposes fusing FuseReshapesSequence().find_and_replace_pattern(graph) RemoveRedundantReshapes().find_and_replace_pattern(graph) FuseTransposesSequence().find_and_replace_pattern(graph)
def find_and_replace_pattern(self, graph: Graph): reshape_nodes = graph.get_op_nodes(type='Reshape') for node in reshape_nodes: if not graph.has_node(node.id): # the Reshape node has been removed in the previous iteration continue if len(node.out_port(0).get_destinations()) == 1: log.debug('First phase for Reshape: {}'.format( node.soft_get('name'))) next_op = get_next_operation(node)[0] log.debug('second node: id={}, type={}'.format( next_op.soft_get('id'), next_op.soft_get('type'))) if next_op.has_valid('type') and next_op.type == 'Reshape': dim_value = next_op.in_port(1).data.get_value() if dim_value is None or 0 in dim_value or -1 in dim_value: # we do not fuse reshape sequences with special symbols: 0, -1 continue # Detected Reshape1 --> data --> Reshape2 pattern without side edges. Remove Reshape1 log.debug('Second phase for Reshape: {}'.format( node.soft_get('name'))) remove_op_node_with_data_node(graph, node)
def concat_convolutions(graph: Graph, start_node: Node, last_node: Node): """ This function converts group of convolutions into one """ # Check that concatenation makes in the same order conv_nodes = get_next_operation(start_node) assert len(conv_nodes) == len(last_node.in_nodes()) gconv = conv_nodes[0] for id in range(len(conv_nodes)): conv = conv_nodes[id] if conv.out_node().id != last_node.in_node(id).id: return False # Check that all convolutions have same weights shapes if not np.array_equal(conv.in_node(1).shape, gconv.in_node(1).shape): log.debug( 'Grouped convolutions fusion : convolutions have different weights shape' ) return False # Check that split and concat dims are valid channel_dim = gconv.channel_dims[0] if channel_dim != start_node.axis or channel_dim != last_node.axis: log.debug( 'Grouped convolutions fusion : split or concat has wierd axis!') return False # Check that all convolutions has the same parameters conv_attrs = ['pad', 'stride'] for attr in conv_attrs: for id in range(len(conv_nodes)): conv = conv_nodes[id] if not np.array_equal(gconv[attr], conv[attr]): log.debug( 'Grouped convolutions fusion : attrs {} doesn\'t match'. format(attr)) return False # Check that all Convolutions has biases (if exists) has_biases = False for id in range(len(conv_nodes)): conv = conv_nodes[id] if len(conv.in_nodes()) == 3: if not has_biases: has_biases = True elif has_biases: return False # All convolution mast have biases # Check that all biases have same shape if has_biases: for id in range(len(conv_nodes)): conv = conv_nodes[id] if conv.in_node(2).shape != gconv.in_node(2).shape: log.debug( 'Group convolutions fusion : convolutions have different biases shape {} and {}' .format(conv.in_node(2).shape, gconv.in_node(2).shape)) return False graph.remove_edge(gconv.in_node(0).id, gconv.id) graph.remove_edge(gconv.id, gconv.out_node().id) input = start_node.in_node(start_node.input_port) output = last_node.out_node() # Removing edges from data nodes to Split and Concat graph.remove_edge(input.id, start_node.id) graph.remove_edge(last_node.id, output.id) # Add edges to grouped convolution graph.add_edges_from([(input.id, gconv.id, { 'in': 0 }), (gconv.id, output.id, { 'out': 0 })]) # Concatenation of convolutions weights_node = gconv.in_node(1) bias_node = gconv.in_node(2) if has_biases else None weights_value = np.array(weights_node.value) bias_value = np.array(bias_node.value) if has_biases else None feature_dim = 3 if graph.graph['layout'] == 'NHWC' else 1 for conv in conv_nodes[1:]: weights_value = np.concatenate((weights_value, conv.in_node(1).value), axis=feature_dim) if has_biases: bias_value = np.concatenate((bias_value, conv.in_node(2).value), axis=-1) # Not validated weights_node.value = np.array(weights_value) weights_node.shape = np.array(weights_value.shape) if has_biases: bias_node.value = np.array(bias_value) bias_node.shape = np.array(bias_value.shape) log.debug('Start node : {} Last node : {} Nodes inside : {}'.format( start_node.id, last_node.id, len(start_node.out_nodes()))) log.debug('Output shape : {}'.format(weights_value.shape)) gconv.group = len(conv_nodes) gconv.output = weights_node.shape[feature_dim] gconv.output_shape[feature_dim] = weights_node.shape[feature_dim] return True