def test_expand_dims_infer(self, axis, ref_out_shape): graph = build_graph(nodes_attributes, [('data_1', 'expand_dims'), ('expand_dims', 'data_2')], {'expand_dims': {'expand_axis': axis}}) expand_dims_node = Node(graph, 'expand_dims') ExpandDims.infer(expand_dims_node) self.assertTrue(np.array_equal(expand_dims_node.out_node().shape, np.array(ref_out_shape)))
def test_expand_dims_infer(self, axis, ref_out_shape): graph = build_graph(nodes_attributes, [('data_1', 'expand_dims'), ('expand_dims', 'data_2')], {'expand_dims': {'expand_axis': axis}}) Node(graph, 'data_1').shape = shape_array([2, 3, dynamic_dimension_value, 224]) expand_dims_node = Node(graph, 'expand_dims') ExpandDims.infer(expand_dims_node) self.assertTrue(strict_compare_tensors(expand_dims_node.out_node().shape, shape_array(ref_out_shape)))
def test_expand_dims_infer_value(self, axis, in_shape, ref_out_shape): in_value = np.random.rand(*in_shape) graph = build_graph(nodes_attributes, [('data_1', 'expand_dims'), ('expand_dims', 'data_2')], {'data_1': {'value': in_value}, 'expand_dims': {'expand_axis': axis}}) expand_dims_node = Node(graph, 'expand_dims') ExpandDims.infer(expand_dims_node) self.assertTrue(np.array_equal(expand_dims_node.out_node().shape, np.array(ref_out_shape))) self.assertTrue(np.array_equal(expand_dims_node.out_node().value, np.array(in_value.reshape(ref_out_shape))))