예제 #1
0
    def test_undefined_shapes(self):
        spec = input_spec.InputSpec(max_ndim=5)
        with self.assertRaisesRegex(ValueError, 'unknown TensorShape'):
            input_spec.to_tensor_shape(spec).as_list()

        spec = input_spec.InputSpec(min_ndim=5, max_ndim=5)
        with self.assertRaisesRegex(ValueError, 'unknown TensorShape'):
            input_spec.to_tensor_shape(spec).as_list()
예제 #2
0
    def test_defined_ndims(self):
        spec = input_spec.InputSpec(ndim=5)
        self.assertAllEqual([None] * 5,
                            input_spec.to_tensor_shape(spec).as_list())

        spec = input_spec.InputSpec(ndim=0)
        self.assertAllEqual([], input_spec.to_tensor_shape(spec).as_list())

        spec = input_spec.InputSpec(ndim=3, axes={1: 3, -1: 2})
        self.assertAllEqual([None, 3, 2],
                            input_spec.to_tensor_shape(spec).as_list())
예제 #3
0
 def test_defined_shape(self):
     spec = input_spec.InputSpec(shape=[1, None, 2, 3])
     self.assertAllEqual([1, None, 2, 3],
                         input_spec.to_tensor_shape(spec).as_list())