Example #1
0
    def test_expand_dims_infer(self, axis, ref_out_shape):
        graph = build_graph(nodes_attributes,
                            [('data_1', 'expand_dims'),
                             ('expand_dims', 'data_2')],
                            {'expand_dims': {'expand_axis': axis}})
        expand_dims_node = Node(graph, 'expand_dims')

        ExpandDims.infer(expand_dims_node)

        self.assertTrue(np.array_equal(expand_dims_node.out_node().shape, np.array(ref_out_shape)))
Example #2
0
    def test_expand_dims_infer(self, axis, ref_out_shape):
        graph = build_graph(nodes_attributes,
                            [('data_1', 'expand_dims'),
                             ('expand_dims', 'data_2')],
                            {'expand_dims': {'expand_axis': axis}})
        Node(graph, 'data_1').shape = shape_array([2, 3, dynamic_dimension_value, 224])
        expand_dims_node = Node(graph, 'expand_dims')

        ExpandDims.infer(expand_dims_node)

        self.assertTrue(strict_compare_tensors(expand_dims_node.out_node().shape, shape_array(ref_out_shape)))
Example #3
0
    def test_expand_dims_infer_value(self, axis, in_shape, ref_out_shape):
        in_value = np.random.rand(*in_shape)
        graph = build_graph(nodes_attributes,
                            [('data_1', 'expand_dims'),
                             ('expand_dims', 'data_2')],
                            {'data_1': {'value': in_value},
                             'expand_dims': {'expand_axis': axis}})
        expand_dims_node = Node(graph, 'expand_dims')

        ExpandDims.infer(expand_dims_node)

        self.assertTrue(np.array_equal(expand_dims_node.out_node().shape, np.array(ref_out_shape)))
        self.assertTrue(np.array_equal(expand_dims_node.out_node().value, np.array(in_value.reshape(ref_out_shape))))