Пример #1
0
 def test_tile_infer_correct_2d_tensor(self):
     graph = build_graph(nodes_attributes, edges,
                         {'data': {'shape': np.array([3, 7])},
                          'tile_values': {'value': np.array([5, 1])}})
     tile_node = Node(graph, 'tile')
     Tile.infer(tile_node)
     self.assertTrue(np.all(np.array([15, 7]) == graph.node['tile_out']['shape']))
Пример #2
0
    def replace_pattern(graph: Graph, match: dict):
        tile = match['tile']

        if tile.has_valid('tile_array'):
            tile_array = tile.tile_array
            assert len(tile_array) == len(tile.in_port(0).data.get_shape())

            non_one_tile = np.argwhere(tile_array != 1).flatten()

            # We need to add new tiles only in case when we tile more than one dimension
            if len(non_one_tile) > 1:
                last_tile = None
                for i in non_one_tile:
                    axis = i
                    tiles = tile_array[i]
                    new_tile = Tile(graph, {'name': tile.name + '/Tile_{}/'.format(i), 'axis': axis, 'tiles': tiles,
                                            'need_shape_inference': True}).create_node()
                    if not last_tile:
                        last_tile = new_tile
                        tile.in_port(0).get_connection().set_destination(new_tile.in_port(0))
                    else:
                        last_tile.out_port(0).connect(new_tile.in_port(0))
                        last_tile = new_tile

                # Reconnect output to new tile node and delete old tile
                tile.out_port(0).get_connection().set_source(last_tile.out_port(0))
Пример #3
0
 def test_tile_infer_one_input_correct_missing_tiles(self):
     graph = build_graph(nodes_attributes, [('data', 'tile'),
                                            ('tile', 'tile_out')],
                         {'tile': {
                             'axis': 1
                         }})
     tile_node = Node(graph, 'tile')
     Tile.infer(tile_node)
     self.assertIsNone(graph.node['tile_out']['shape'])
Пример #4
0
 def test_tile_infer_values_test(self):
     input_data = np.arange(-30, 60, 0.25).reshape([2, 4, 3, -1])
     tile_values = np.array([3, 1, 1, 1])
     graph = build_graph(nodes_attributes, edges,
                         {'data': {'shape': np.array(input_data.shape), 'value': input_data},
                          'tile_values': {'value': tile_values}})
     tile_node = Node(graph, 'tile')
     Tile.infer(tile_node)
     self.assertTrue(np.all(np.tile(input_data, tile_values) == graph.node['tile_out']['value']))
Пример #5
0
 def test_tile_infer_correct(self):
     graph = build_graph(nodes_attributes, edges,
                         {'tile_values': {
                             'value': np.array([7, 1, 1, 1])
                         }})
     tile_node = Node(graph, 'tile')
     Tile.infer(tile_node)
     self.assertTrue(
         np.all(
             np.array([70, 20, 30, 40]) == graph.node['tile_out']['shape']))
Пример #6
0
 def test_tile_infer_undefined_tile_values(self):
     graph = build_graph(nodes_attributes, [('data', 'tile'),
                                            ('tile_values', 'tile'),
                                            ('tile', 'tile_out')],
                         {'tile_values': {
                             'value': None
                         }})
     tile_node = Node(graph, 'tile')
     Tile.infer(tile_node)
     self.assertIsNone(graph.node['tile_out']['shape'])
Пример #7
0
 def test_tile_infer_three_non_one(self):
     graph = build_graph(nodes_attributes, edges,
                         {'tile_values': {
                             'value': np.array([2, 1, 5, 2])
                         }})
     tile_node = Node(graph, 'tile')
     Tile.infer(tile_node)
     self.assertTrue(
         np.all(
             np.array([20, 20, 150, 80]) == graph.node['tile_out']
             ['shape']))
Пример #8
0
 def test_tile_infer_values_const_propagation(self):
     """
     Test for constant propagation even if tile with multiple tile indices is not supported
     """
     input_data = np.arange(-30, 60, 0.25).reshape([2, 4, 3, -1])
     tile_values = np.array([4, 3, 2, 5])
     graph = build_graph(nodes_attributes, edges,
                         {'data': {'shape': np.array(input_data.shape), 'value': input_data},
                          'tile_values': {'value': tile_values}})
     tile_node = Node(graph, 'tile')
     Tile.infer(tile_node)
     self.assertTrue(np.all(np.tile(input_data, tile_values) == graph.node['tile_out']['value']))
Пример #9
0
 def test_tile_infer_shapes_mismatch(self):
     graph = build_graph(nodes_attributes, [('data', 'tile'),
                                            ('tile_values', 'tile'),
                                            ('tile', 'tile_out')],
                         {
                             'tile_values': {
                                 'value': np.array([1, 2, 1]),
                                 'shape': np.array([3])
                             }
                         })
     tile_node = Node(graph, 'tile')
     Tile.infer(tile_node)
     self.assertIsNone(graph.node['tile_out']['shape'])
Пример #10
0
 def test_tile_infer_shapes_alignment(self):
     graph = build_graph(nodes_attributes, edges, {
         'tile_values': {
             'value': np.array([1, 2, 3]),
             'shape': np.array([3])
         }
     })
     tile_node = Node(graph, 'tile')
     Tile.infer(tile_node)
     self.assertTrue(
         np.all(
             np.array([10, 20, 60, 120]) == graph.node['tile_out']
             ['shape']))
Пример #11
0
 def test_tile_infer_none_input_shape(self):
     graph = build_graph(
         nodes_attributes, [('data', 'tile'), ('tile_values', 'tile'),
                            ('tile', 'tile_out')], {
                                'data': {
                                    'shape': None
                                },
                                'tile_values': {
                                    'value': np.array([1, 7, 1, 1])
                                }
                            })
     tile_node = Node(graph, 'tile')
     Tile.infer(tile_node)
     self.assertIsNone(graph.node['tile_out']['shape'])
Пример #12
0
 def test_tile_infer_all_ones(self):
     graph = build_graph(nodes_attributes, [('data', 'tile'),
                                            ('tile_values', 'tile'),
                                            ('tile', 'tile_out')],
                         {'tile_values': {
                             'value': np.array([1, 1, 1, 1])
                         }})
     tile_node = Node(graph, 'tile')
     Tile.infer(tile_node)
     self.assertTrue(
         np.all(
             np.array([10, 20, 30, 40]) == graph.node['tile_out']['shape']))
     self.assertEqual(tile_node.axis, 0)
     self.assertEqual(tile_node.tiles, 1)
Пример #13
0
 def test_tile_infer_two_non_one(self):
     graph = build_graph(nodes_attributes, [('data', 'tile'),
                                            ('tile_values', 'tile'),
                                            ('tile', 'tile_out')],
                         {'tile_values': {
                             'value': np.array([2, 1, 1, 2])
                         }})
     tile_node = Node(graph, 'tile')
     Tile.infer(tile_node)
     self.assertIsNone(graph.node['tile']['type'])
     self.assertTrue(
         np.all(
             np.array([20, 20, 30, 80]) == graph.node['tile_out']['shape']))
     self.assertFalse(tile_node.has_and_set('axis'))
     self.assertFalse(tile_node.has_and_set('tiles'))
Пример #14
0
 def test_tile_infer_one_input_correct(self):
     graph = build_graph(nodes_attributes, [('data', 'tile'),
                                            ('tile', 'tile_out')],
                         {'tile': {
                             'axis': 1,
                             'tiles': 7
                         }})
     tile_node = Node(graph, 'tile')
     Tile.infer(tile_node)
     self.assertTrue(
         np.all(
             np.array([10, 140, 30, 40]) == graph.node['tile_out']
             ['shape']))
     self.assertEqual(tile_node.axis, 1)
     self.assertEqual(tile_node.tiles, 7)
Пример #15
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        shapes = [in_node.shape for _, in_node in node.in_nodes().items()]
        out_shape = node.out_node().shape
        tname = node.name + '/Broadcast/'
        tile = Tile(graph, dict(name=tname))

        # Working with scalar values
        for i, shape in enumerate(shapes):
            if len(shape) == 0:
                shapes[i] = np.ones(len(out_shape), dtype=np.int64)
                node.in_node(i).shape = shapes[i].copy()
                if node.in_node(i).value is not None:
                    node.in_node(i).value = np.reshape(node.in_node(i).value, newshape=shapes[i])

        if not all([len(shape) == len(out_shape) for shape in shapes]):
            log.warning("Cannot apply broadcast for Eltwise layer {} "
                        "because not all input shapes {} have the same number of elements "
                        "as output shape {}.".format(node.soft_get('name'),
                                                     shapes,
                                                     out_shape
                                                     )
                        )
            return

        input_idx = 0
        for port, old_input in node.in_nodes().items():
            # old_input = node.in_node(input_idx)
            input = old_input
            for i in range(len(out_shape)):
                if shapes[input_idx][i] == 1 and out_shape[i] > 1:
                    new_op = tile.create_node([input], dict(axis=i, tiles=out_shape[i]))
                    # add a data node following a new operation node
                    data_id = graph.unique_id(node.name)
                    graph.add_node(data_id, kind='data', shape=None, value=None)
                    new_data = Node(graph, data_id)
                    graph.add_edge(new_op.id, new_data.id, **{'out': 0})
                    new_op.infer(new_op)
                    input = new_data
            if input != old_input:
                # create a new edge from new data node after Tile application to the eltwise
                # and copy all edge attributes from the old edge
                # [0] is not what we really want
                graph.add_edge(input.id, node.id, **graph[old_input.id][node.id][0])
                graph.remove_edge(old_input.id, node.id)
            input_idx += 1
Пример #16
0
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['tile']
        name = node.soft_get('name', node.id)

        axis = node.axis
        tiles = node.tiles

        input_shape = node.in_port(0).data.get_shape()
        assert input_shape is not None
        tiles_input_value = int64_array(np.ones(input_shape.size))
        tiles_input_value[axis] = tiles

        const = Const(graph, {
            'value': tiles_input_value,
            'name': name + '/tiles'
        }).create_node()
        tile = Tile(graph, {'name': name}).create_node()

        node.out_port(0).get_connection().set_source(tile.out_port(0))
        node.in_port(0).get_connection().set_destination(tile.in_port(0))
        const.out_port(0).connect(tile.in_port(1))
Пример #17
0
 def extract(node):
     Tile.update_node_stat(node, {})
     return __class__.enabled
Пример #18
0
    def mxrepeat_decomposition(node: Node):
        graph = node.graph
        name = node.soft_get('name', node.id)

        rename_node(node, name + '/to_be_removed')

        # Unqueeze
        input_rank = Rank(graph, {'name': name + '/Rank'}).create_node()
        node.in_port(0).get_source().connect(input_rank.in_port(0))

        axis = get_canonical_axis_index_node(input_rank, node.axis)
        unsqueeze_axis = create_op_node_with_second_input(
            graph,
            Add,
            int64_array([1]), {'name': name + '/Unsqueeze/Axis'},
            input_node=axis)

        unsqueeze = Unsqueeze(graph, {
            'name': name + '/Unsqueeze'
        }).create_node()
        unsqueeze.in_port(1).connect(unsqueeze_axis.out_port(0))

        # Tile (1, 1, ..., repeats, ..., 1)
        # we generate tile array according to the following table:

        # parts:       |      first      |  repeats |  second     |
        # i:           | 0, 1, ..., axis,| axis + 1,| ..., rank+1 |
        # tile_array:  | 1, 1, ...,  1  ,| repeats ,| ...,   1    |

        one = Const(graph, {
            'name': name + '/Broadcast/One',
            'value': int64_array([1])
        }).create_node()
        first_ones = Broadcast(graph, {
            'name': name + '/Broadcast/Ones_first_part'
        }).create_node()
        first_ones.in_port(0).connect(one.out_port(0))
        first_ones.in_port(1).connect(unsqueeze_axis.out_port(0))

        repeats = Const(graph, {
            'name': name + '/repeats',
            'value': int64_array([node.repeats])
        }).create_node()

        second_ones = Broadcast(graph, {
            'name': name + '/Broadcast/Ones_second_part'
        }).create_node()
        second_part_broadcast_shape = Sub(
            graph, {
                'name': name + '/Broadcast/Shape/second_part'
            }).create_node()
        second_part_broadcast_shape.in_port(0).connect(input_rank.out_port(0))
        second_part_broadcast_shape.in_port(1).connect(
            unsqueeze_axis.out_port(0))
        second_ones.in_port(0).connect(one.out_port(0))
        second_ones.in_port(1).connect(second_part_broadcast_shape.out_port(0))

        tile_repeats = new_shape_node_from_shape_nodes(
            [first_ones, repeats, second_ones])
        tile = Tile(graph, {'name': name + '/Tile'}).create_node()
        tile.in_port(1).connect(tile_repeats.out_port(0))

        # Reshape (input_shape[:axis], input_shape[axis] * repeats, input_shape[axis+1:])
        # we generate reshape dim array according to the following table:

        # parts:       |    first   |                rep           |  second   |
        # i:           | 0, 1, ... ,|               axis,          | ..., rank |
        # dim_array:   | inp_sh[i] ,| input_shape[axis] * repeats ,| inp_sh[i] |

        input_shape = Shape(graph, {'name': name + '/Shape'}).create_node()
        node.in_port(0).get_source().connect(input_shape.in_port(0))

        first_input_shape_part = get_shape_values_by_range_idxs(
            input_shape,
            input_rank,
            begin=0,
            end=node.axis,
            include_begin=True,
            include_end=False)

        original_axis_dim = create_op_with_const_inputs(
            graph,
            Gather, {2: int64_array(0)}, {'name': name + '/OriginalDim'},
            input_node=input_shape)
        original_axis_dim.in_port(1).connect(axis.out_port(0))

        repeated_dimention = Mul(graph, {
            'name': name + '/RepeatedDim'
        }).create_node()
        repeated_dimention.in_port(0).connect(original_axis_dim.out_port(0))
        repeated_dimention.in_port(1).connect(repeats.out_port(0))

        second_input_shape_part = get_shape_values_by_range_idxs(
            input_shape,
            input_rank,
            begin=node.axis,
            end=-1,
            include_begin=False,
            include_end=True)

        output_shape = new_shape_node_from_shape_nodes([
            first_input_shape_part, repeated_dimention, second_input_shape_part
        ])

        reshape = Reshape(graph, {'name': name}).create_node()
        rename_node(reshape, name)
        reshape.in_port(1).connect(output_shape.out_port(0))

        # Final connections
        node.in_port(0).get_connection().set_destination(unsqueeze.in_port(0))
        tile.in_port(0).connect(unsqueeze.out_port(0))
        reshape.in_port(0).connect(tile.out_port(0))
        node.out_port(0).get_connection().set_source(reshape.out_port(0))
Пример #19
0
 def extract(cls, node):
     Tile.update_node_stat(node, {})
     return cls.enabled
Пример #20
0
 def extract(cls, node: Node):
     attrs = get_mxnet_layer_attrs(node.symbol_dict)
     Tile.update_node_stat(node, {
         'reps': attrs.tuple('reps', int, None),
     })
     return cls.enabled