def _insert_pooling(graph: Graph, first_node: Node, second_node: Node, spatial_dims): """ This function inserts point wise pooling layer between two nodes """ log.debug("STRIDE PROP: Insert pooling between {} and {}".format( first_node.name, second_node.name)) stride_prop = second_node.stride_prop assert len(graph.get_edge_data(first_node.id, second_node.id)) == 1 eattrs = graph.get_edge_data(first_node.id, second_node.id)[0] graph.remove_edge(first_node.id, second_node.id) pooling = Pooling( graph, dict(name='Pooling_', spatial_dims=spatial_dims, window=np.array([1, 1, 1, 1]), output_spatial_shape=None, stride=np.array(stride_prop), pad_spatial_shape=np.array([[0, 0], [0, 0]]), pad=np.array([[0, 0], [0, 0], [0, 0], [0, 0]]), pool_method='max', is_partial_inferred=False)) pooling_data = pooling.create_node_with_data([first_node]) _clean_fw_tensor_attrs(pooling_data) graph.add_edges_from([(pooling_data.id, second_node.id, eattrs)])
def find_and_replace_pattern(self, graph: Graph): for pool_v2_node in graph.get_op_nodes(op='PoolingV2'): pool_v2_name = pool_v2_node.soft_get('name', pool_v2_node.id) pool_v1_node = Pooling( graph, { 'window': pool_v2_node.in_port(1).data.get_value(), 'stride': pool_v2_node.in_port(2).data.get_value(), 'pad': pool_v2_node.pad, 'spatial_dims': pool_v2_node.spatial_dims, 'auto_pad': pool_v2_node.auto_pad, 'output_spatial_shape': pool_v2_node.output_spatial_shape, 'pad_spatial_shape': pool_v2_node.pad_spatial_shape, 'pool_method': pool_v2_node.pool_method, 'permute_attrs': pool_v2_node.permute_attrs, }).create_node() rename_nodes([(pool_v2_node, pool_v2_name + '/to_be_removed'), (pool_v1_node, pool_v2_name)]) pool_v2_node.in_port(0).get_connection().set_destination( pool_v1_node.in_port(0)) pool_v2_node.out_port(0).get_connection().set_source( pool_v1_node.out_port(0))