def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]): node = match['reduce'] connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()] if len(connected_in_ports) == 1: # if the 'axis' is None then we still add a second input to the layer with a 1D array with 1 element equal # to None. The infer function handles this case because the input shape is known at this stage only if node.has('axis'): const = Const(graph, {'value': node.axis}).create_node() node.add_input_port(1, skip_if_exist=True) const.out_port(0).connect(node.in_port(1)) del graph.node[node.id]['axis'] else: # The default (if there is no 'axis') is to reduce over all the dimensions of the input tensor. node_name = node.name begin_of_range = Const(graph, dict(name=node_name + '/range_begin_', value=0)).create_node() step = Const(graph, dict(name=node_name + '/range_step_', value=1)).create_node() end_of_range = Rank(graph, dict(name=node_name + '/range_end_')).create_node() axes = Range(graph, dict(name=node_name + '/axes_')).create_node() begin_of_range.out_port(0).connect(axes.in_port(0)) end_of_range.out_port(0).connect(axes.in_port(1)) step.out_port(0).connect(axes.in_port(2)) node.add_input_port(1, skip_if_exist=True) axes.out_port(0).connect(node.in_port(1)) node.in_port(0).get_connection().get_source().connect(end_of_range.in_port(0))
def replace_sub_graph(self, graph: Graph, match: dict): source_connection = match['split'].in_port(0).get_connection() source_node = source_connection.get_source().node cast_node = match['cast'] range_node = Range(graph, { 'name': source_node.id + '/Range' }).create_node() start_node = Const(graph, { 'name': range_node.id + '/Start', 'value': int64_array(0) }).create_node() step_node = Const(graph, { 'name': range_node.id + '/Step', 'value': int64_array(1) }).create_node() input_shape_node = Shape(graph, { 'name': start_node.id + '/Shape' }).create_node() input_shape_node.in_port(0).connect(source_node.out_port(0)) limit_node_1D = node_to_get_batch_value(input_shape_node) limit_node = create_op_node_with_second_input( graph, Squeeze, int64_array([0]), {'name': source_node.id + '/batch_0D_value'}, limit_node_1D) range_node.in_port(0).connect(start_node.out_port(0)) range_node.in_port(1).connect(limit_node.out_port(0)) range_node.in_port(2).connect(step_node.out_port(0)) cast_node.out_port(0).get_connection().set_source( range_node.out_port(0)) graph.remove_nodes_from([node.id for node in match.values()])
def find_and_replace_pattern(self, graph: Graph): global_poolings = graph.get_op_nodes(type='Pooling', global_pool=True) if len(global_poolings) == 0: return layout = graph.graph['layout'] assert layout != 'NHWC', 'Global pooling transformation depends on layout (NHWC not enabled)' for pooling in global_poolings: name = pooling.soft_get('name', pooling.id) assert pooling.has_valid( 'pool_method' ), 'Global Pooling {} has no `pool_method` attribute'.format(name) method = pooling['pool_method'] assert method in self.pool_method_to_reduce_type, \ 'Unexpected Global Pooling method `{}` for node `{}`'.format(method, name) reduce_op_class = self.pool_method_to_reduce_type[method] reduce = reduce_op_class(graph, { 'name': name + '/reduce', 'keep_dims': True }).create_node() pooling.out_port(0).get_connection().set_source(reduce.out_port(0)) src = pooling.in_port(0).get_connection().get_source() pooling.in_port(0).disconnect() src.connect(reduce.in_port(0)) start = Const(graph, {'value': int64_array(2)}).create_node() end = Rank(graph, {'name': name + '/input_rank'}).create_node() delta = Const(graph, {'value': int64_array(1)}).create_node() axis = Range(graph, { 'name': name + '/global_pooling_reduce_axis' }).create_node() axis.in_port(0).connect(start.out_port(0)) src.connect(end.in_port(0)) axis.in_port(1).connect(end.out_port(0)) axis.in_port(2).connect(delta.out_port(0)) axis.out_port(0).connect(reduce.in_port(1)) log.debug('Global {} pooling was converted to reduce: `{}`'.format( method, name))