Ejemplo n.º 1
0
 def extract(cls, node: Node):
     attrs = get_mxnet_layer_attrs(node.symbol_dict)
     Range.update_node_stat(node, {
         'start': attrs.int('start', 0),
         'stop': attrs.int('stop', 0),
         'repeat': attrs.int('repeat', 1),
         'step': attrs.float('step', 1),
         'dtype': np.dtype(attrs.str('dtype ', 'float32'))
     })
     return cls.enabled
Ejemplo n.º 2
0
def get_range_node_of_idxs(rank: Node,
                           begin: int,
                           end: int,
                           include_begin: bool = True,
                           include_end: bool = False) -> Node:
    """
    Returns node that produces 1D output of values of range from begin to end (ex)/(in)cluding begin or end point

    :param rank: the node of 0D output shape to get rank of tensor from
    :param begin: integer value from [-rank; rank - 1]
    :param end: integer value from [-rank; +rank]
    :param include_begin: boolean flag to include or exclude start point from range output
    :param include_end: boolean flag to include or exclude end point from range output
    :return: range node producing 1D output
    """
    graph = rank.graph
    name = rank.soft_get('name', rank.id)

    start_idx = get_canonical_axis_index_node(rank, begin)
    end_idx = get_canonical_axis_index_node(rank, end)

    if not include_begin:
        const = Const(graph, {
            'value': int64_array(1),
            'name': name + '/exclude_begin/value'
        }).create_node()
        add = Add(graph, {'name': name + '/exclude_begin'}).create_node()
        start_idx.out_port(0).connect(add.in_port(0))
        const.out_port(0).connect(add.in_port(1))
        start_idx = add

    if include_end:
        const = Const(graph, {
            'value': int64_array(1),
            'name': name + '/including_end/value'
        }).create_node()
        add = Add(graph, {'name': name + '/including_end'}).create_node()
        end_idx.out_port(0).connect(add.in_port(0))
        const.out_port(0).connect(add.in_port(1))
        end_idx = add

    delta = Const(graph, {
        'name': name + '/delta',
        'value': int64_array(1)
    }).create_node()
    range_node = Range(graph, {'name': name + '/range_idxs'}).create_node()

    start_idx.out_port(0).connect(range_node.in_port(0))
    end_idx.out_port(0).connect(range_node.in_port(1))
    delta.out_port(0).connect(range_node.in_port(2))

    return range_node
Ejemplo n.º 3
0
    def find_and_replace_pattern(self, graph: Graph):
        global_poolings = graph.get_op_nodes(type='Pooling', global_pool=True)
        if len(global_poolings) == 0:
            return

        layout = graph.graph['layout']
        assert layout != 'NHWC', 'Global pooling transformation depends on layout (NHWC not enabled)'

        for pooling in global_poolings:
            name = pooling.soft_get('name', pooling.id)
            assert pooling.has_valid(
                'pool_method'
            ), 'Global Pooling {} has no `pool_method` attribute'.format(name)
            method = pooling['pool_method']
            assert method in self.pool_method_to_reduce_type, \
                'Unexpected Global Pooling method `{}` for node `{}`'.format(method, name)
            reduce_op_class = self.pool_method_to_reduce_type[method]

            reduce = reduce_op_class(graph, {
                'name': name + '/reduce',
                'keep_dims': True
            }).create_node()

            pooling.out_port(0).get_connection().set_source(reduce.out_port(0))
            src = pooling.in_port(0).get_connection().get_source()
            pooling.in_port(0).disconnect()
            src.connect(reduce.in_port(0))

            start = Const(graph, {'value': int64_array(2)}).create_node()
            end = Rank(graph, {'name': name + '/input_rank'}).create_node()
            delta = Const(graph, {'value': int64_array(1)}).create_node()

            axis = Range(graph, {
                'name': name + '/global_pooling_reduce_axis'
            }).create_node()

            axis.in_port(0).connect(start.out_port(0))
            src.connect(end.in_port(0))
            axis.in_port(1).connect(end.out_port(0))
            axis.in_port(2).connect(delta.out_port(0))

            axis.out_port(0).connect(reduce.in_port(1))

            log.debug('Global {} pooling was converted to reduce: `{}`'.format(
                method, name))
Ejemplo n.º 4
0
    def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
        node = match['reduce']
        connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
        if len(connected_in_ports) == 1:
            # if the 'axis' is None then we still add a second input to the layer with a 1D array with 1 element equal
            # to None. The infer function handles this case because the input shape is known at this stage only
            if node.has('axis'):
                const = Const(graph, {'value': node.axis}).create_node()
                node.add_input_port(1, skip_if_exist=True)
                const.out_port(0).connect(node.in_port(1))
                del graph.node[node.id]['axis']
            else:
                # The default (if there is no 'axis') is to reduce over all the dimensions of the input tensor.
                node_name = node.name

                begin_of_range = Const(graph, dict(name=node_name + '/range_begin_', value=0)).create_node()
                step = Const(graph, dict(name=node_name + '/range_step_', value=1)).create_node()
                end_of_range = Rank(graph, dict(name=node_name + '/range_end_')).create_node()
                axes = Range(graph, dict(name=node_name + '/axes_')).create_node()

                begin_of_range.out_port(0).connect(axes.in_port(0))
                end_of_range.out_port(0).connect(axes.in_port(1))
                step.out_port(0).connect(axes.in_port(2))

                node.add_input_port(1, skip_if_exist=True)
                axes.out_port(0).connect(node.in_port(1))
                node.in_port(0).get_connection().get_source().connect(end_of_range.in_port(0))
Ejemplo n.º 5
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        source_connection = match['split'].in_port(0).get_connection()
        source_node = source_connection.get_source().node
        cast_node = match['cast']

        range_node = Range(graph, {
            'name': source_node.id + '/Range'
        }).create_node()
        start_node = Const(graph, {
            'name': range_node.id + '/Start',
            'value': int64_array(0)
        }).create_node()

        step_node = Const(graph, {
            'name': range_node.id + '/Step',
            'value': int64_array(1)
        }).create_node()
        input_shape_node = Shape(graph, {
            'name': start_node.id + '/Shape'
        }).create_node()
        input_shape_node.in_port(0).connect(source_node.out_port(0))

        limit_node_1D = node_to_get_batch_value(input_shape_node)
        limit_node = create_op_node_with_second_input(
            graph, Squeeze, int64_array([0]),
            {'name': source_node.id + '/batch_0D_value'}, limit_node_1D)

        range_node.in_port(0).connect(start_node.out_port(0))
        range_node.in_port(1).connect(limit_node.out_port(0))
        range_node.in_port(2).connect(step_node.out_port(0))
        cast_node.out_port(0).get_connection().set_source(
            range_node.out_port(0))

        graph.remove_nodes_from([node.id for node in match.values()])
Ejemplo n.º 6
0
 def extract(cls, node: Node):
     Range.update_node_stat(node, {})
     return cls.enabled
Ejemplo n.º 7
0
 def extract(cls, node: Node):
     Range.update_node_stat(
         node,
         {'output_type': tf_dtype_extractor(node.pb.attr['Tidx'].type)})
     return cls.enabled
Ejemplo n.º 8
0
 def extract(cls, node: Node):
     # output_type attribute will be deduced during shape infer
     Range.update_node_stat(node, {})
     return cls.enabled