def test_undefined_shapes(self): spec = input_spec.InputSpec(max_ndim=5) with self.assertRaisesRegex(ValueError, 'unknown TensorShape'): input_spec.to_tensor_shape(spec).as_list() spec = input_spec.InputSpec(min_ndim=5, max_ndim=5) with self.assertRaisesRegex(ValueError, 'unknown TensorShape'): input_spec.to_tensor_shape(spec).as_list()
def test_defined_ndims(self): spec = input_spec.InputSpec(ndim=5) self.assertAllEqual([None] * 5, input_spec.to_tensor_shape(spec).as_list()) spec = input_spec.InputSpec(ndim=0) self.assertAllEqual([], input_spec.to_tensor_shape(spec).as_list()) spec = input_spec.InputSpec(ndim=3, axes={1: 3, -1: 2}) self.assertAllEqual([None, 3, 2], input_spec.to_tensor_shape(spec).as_list())
def test_defined_shape(self): spec = input_spec.InputSpec(shape=[1, None, 2, 3]) self.assertAllEqual([1, None, 2, 3], input_spec.to_tensor_shape(spec).as_list())