def _maybe_as_partial_spec(spec: array_specs.Array):
    if -1 not in spec.shape:
        return spec

    if isinstance(spec, array_specs.BoundedArray):
        raise ValueError('Partial bounded arrays are not yet handled.')

    return partial_array_specs.PartialArray(spec.shape, spec.dtype, spec.name)
 def testValidateShape(
     self, value, is_valid, error_format=partial_array_specs._INVALID_SHAPE):
   spec = partial_array_specs.PartialArray((-1, 2), np.int32)
   if is_valid:  # Should not raise any exception.
     spec.validate(value)
   else:
     with self.assertRaisesWithLiteralMatch(
         ValueError, error_format % (value.shape, spec.shape)):
       spec.validate(value)
 def testSerialization(self):
   desc = partial_array_specs.PartialArray([-1, 5], np.float32, "test")
   self.assertEqual(pickle.loads(pickle.dumps(desc)), desc)
 def testGenerateValue(self):
   spec = partial_array_specs.PartialArray((2, -1), np.int32)
   test_value = spec.generate_value()
   spec.validate(test_value)
 def testShapeValueError(self, spec_shape, is_valid):
   if is_valid:
     partial_array_specs.PartialArray(spec_shape, np.int32)
   else:
     with self.assertRaises(ValueError):
       partial_array_specs.PartialArray(spec_shape, np.int32)