def extract(cls, node): attrs = get_mxnet_layer_attrs(node.symbol_dict) axes = list(attrs.tuple("axes", int, [])) node_attrs = {'axes': axes} # update the attributes of the node SliceLike.update_node_stat(node, node_attrs) return cls.enabled
def test_4(self): graph = build_graph(nodes_attributes, edges, {'slice_like': { 'axes': (-1, ) }}) slice_like = Node(graph, 'slice_like') SliceLike.infer(slice_like) ref_shape = int64_array([3, 3]) ref_value = np.array([[1, 2, 3], [5, 6, 7], [9, 10, 11]]) res_shape = graph.node['out_data']['shape'] res_value = graph.node['out_data']['value'] self.assertTrue(np.array_equal(res_shape, ref_shape)) self.assertTrue(np.array_equal(res_value, ref_value))
def replace_sub_graph(self, graph: Graph, match: dict): slice_like = match['slice_like'] const = slice_like.in_nodes()[0] crop_shape = slice_like.in_nodes()[1] variants_dict = { 'mul_scalar1x': 0.1, 'mul_scalar2x': 0.2, 'mul_scalar1y': 0.1, 'mul_scalar2y': 0.2 } for matches in find_pattern_matches(graph, self.variants_pattern['nodes'], self.variants_pattern['edges'], None, None): for k, v in matches.items(): if v in variants_dict.keys(): variants_dict[v] = Node(graph, k).in_nodes()[1].value[0] variants = mo_array([ variants_dict['mul_scalar1x'], variants_dict['mul_scalar1y'], variants_dict['mul_scalar2x'], variants_dict['mul_scalar2y'] ] * int(const.value.size / 4)).reshape(const.value.shape) priorbox_variants = Const( graph, dict(value=variants, name=const.id + '/priorbox_variants')).create_node() variants_slice_like = SliceLike( graph, dict(axes=slice_like.axes, name=slice_like.id + '/variants_slice_like')).create_node() variants_slice_like.in_port(0).connect(priorbox_variants.out_port(0)) variants_slice_like.in_port(1).connect(crop_shape.out_port(0)) concat = match['reshape3'].out_port(0).get_destination().node assert concat.op == 'Concat' concat_nodes_count = len(concat.in_nodes()) concat.add_input_port(concat_nodes_count) concat.in_port(concat_nodes_count).get_connection().set_source( variants_slice_like.out_port(0))