def extract(cls, node):
        attrs = get_mxnet_layer_attrs(node.symbol_dict)
        axes = list(attrs.tuple("axes", int, []))
        node_attrs = {'axes': axes}

        # update the attributes of the node
        SliceLike.update_node_stat(node, node_attrs)
        return cls.enabled
Ejemplo n.º 2
0
 def test_4(self):
     graph = build_graph(nodes_attributes, edges,
                         {'slice_like': {
                             'axes': (-1, )
                         }})
     slice_like = Node(graph, 'slice_like')
     SliceLike.infer(slice_like)
     ref_shape = int64_array([3, 3])
     res_shape = graph.node['out_data']['shape']
     self.assertTrue(np.array_equal(res_shape, ref_shape))
Ejemplo n.º 3
0
 def test_3(self):
     graph = build_graph(nodes_attributes, edges,
                         {'slice_like': {
                             'axes': (0, )
                         }})
     slice_like = Node(graph, 'slice_like')
     SliceLike.infer(slice_like)
     ref_shape = int64_array([2, 4])
     ref_value = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
     res_shape = graph.node['out_data']['shape']
     res_value = graph.node['out_data']['value']
     self.assertTrue(np.array_equal(res_shape, ref_shape))
     self.assertTrue(np.array_equal(res_value, ref_value))
Ejemplo n.º 4
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        slice_like = match['slice_like']
        const = slice_like.in_nodes()[0]
        crop_shape = slice_like.in_nodes()[1]

        variants_dict = {
            'mul_scalar1x': 0.1,
            'mul_scalar2x': 0.2,
            'mul_scalar1y': 0.1,
            'mul_scalar2y': 0.2
        }
        for matches in find_pattern_matches(graph,
                                            self.variants_pattern['nodes'],
                                            self.variants_pattern['edges'],
                                            None, None):
            for k, v in matches.items():
                if v in variants_dict.keys():
                    variants_dict[v] = Node(graph, k).in_nodes()[1].value[0]

        variants = np.array([
            variants_dict['mul_scalar1x'], variants_dict['mul_scalar1y'],
            variants_dict['mul_scalar2x'], variants_dict['mul_scalar2y']
        ] * int(const.value.size / 4)).reshape(const.value.shape)
        priorbox_variants = Const(
            graph, dict(value=variants,
                        name=const.id + '/priorbox_variants')).create_node()
        variants_slice_like = SliceLike(
            graph,
            dict(axes=slice_like.axes,
                 name=slice_like.id + '/variants_slice_like')).create_node()
        variants_slice_like.in_port(0).connect(priorbox_variants.out_port(0))
        variants_slice_like.in_port(1).connect(crop_shape.out_port(0))

        concat = match['reshape3'].out_port(0).get_destination().node
        assert concat.op == 'Concat'
        concat_nodes_count = len(concat.in_nodes())
        concat.add_input_port(concat_nodes_count)
        concat.in_port(concat_nodes_count).get_connection().set_source(
            variants_slice_like.out_port(0))