def test_expand_dims_infer(self, axis, ref_out_shape): graph = build_graph(nodes_attributes, [('data_1', 'expand_dims'), ('expand_dims', 'data_2')], {'expand_dims': {'expand_axis': axis}}) expand_dims_node = Node(graph, 'expand_dims') ExpandDims.infer(expand_dims_node) self.assertTrue(np.array_equal(expand_dims_node.out_node().shape, np.array(ref_out_shape)))
def test_expand_dims_infer(self, axis, ref_out_shape): graph = build_graph(nodes_attributes, [('data_1', 'expand_dims'), ('expand_dims', 'data_2')], {'expand_dims': {'expand_axis': axis}}) Node(graph, 'data_1').shape = shape_array([2, 3, dynamic_dimension_value, 224]) expand_dims_node = Node(graph, 'expand_dims') ExpandDims.infer(expand_dims_node) self.assertTrue(strict_compare_tensors(expand_dims_node.out_node().shape, shape_array(ref_out_shape)))
def test_expand_dims_infer_value(self, axis, in_shape, ref_out_shape): in_value = np.random.rand(*in_shape) graph = build_graph(nodes_attributes, [('data_1', 'expand_dims'), ('expand_dims', 'data_2')], {'data_1': {'value': in_value}, 'expand_dims': {'expand_axis': axis}}) expand_dims_node = Node(graph, 'expand_dims') ExpandDims.infer(expand_dims_node) self.assertTrue(np.array_equal(expand_dims_node.out_node().shape, np.array(ref_out_shape))) self.assertTrue(np.array_equal(expand_dims_node.out_node().value, np.array(in_value.reshape(ref_out_shape))))
def replace_op(self, graph: Graph, node: Node): out_node = Concat( graph, { 'axis': node.axis, 'in_ports_count': len(node.in_ports()), 'name': node.name + '/Concat_', }).create_node() for ind in node.in_ports(): expand_dims_node = ExpandDims( graph, { 'expand_axis': int64_array([node.axis]), 'name': node.name + '/ExpandDims_' }).create_node() node.in_port(ind).get_connection().set_destination( expand_dims_node.in_port(0)) expand_dims_node.out_port(0).connect(out_node.in_port(ind)) # Replace edge from out port 0 of the matched node with a edge from node out_node.id with port 0. # The "explicit" version of the return value is: [(out_node.id, 0)]) return [out_node.id]
def replace_op(self, graph: nx.MultiDiGraph, node: Node): expand_dims_nodes = list() expand_axis_node = Const(graph, dict(value=node.axis)).create_node([]) for ind, edge_attrs in node.in_edges().items(): expand_dims_nodes.append( ExpandDims(graph, dict(name=node.name + '/ExpandDims_')).create_node([ (node.in_node(ind), edge_attrs['out']), expand_axis_node ])) out_node = Concat(graph, dict(name=node.name + '/Concat_', axis=node.axis)).create_node(expand_dims_nodes) # Replace edge from out port 0 of the matched node with a edge from node out_node.id with port 0. # The "explicit" version of the return value is: [(out_node.id, 0)]) return [out_node.id]
def extract(cls, node): axis = np.array(onnx_attr(node, 'axes', 'ints', default=[]), dtype=np.int64) ExpandDims.update_node_stat(node, {'expand_axis': axis}) return cls.enabled
def extract(cls, node): attrs = get_mxnet_layer_attrs(node.symbol_dict) expand_axis = attrs.int('axis', None) ExpandDims.update_node_stat(node, {'expand_axis': expand_axis}) return cls.enabled