def fuse_linear_ops(graph: nx.MultiDiGraph): """ This function makes fusing of linear operations (Mul,Add) to Convolution/FC. """ fuse_count = 0 # Fusion in backward direction nodes = pseudo_topological_sort(graph) for idx in nodes: node = Node(graph, idx) is_fused = False # Fuse Mul to Convolution/FC if node.soft_get('op') == 'Mul' and get_value_id( node) is not None and node.soft_get('can_be_fused') == True: fuse_nodes = backward_bfs( node, [], ['Convolution', 'Deconvolution', 'FullyConnected']) is_fused = _fuse_mul(graph, node, fuse_nodes) # Fuse Add to Convolution/FC if node.soft_get('op') == 'Add' and get_value_id( node) is not None and node.soft_get('can_be_fused') == True: fuse_nodes = backward_bfs( node, [], ['Convolution', 'Deconvolution', 'FullyConnected']) is_fused = _fuse_add(graph, node, fuse_nodes) fuse_count += is_fused # Fusion in forward direction nodes = pseudo_topological_sort(graph, reverse=True) for idx in nodes: node = Node(graph, idx) is_fused = False # Fuse Mul to Convolution/FC if node.soft_get('op') == 'Mul' and get_value_id( node) is not None and node.soft_get('can_be_fused') == True: fuse_nodes = forward_bfs( node, [], ['Convolution', 'Deconvolution', 'FullyConnected']) is_fused = _fuse_mul(graph, node, fuse_nodes, False) # Fuse Add to Convolution/FC if node.soft_get('op') == 'Add' and get_value_id( node) is not None and node.soft_get('can_be_fused') == True: fuse_nodes = forward_bfs(node, [], ['FullyConnected']) is_fused = _fuse_add(graph, node, fuse_nodes, False) fuse_count += is_fused log.debug("Fused {} nodes".format(fuse_count))
def find_and_replace_pattern(self, graph: nx.MultiDiGraph): intervals = {} for n in pseudo_topological_sort(graph): node = Node(graph, n) if not node.has('op') or (node.op != 'FakeQuantWithMinMaxVars' and node.op != 'Quantize'): continue if node.op == 'Quantize': # check if input range matches output range low_match = np.all( node.in_node(1).value == node.in_node(3).value) high_match = np.all( node.in_node(2).value == node.in_node(4).value) if not low_match or not high_match: continue prev_node = node.in_node().in_node() prev_node_id = prev_node.id prev_node_out_shape = prev_node.out_node()['shape'] C = prev_node_out_shape[1] assert node.in_node(1).value.size == 1 assert node.in_node(2).value.size == 1 min = ', '.join([str(node.in_node(1).value.flatten()[0])] * C) max = ', '.join([str(node.in_node(2).value.flatten()[0])] * C) intervals[prev_node_id] = {'min': min, 'max': max} if intervals: if 'statistics' not in graph.graph: graph.graph['statistics'] = intervals else: graph.graph['statistics'].update(intervals) remove_op_nodes(graph, {'op': 'FakeQuantWithMinMaxVars'}) remove_op_nodes(graph, {'op': 'Quantize'})
def find_and_replace_pattern(self, graph: Graph): for n in pseudo_topological_sort(graph): if graph.node[n]['kind'] == 'data' or graph.node[n]['op'] != 'Switch': continue switch_op_node = Node(graph, n) pred_id_data_node = switch_op_node.in_node(1) graph.remove_edge(pred_id_data_node.id, switch_op_node.id) remove_op_node_with_data_node(graph, switch_op_node)
def fuse_mul_add_sequence(graph: Graph): """ This function finds first valid Mul/Add node and pass it to fuse_linear_sequence where full sequence will be found """ while True: is_fused = False for idx in list(pseudo_topological_sort(graph)): if idx in graph: node = Node(graph, idx) if node.soft_get('op') in ['Mul','Add'] and get_value_id(node) is not None and node.soft_get('can_be_fused') == True: is_fused |= _fuse_linear_sequence(graph, node) if not is_fused: break
def _stride_propagation(graph: Graph, spatial_dims): """ This function do stride propagation for all op nodes """ nodes = [Node(graph, x) for x in pseudo_topological_sort(graph, reverse=True) if Node(graph, x).kind == 'op' and Node(graph, x).soft_get('type') != 'Const'] for node in nodes: if node.soft_get('type') in supported_ops: op = supported_ops[node.type] # Add node attrs for key in op['attrs'].keys(): node[key] = op['attrs'][key] op['stride_prop'](graph, node, spatial_dims, True) else: _simple_stride_prop(graph, node, spatial_dims, False)
def mark_const_producer_nodes(graph: nx.MultiDiGraph): """ Mark nodes that produce constant values. :param graph: graph to operate on. :return: . """ nx.set_node_attributes(G=graph, name='is_const_producer', values=True) for n in pseudo_topological_sort(graph): node = Node(graph, n) for input, output, attrs in graph.in_edges(n, data=True): if 'control_flow_edge' in attrs and attrs['control_flow_edge']: graph.node[input]['is_const_producer'] = False graph.node[output]['is_const_producer'] = False if not node.has('value') or node.value is None: for input, _ in graph.in_edges(n): graph.node[input]['is_const_producer'] = False
def shape_inference(graph: Graph): nodes = pseudo_topological_sort(graph) for node in nodes: node = Node(graph, node) if node.has_and_set('need_shape_inference'): old_out_shapes = [ port.data.get_shape() for port in node.out_ports().values() ] node.infer(node) new_out_shapes = [ port.data.get_shape() for port in node.out_ports().values() ] for shape1, shape2 in zip(old_out_shapes, new_out_shapes): if shape1 is not None and not np.array_equal(shape1, shape2): raise Error( "After partial shape inference were found shape collision for node {} (old shape: {}, new shape: {})" .format(node.name, shape1, shape2)) node.need_shape_inference = False
def stride_optimization(graph: Graph): """ This is main function for stride optimization pass """ layout = graph.graph['layout'] if layout == 'NCHW': spatial_dims = np.array([2, 3]) elif layout == 'NHWC': spatial_dims = np.array([1, 2]) else: log.warning('STRIDE PROP: layout {} is not supported'.format(layout)) return _stride_propagation(graph, spatial_dims) nodes = [Node(graph, x) for x in pseudo_topological_sort(graph) if Node(graph, x).soft_get('is_partial_inferred') == False] for node in nodes: node.infer(node)
def grouped_convolutions_fusing(graph: Graph): while True: is_fused = False graph_clean_up(graph, ['TFCustomSubgraphCall', 'Shape']) nodes = pseudo_topological_sort(graph) for idx in nodes: node = Node(graph, idx) if node.kind == 'op' and len(node.out_nodes()) > 1: if node.soft_get('can_be_fused') == False: continue is_valid_convolutions = True last_layer = None next_nodes = get_next_operation(node) # Check that all operation after this one are Convolutions # and all convolutions has same output if len(next_nodes) > 1 and all( _node.soft_get('type') in ['Convolution', 'Deconvolution'] for _node in next_nodes): for conv in next_nodes: conv_outputs = get_next_operation(conv) if conv.soft_get('can_be_fused') == False: is_valid_convolutions = False if len(conv_outputs) != 1: is_valid_convolutions = False if last_layer is None: last_layer = conv_outputs[0].id elif conv_outputs[0].id != last_layer: is_valid_convolutions = False if is_valid_convolutions: is_fused = concat_convolutions(graph, node, Node(graph, last_layer)) if is_fused: break if not is_fused: break