def _extract_type_spec_recursively(value): """Return (collection of) TypeSpec(s) for `value` if it includes `Tensor`s. If `value` is a `Tensor` or `CompositeTensor`, return its `TypeSpec`. If `value` is a collection containing `Tensor` values, recursively supplant them with their respective `TypeSpec`s in a collection of parallel stucture. If `value` is nont of the above, return it unchanged. Args: value: a Python `object` to (possibly) turn into a (collection of) `tf.TypeSpec`(s). Returns: spec: the `TypeSpec` or collection of `TypeSpec`s corresponding to `value` or `value`, if no `Tensor`s are found. """ if isinstance(value, composite_tensor.CompositeTensor): return value._type_spec # pylint: disable=protected-access if isinstance(value, tf.Variable): return resource_variable_ops.VariableSpec( value.shape, dtype=value.dtype, trainable=value.trainable) if tf.is_tensor(value): return tf.TensorSpec(value.shape, value.dtype) if isinstance(value, (list, tuple)): specs = [_extract_type_spec_recursively(v) for v in value] has_tensors = any(a is not b for a, b in zip(value, specs)) has_only_tensors = all(a is not b for a, b in zip(value, specs)) if has_tensors: if has_tensors != has_only_tensors: raise NotImplementedError( 'Found `{}` with both Tensor and non-Tensor parts: {}' .format(type(value), value)) return type(value)(specs) return value
def _type_spec(self): input_spec = self._get_input_spec() transform_or_spec = getattr(self._transform_fn, '_type_spec', self._transform_fn) # Extract Variables from also_track. if self.also_track is None: also_track_spec = None else: also_track_vars = tf.nest.flatten( tf.nest.map_structure( lambda x: x.variables if isinstance(x, tf.Module) else x, self.also_track)) also_track_spec = tf.nest.map_structure( lambda x: resource_variable_ops.VariableSpec( # pylint: disable=g-long-lambda x.shape, x.dtype, trainable=x.trainable), also_track_vars) if isinstance(self.pretransformed_input, tf.Variable): return _DeferredTensorSpec( input_spec, transform_or_spec, dtype=self.dtype, shape=self.shape, name=self.name, also_track_spec=also_track_spec) return _DeferredTensorBatchableSpec( input_spec, transform_or_spec, dtype=self.dtype, shape=self.shape, name=self.name, also_track_spec=also_track_spec)
def testHash(self): shape = (1, 3) dtype = dtypes.int32 trainable = False spec = resource_variable_ops.VariableSpec(shape, dtype, trainable) spec_hash = hash(spec) expected_hash = hash((shape, dtype, trainable)) self.assertEqual(spec_hash, expected_hash)
def _type_spec(self): weight_specs = [] for w in self.weights: weight_specs.append( resource_variable_ops.VariableSpec(w.shape, w.dtype, w.name.split(":")[0], trainable=False)) return MeanMetricSpec(self.get_config(), weight_specs)
def _component_specs(self): ret = [] for w in self._weights: ret.append( resource_variable_ops.VariableSpec(w.shape, w.dtype, w.name.split(":")[0], trainable=False)) return ret
def _get_input_spec(self): if isinstance(self.pretransformed_input, tf.__internal__.CompositeTensor): return self.pretransformed_input._type_spec # pylint: disable=protected-access if isinstance(self.pretransformed_input, tf.Variable): return resource_variable_ops.VariableSpec( self.pretransformed_input.shape, dtype=self.pretransformed_input.dtype, trainable=self.pretransformed_input.trainable) return tf.TensorSpec.from_tensor(self.pretransformed_input)
def testRepr(self): shape = (1, 3) dtype = dtypes.int32 trainable = False spec = resource_variable_ops.VariableSpec(shape, dtype, trainable) spec_repr = repr(spec) expected_repr = ( f"VariableSpec(shape={shape}, dtype={dtype}, trainable={trainable})" ) self.assertEqual(spec_repr, expected_repr)
def testSerialize(self): shape = [1, 3] dtype = dtypes.int32 trainable = False spec = resource_variable_ops.VariableSpec(shape, dtype, trainable) serialization = spec._serialize() expected_serialization = (shape, dtype, trainable) self.assertEqual(serialization, expected_serialization) rebuilt_spec = spec._deserialize(serialization) self.assertEqual(rebuilt_spec, spec)
def test_variable_args_cannot_be_used_as_signature(self): @def_function.function(input_signature=[ resource_variable_ops.VariableSpec(shape=[], dtype=dtypes.int32)]) def f(unused_v): return 1 root = tracking.AutoTrackable() root.f = f.get_concrete_function() with self.assertRaisesRegexp(ValueError, "tf.Variable inputs cannot be exported"): save.save(root, os.path.join(self.get_temp_dir(), "saved_model"), signatures=root.f)
def _get_batched_input_spec(self, batch_size): """Returns the batched `input_spec` for the given `batch_size`.""" if isinstance(self._input_spec, type_spec.BatchableTypeSpec): return self._input_spec._batch(batch_size) # pylint: disable=protected-access if isinstance(self._input_spec, resource_variable_ops.VariableSpec): return resource_variable_ops.VariableSpec( shape=tf.TensorShape([batch_size]).concatenate( self._input_spec.shape), dtype=self._input_spec.dtype, trainable=self._input_spec.trainable) raise NotImplementedError( f'`{self.value_type.__name__}`s `TypeSpec` is not supported for ' f'inputs of type {type(self._input_spec)}.')
def _get_unbatched_input_spec(self): """Returns the `input_spec` with leading batch dimension removed.""" if isinstance(self._input_spec, type_spec.BatchableTypeSpec): return self._input_spec._unbatch() # pylint: disable=protected-access if isinstance(self._input_spec, resource_variable_ops.VariableSpec): return resource_variable_ops.VariableSpec( shape=(None if self._input_spec.shape is None else self._input_spec.shape[1:]), dtype=self._input_spec.dtype, trainable=self._input_spec.trainable) else: raise NotImplementedError( f'`{self.value_type.__name__}`s `TypeSpec` is not supported for ' f'inputs of type {type(self._input_spec)}.')
def testFromValue(self, initial_value=None, shape=None, dtype=dtypes.float32, trainable=True): if initial_value is None: var = resource_variable_ops.UninitializedVariable( shape=shape, dtype=dtype, trainable=trainable) else: var = resource_variable_ops.ResourceVariable( initial_value=initial_value, shape=shape, dtype=dtype, trainable=trainable) spec = resource_variable_ops.VariableSpec.from_value(var) expected_spec = resource_variable_ops.VariableSpec(shape=shape, dtype=dtype, trainable=trainable) self.assertEqual(spec, expected_spec)
def testEquality(self): spec = resource_variable_ops.VariableSpec([1, 3], dtypes.float32, False) spec2 = resource_variable_ops.VariableSpec([1, 3], dtypes.float32, False) self.assertEqual(spec2, spec) spec3 = resource_variable_ops.VariableSpec([1, 3], dtypes.float32, False) self.assertEqual(spec3, spec) spec4 = resource_variable_ops.VariableSpec([1, 3], dtypes.float32, True) self.assertNotEqual(spec4, spec) spec5 = resource_variable_ops.VariableSpec([3, 3], dtypes.float32, True) self.assertNotEqual(spec5, spec) spec6 = resource_variable_ops.VariableSpec([1, 3], dtypes.int32, True) self.assertNotEqual(spec6, spec)
def _type_spec(self): return ShardedVariableSpec( *(resource_variable_ops.VariableSpec(v.shape, v.dtype) for v in self._variables))
class DeferredTensorSpecTest(test_util.TestCase): @parameterized.named_parameters( ('DeferredTensorBijector', _make_deferred_tensor_spec( input_spec=tf.TensorSpec([4, 2], tf.float32), transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec), _make_deferred_tensor_spec( input_spec=tf.TensorSpec([4, 2], tf.float32), transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec)), ('DeferredTensorCallable', _make_deferred_tensor_spec(input_spec=tf.TensorSpec([4, 2], tf.float32), transform_or_spec=tf.math.sigmoid, shape=tf.TensorShape([None, 2]), name='one'), _make_deferred_tensor_spec(input_spec=tf.TensorSpec([4, 2], tf.float32), transform_or_spec=tf.math.sigmoid, shape=tf.TensorShape([None, 2]), name='two')), ('NestedDeferredTensor', _make_deferred_tensor_spec( input_spec=_make_deferred_tensor_spec( tf.TensorSpec([], tf.float32), transform_or_spec=tf.math.exp), transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec), _make_deferred_tensor_spec( input_spec=_make_deferred_tensor_spec( tf.TensorSpec([], tf.float32), transform_or_spec=tf.math.exp), transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec)), ('TransformedVariableBijector', _make_transformed_variable_spec( input_spec=tf.TensorSpec([4, None], tf.float32), transform_or_spec=_make_bijector_spec(tfb.Scale, [3.])), _make_transformed_variable_spec( input_spec=tf.TensorSpec([4, None], tf.float32), transform_or_spec=_make_bijector_spec(tfb.Scale, [3.]))), ('TranformedVariableCallable', _make_transformed_variable_spec( input_spec=resource_variable_ops.VariableSpec(None, tf.float64), transform_or_spec=tf.math.sigmoid, dtype=tf.float64, name='one'), _make_transformed_variable_spec( input_spec=resource_variable_ops.VariableSpec(None, tf.float64), transform_or_spec=tf.math.sigmoid, dtype=tf.float64, name='two')), ) def testEquality(self, v1, v2): # pylint: disable=g-generic-assert self.assertEqual(v1, v2) self.assertEqual(v2, v1) self.assertFalse(v1 != v2) self.assertFalse(v2 != v1) self.assertEqual(hash(v1), hash(v2)) @parameterized.named_parameters( ('DifferentInputSpecs', _make_deferred_tensor_spec( input_spec=tf.TensorSpec([4, 2], tf.float32), transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec), _make_deferred_tensor_spec( input_spec=tf.TensorSpec([None, 2], tf.float32), transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec)), ('DifferentBijectorSpecs', _make_deferred_tensor_spec( input_spec=tf.TensorSpec([4, 2], tf.float32), transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec, shape=tf.TensorShape([None, 2]), name='one'), _make_deferred_tensor_spec(input_spec=tf.TensorSpec([4, 2], tf.float32), transform_or_spec=tfb.Exp()._type_spec, shape=tf.TensorShape([None, 2]), name='two')), ('DifferentDtypes', _make_transformed_variable_spec( input_spec=tf.TensorSpec([4, 2], tf.float64), transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec, dtype=tf.float64), _make_transformed_variable_spec( input_spec=tf.TensorSpec([4, 2], tf.float32), transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec)), ('DifferentCallables', _make_transformed_variable_spec(input_spec=tf.TensorSpec([4, 2], tf.float64), transform_or_spec=tf.math.sigmoid, dtype=tf.float64, name='one'), _make_transformed_variable_spec(input_spec=tf.TensorSpec([4, 2], tf.float64), transform_or_spec=tf.math.softplus, dtype=tf.float64, name='two')), ('DifferentAlsoTrack', _make_deferred_tensor_spec(input_spec=tf.TensorSpec([], tf.float32), transform_or_spec=tf.math.exp), _make_deferred_tensor_spec(input_spec=tf.TensorSpec([], tf.float32), transform_or_spec=tf.math.exp, also_track_spec=[ resource_variable_ops.VariableSpec( [3, 2], tf.float32) ])), ('DifferentValueType', _make_deferred_tensor_spec(input_spec=tf.TensorSpec([], tf.float32), transform_or_spec=tf.math.exp), _make_transformed_variable_spec(input_spec=tf.TensorSpec([], tf.float32), transform_or_spec=tf.math.exp)), ) def testInequality(self, v1, v2): # pylint: disable=g-generic-assert self.assertNotEqual(v1, v2) self.assertNotEqual(v2, v1) self.assertFalse(v1 == v2) self.assertFalse(v2 == v1) @parameterized.named_parameters( ('DeferredTensorBijector', _make_deferred_tensor_spec( input_spec=tf.TensorSpec([4, 2], tf.float32), transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec), _make_deferred_tensor_spec( input_spec=tf.TensorSpec([4, 2], tf.float32), transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec)), ('DeferredTensorCallable', _make_deferred_tensor_spec(input_spec=tf.TensorSpec([4, 2], tf.float32), transform_or_spec=tf.math.sigmoid, shape=tf.TensorShape([4, 2]), name='one'), _make_deferred_tensor_spec(input_spec=tf.TensorSpec([4, 2], tf.float32), transform_or_spec=tf.math.sigmoid, shape=tf.TensorShape([None, 2]), name='two')), ('TransformedVariableBijector', _make_transformed_variable_spec( input_spec=resource_variable_ops.VariableSpec([4, 2], tf.float32), transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec), _make_transformed_variable_spec( input_spec=resource_variable_ops.VariableSpec([4, 2], tf.float32), transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec)), ('TransformedVariableCallable', _make_transformed_variable_spec(input_spec=tf.TensorSpec([4, 2], tf.float32), transform_or_spec=tf.math.sigmoid, name='one'), _make_transformed_variable_spec(input_spec=tf.TensorSpec([4, 2], tf.float32), transform_or_spec=tf.math.sigmoid, name='two')), ) def testIsCompatibleWith(self, v1, v2): self.assertTrue(v1.is_compatible_with(v2)) self.assertTrue(v2.is_compatible_with(v1)) @parameterized.named_parameters( ('IncompatibleInputSpecs', _make_deferred_tensor_spec( input_spec=tf.TensorSpec([4, 2], tf.float32), transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec), _make_deferred_tensor_spec( input_spec=tf.TensorSpec([None, 3], tf.float32), transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec)), ('DifferentDtypes', _make_transformed_variable_spec( input_spec=tf.TensorSpec([4, 2], tf.float32), transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec, dtype=tf.float64), _make_transformed_variable_spec( input_spec=tf.TensorSpec([4, 2], tf.float32), transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec)), ('DifferentCallables', _make_transformed_variable_spec(input_spec=tf.TensorSpec([4, 2], tf.float64), transform_or_spec=tf.math.sigmoid, dtype=tf.float64, name='one'), _make_transformed_variable_spec(input_spec=tf.TensorSpec([4, 2], tf.float64), transform_or_spec=tf.math.softplus, dtype=tf.float64, name='two')), ('DifferentAlsoTrack', _make_deferred_tensor_spec(input_spec=tf.TensorSpec([], tf.float32), transform_or_spec=tf.math.exp), _make_deferred_tensor_spec(input_spec=tf.TensorSpec([], tf.float32), transform_or_spec=tf.math.exp, also_track_spec=[ resource_variable_ops.VariableSpec( [3, 2], tf.float32) ])), ('DifferentValueType', _make_deferred_tensor_spec(input_spec=tf.TensorSpec([], tf.float32), transform_or_spec=tf.math.exp), _make_transformed_variable_spec(input_spec=tf.TensorSpec([], tf.float32), transform_or_spec=tf.math.exp)), ) def testIsNotCompatibleWith(self, v1, v2): self.assertFalse(v1.is_compatible_with(v2)) self.assertFalse(v2.is_compatible_with(v1)) @parameterized.named_parameters( ('DeferredTensor', _make_deferred_tensor_spec(input_spec=tf.TensorSpec([None, 2], tf.float32), transform_or_spec=tf.math.sigmoid), _make_deferred_tensor_spec(input_spec=tf.TensorSpec([4, 2], tf.float32), transform_or_spec=tf.math.sigmoid), _make_deferred_tensor_spec(input_spec=tf.TensorSpec([None, 2], tf.float32), transform_or_spec=tf.math.sigmoid)), ('TransformedVariableBijector', _make_transformed_variable_spec( input_spec=tf.TensorSpec([4, 2], tf.float32), transform_or_spec=_make_bijector_spec( tfb.Shift, [[2.]], use_variable=True, variable_shape=[1, 1])), _make_transformed_variable_spec( input_spec=tf.TensorSpec([4, 2], tf.float32), transform_or_spec=_make_bijector_spec(tfb.Shift, [[3.]], use_variable=True, variable_shape=[1, None])), _make_transformed_variable_spec( input_spec=tf.TensorSpec([4, 2], tf.float32), transform_or_spec=_make_bijector_spec(tfb.Shift, [[3.]], use_variable=True, variable_shape=[1, None]))), ('TransformedVariableCallable', _make_transformed_variable_spec( input_spec=resource_variable_ops.VariableSpec([4, 2], tf.float32), transform_or_spec=tf.math.sigmoid), _make_transformed_variable_spec( input_spec=resource_variable_ops.VariableSpec(None, tf.float32), transform_or_spec=tf.math.sigmoid), _make_transformed_variable_spec( input_spec=resource_variable_ops.VariableSpec(None, tf.float32), transform_or_spec=tf.math.sigmoid))) def testMostSpecificCompatibleType(self, v1, v2, expected): self.assertEqual(v1.most_specific_compatible_type(v2), expected) self.assertEqual(v2.most_specific_compatible_type(v1), expected) @parameterized.named_parameters( ('IncompatibleInputSpecs', _make_deferred_tensor_spec( input_spec=resource_variable_ops.VariableSpec([4, 2], tf.float32), transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec), _make_deferred_tensor_spec( input_spec=tf.TensorSpec([None, 3], tf.float32), transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec)), ('IncompatibleBijectorSpecs', _make_deferred_tensor_spec( input_spec=resource_variable_ops.VariableSpec([4, 2], tf.float32), transform_or_spec=tfb.Exp(validate_args=True)._type_spec), _make_deferred_tensor_spec( input_spec=tf.TensorSpec([None, 3], tf.float32), transform_or_spec=tfb.Exp(validate_args=False)._type_spec)), ('DifferentDtypes', _make_transformed_variable_spec( input_spec=tf.TensorSpec([], tf.float32), transform_or_spec=tfb.Sigmoid()._type_spec, dtype=tf.float64), _make_transformed_variable_spec( input_spec=tf.TensorSpec([], tf.float32), transform_or_spec=tfb.Sigmoid()._type_spec)), ('DifferentCallables', _make_transformed_variable_spec(input_spec=tf.TensorSpec([4, 2], tf.float64), transform_or_spec=tf.math.sigmoid, dtype=tf.float64, name='one'), _make_transformed_variable_spec(input_spec=tf.TensorSpec([4, 2], tf.float64), transform_or_spec=tf.math.softplus, dtype=tf.float64, name='two')), ) def testMostSpecificCompatibleTypeException(self, v1, v2): with self.assertRaises(ValueError): v1.most_specific_compatible_type(v2) with self.assertRaises(ValueError): v2.most_specific_compatible_type(v1) @parameterized.named_parameters( ('DeferredTensor', _make_deferred_tensor_spec( input_spec=tf.TensorSpec([4, 2], tf.float32), transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec)), ('TransformedVariable', _make_transformed_variable_spec( input_spec=tf.TensorSpec([4, 2], tf.float32), transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec, dtype=tf.float64))) def testRepr(self, spec): kwargs = dict(spec._specs, **spec._unique_id_params, name=spec.name) # pylint: disable=protected-access kwargs_str = ', '.join(f'{k}={v}' for k, v in kwargs.items()) expected = f'{type(spec).__name__}({kwargs_str})' self.assertEqual(repr(spec), expected)