def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
        node = match['reduce']
        connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
        if len(connected_in_ports) == 1:
            # if the 'axis' is None then we still add a second input to the layer with a 1D array with 1 element equal
            # to None. The infer function handles this case because the input shape is known at this stage only
            if node.has('axis'):
                const = Const(graph, {'value': node.axis}).create_node()
                node.add_input_port(1, skip_if_exist=True)
                const.out_port(0).connect(node.in_port(1))
                del graph.node[node.id]['axis']
            else:
                # The default (if there is no 'axis') is to reduce over all the dimensions of the input tensor.
                node_name = node.name

                begin_of_range = Const(graph, dict(name=node_name + '/range_begin_', value=0)).create_node()
                step = Const(graph, dict(name=node_name + '/range_step_', value=1)).create_node()
                end_of_range = Rank(graph, dict(name=node_name + '/range_end_')).create_node()
                axes = Range(graph, dict(name=node_name + '/axes_')).create_node()

                begin_of_range.out_port(0).connect(axes.in_port(0))
                end_of_range.out_port(0).connect(axes.in_port(1))
                step.out_port(0).connect(axes.in_port(2))

                node.add_input_port(1, skip_if_exist=True)
                axes.out_port(0).connect(node.in_port(1))
                node.in_port(0).get_connection().get_source().connect(end_of_range.in_port(0))
Example #2
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        source_connection = match['split'].in_port(0).get_connection()
        source_node = source_connection.get_source().node
        cast_node = match['cast']

        range_node = Range(graph, {
            'name': source_node.id + '/Range'
        }).create_node()
        start_node = Const(graph, {
            'name': range_node.id + '/Start',
            'value': int64_array(0)
        }).create_node()

        step_node = Const(graph, {
            'name': range_node.id + '/Step',
            'value': int64_array(1)
        }).create_node()
        input_shape_node = Shape(graph, {
            'name': start_node.id + '/Shape'
        }).create_node()
        input_shape_node.in_port(0).connect(source_node.out_port(0))

        limit_node_1D = node_to_get_batch_value(input_shape_node)
        limit_node = create_op_node_with_second_input(
            graph, Squeeze, int64_array([0]),
            {'name': source_node.id + '/batch_0D_value'}, limit_node_1D)

        range_node.in_port(0).connect(start_node.out_port(0))
        range_node.in_port(1).connect(limit_node.out_port(0))
        range_node.in_port(2).connect(step_node.out_port(0))
        cast_node.out_port(0).get_connection().set_source(
            range_node.out_port(0))

        graph.remove_nodes_from([node.id for node in match.values()])
Example #3
0
    def find_and_replace_pattern(self, graph: Graph):
        global_poolings = graph.get_op_nodes(type='Pooling', global_pool=True)
        if len(global_poolings) == 0:
            return

        layout = graph.graph['layout']
        assert layout != 'NHWC', 'Global pooling transformation depends on layout (NHWC not enabled)'

        for pooling in global_poolings:
            name = pooling.soft_get('name', pooling.id)
            assert pooling.has_valid(
                'pool_method'
            ), 'Global Pooling {} has no `pool_method` attribute'.format(name)
            method = pooling['pool_method']
            assert method in self.pool_method_to_reduce_type, \
                'Unexpected Global Pooling method `{}` for node `{}`'.format(method, name)
            reduce_op_class = self.pool_method_to_reduce_type[method]

            reduce = reduce_op_class(graph, {
                'name': name + '/reduce',
                'keep_dims': True
            }).create_node()

            pooling.out_port(0).get_connection().set_source(reduce.out_port(0))
            src = pooling.in_port(0).get_connection().get_source()
            pooling.in_port(0).disconnect()
            src.connect(reduce.in_port(0))

            start = Const(graph, {'value': int64_array(2)}).create_node()
            end = Rank(graph, {'name': name + '/input_rank'}).create_node()
            delta = Const(graph, {'value': int64_array(1)}).create_node()

            axis = Range(graph, {
                'name': name + '/global_pooling_reduce_axis'
            }).create_node()

            axis.in_port(0).connect(start.out_port(0))
            src.connect(end.in_port(0))
            axis.in_port(1).connect(end.out_port(0))
            axis.in_port(2).connect(delta.out_port(0))

            axis.out_port(0).connect(reduce.in_port(1))

            log.debug('Global {} pooling was converted to reduce: `{}`'.format(
                method, name))