def test_expand_dims_infer(self, axis, ref_out_shape): graph = build_graph(nodes_attributes, [('data_1', 'expand_dims'), ('expand_dims', 'data_2')], {'expand_dims': { 'expand_axis': axis }}) expand_dims_node = Node(graph, 'expand_dims') ExpandDims.infer(expand_dims_node) self.assertTrue( np.array_equal(expand_dims_node.out_node().shape, np.array(ref_out_shape)))
def test_expand_dims_infer(self, axis, ref_out_shape): graph = build_graph(nodes_attributes, [('data_1', 'expand_dims'), ('expand_dims', 'data_2')], {'expand_dims': { 'expand_axis': axis }}) Node(graph, 'data_1').shape = shape_array( [2, 3, dynamic_dimension_value, 224]) expand_dims_node = Node(graph, 'expand_dims') ExpandDims.infer(expand_dims_node) self.assertTrue( strict_compare_tensors(expand_dims_node.out_node().shape, shape_array(ref_out_shape)))
def test_expand_dims_infer_value(self, axis, in_shape, ref_out_shape): in_value = np.random.rand(*in_shape) graph = build_graph(nodes_attributes, [('data_1', 'expand_dims'), ('expand_dims', 'data_2')], { 'data_1': { 'value': in_value }, 'expand_dims': { 'expand_axis': axis } }) expand_dims_node = Node(graph, 'expand_dims') ExpandDims.infer(expand_dims_node) self.assertTrue( np.array_equal(expand_dims_node.out_node().shape, np.array(ref_out_shape))) self.assertTrue( np.array_equal(expand_dims_node.out_node().value, np.array(in_value.reshape(ref_out_shape))))
def extract(cls, node): attrs = get_mxnet_layer_attrs(node.symbol_dict) expand_axis = attrs.int('axis', None) ExpandDims.update_node_stat(node, {'expand_axis': expand_axis}) return cls.enabled
def extract(cls, node): axis = int64_array(onnx_attr(node, 'axes', 'ints', default=[])) ExpandDims.update_node_stat(node, {'expand_axis': axis}) return cls.enabled