def generate_sub_graph(self, graph: Graph, match: SubgraphMatch): permute_op = Permute(graph, {'order': np.array([0, 2, 3, 1])}) permute_node = permute_op.add_node({'name': match.scope + '_permute_'}) reshape_node = match.node_by_pattern('flatten/Reshape$') # reshape_in_node is the node after which we should insert Permute reshape_in_node = reshape_node.in_nodes()[0] reshape_in_node.insert_node_after(permute_node, 0) return {}
def replace_sub_graph(self, graph: Graph, match: dict): target_node = match['target_node'] nodes_with_weights = self.dfs( graph, target_node.name, ('Convolution', 'FullyConnected', 'ScaleShift'), True) convolution_nodes = [ node for node in nodes_with_weights if Node(graph, node).op == 'Convolution' ] for convolution_node in convolution_nodes: target_node = self.search_target_node(Node(graph, convolution_node)) permute_op = Permute(graph, {'order': np.array([0, 3, 2, 1])}) permute_node = permute_op.add_node( {'name': '{}/Permute'.format(target_node.name)}) target_node.insert_node_after(permute_node, 0)