def _maybe_as_partial_spec(spec: array_specs.Array): if -1 not in spec.shape: return spec if isinstance(spec, array_specs.BoundedArray): raise ValueError('Partial bounded arrays are not yet handled.') return partial_array_specs.PartialArray(spec.shape, spec.dtype, spec.name)
def testValidateShape( self, value, is_valid, error_format=partial_array_specs._INVALID_SHAPE): spec = partial_array_specs.PartialArray((-1, 2), np.int32) if is_valid: # Should not raise any exception. spec.validate(value) else: with self.assertRaisesWithLiteralMatch( ValueError, error_format % (value.shape, spec.shape)): spec.validate(value)
def testSerialization(self): desc = partial_array_specs.PartialArray([-1, 5], np.float32, "test") self.assertEqual(pickle.loads(pickle.dumps(desc)), desc)
def testGenerateValue(self): spec = partial_array_specs.PartialArray((2, -1), np.int32) test_value = spec.generate_value() spec.validate(test_value)
def testShapeValueError(self, spec_shape, is_valid): if is_valid: partial_array_specs.PartialArray(spec_shape, np.int32) else: with self.assertRaises(ValueError): partial_array_specs.PartialArray(spec_shape, np.int32)