예제 #1
0
def _extract_type_spec_recursively(value):
  """Return (collection of) TypeSpec(s) for `value` if it includes `Tensor`s.

  If `value` is a `Tensor` or `CompositeTensor`, return its `TypeSpec`. If
  `value` is a collection containing `Tensor` values, recursively supplant them
  with their respective `TypeSpec`s in a collection of parallel stucture.

  If `value` is nont of the above, return it unchanged.

  Args:
    value: a Python `object` to (possibly) turn into a (collection of)
    `tf.TypeSpec`(s).

  Returns:
    spec: the `TypeSpec` or collection of `TypeSpec`s corresponding to `value`
    or `value`, if no `Tensor`s are found.
  """
  if isinstance(value, composite_tensor.CompositeTensor):
    return value._type_spec  # pylint: disable=protected-access
  if isinstance(value, tf.Variable):
    return resource_variable_ops.VariableSpec(
        value.shape, dtype=value.dtype, trainable=value.trainable)
  if tf.is_tensor(value):
    return tf.TensorSpec(value.shape, value.dtype)
  if isinstance(value, (list, tuple)):
    specs = [_extract_type_spec_recursively(v) for v in value]
    has_tensors = any(a is not b for a, b in zip(value, specs))
    has_only_tensors = all(a is not b for a, b in zip(value, specs))
    if has_tensors:
      if has_tensors != has_only_tensors:
        raise NotImplementedError(
            'Found `{}` with both Tensor and non-Tensor parts: {}'
            .format(type(value), value))
      return type(value)(specs)
  return value
예제 #2
0
  def _type_spec(self):
    input_spec = self._get_input_spec()
    transform_or_spec = getattr(self._transform_fn, '_type_spec',
                                self._transform_fn)

    # Extract Variables from also_track.
    if self.also_track is None:
      also_track_spec = None
    else:
      also_track_vars = tf.nest.flatten(
          tf.nest.map_structure(
              lambda x: x.variables if isinstance(x, tf.Module) else x,
              self.also_track))
      also_track_spec = tf.nest.map_structure(
          lambda x: resource_variable_ops.VariableSpec(  # pylint: disable=g-long-lambda
              x.shape, x.dtype, trainable=x.trainable),
          also_track_vars)

    if isinstance(self.pretransformed_input, tf.Variable):
      return _DeferredTensorSpec(
          input_spec, transform_or_spec, dtype=self.dtype, shape=self.shape,
          name=self.name, also_track_spec=also_track_spec)
    return _DeferredTensorBatchableSpec(
        input_spec, transform_or_spec, dtype=self.dtype, shape=self.shape,
        name=self.name, also_track_spec=also_track_spec)
예제 #3
0
 def testHash(self):
     shape = (1, 3)
     dtype = dtypes.int32
     trainable = False
     spec = resource_variable_ops.VariableSpec(shape, dtype, trainable)
     spec_hash = hash(spec)
     expected_hash = hash((shape, dtype, trainable))
     self.assertEqual(spec_hash, expected_hash)
 def _type_spec(self):
     weight_specs = []
     for w in self.weights:
         weight_specs.append(
             resource_variable_ops.VariableSpec(w.shape,
                                                w.dtype,
                                                w.name.split(":")[0],
                                                trainable=False))
     return MeanMetricSpec(self.get_config(), weight_specs)
 def _component_specs(self):
     ret = []
     for w in self._weights:
         ret.append(
             resource_variable_ops.VariableSpec(w.shape,
                                                w.dtype,
                                                w.name.split(":")[0],
                                                trainable=False))
     return ret
예제 #6
0
 def _get_input_spec(self):
   if isinstance(self.pretransformed_input, tf.__internal__.CompositeTensor):
     return self.pretransformed_input._type_spec  # pylint: disable=protected-access
   if isinstance(self.pretransformed_input, tf.Variable):
     return resource_variable_ops.VariableSpec(
         self.pretransformed_input.shape,
         dtype=self.pretransformed_input.dtype,
         trainable=self.pretransformed_input.trainable)
   return tf.TensorSpec.from_tensor(self.pretransformed_input)
예제 #7
0
 def testRepr(self):
     shape = (1, 3)
     dtype = dtypes.int32
     trainable = False
     spec = resource_variable_ops.VariableSpec(shape, dtype, trainable)
     spec_repr = repr(spec)
     expected_repr = (
         f"VariableSpec(shape={shape}, dtype={dtype}, trainable={trainable})"
     )
     self.assertEqual(spec_repr, expected_repr)
예제 #8
0
 def testSerialize(self):
     shape = [1, 3]
     dtype = dtypes.int32
     trainable = False
     spec = resource_variable_ops.VariableSpec(shape, dtype, trainable)
     serialization = spec._serialize()
     expected_serialization = (shape, dtype, trainable)
     self.assertEqual(serialization, expected_serialization)
     rebuilt_spec = spec._deserialize(serialization)
     self.assertEqual(rebuilt_spec, spec)
예제 #9
0
 def test_variable_args_cannot_be_used_as_signature(self):
   @def_function.function(input_signature=[
       resource_variable_ops.VariableSpec(shape=[], dtype=dtypes.int32)])
   def f(unused_v):
     return 1
   root = tracking.AutoTrackable()
   root.f = f.get_concrete_function()
   with self.assertRaisesRegexp(ValueError,
                                "tf.Variable inputs cannot be exported"):
     save.save(root, os.path.join(self.get_temp_dir(), "saved_model"),
               signatures=root.f)
예제 #10
0
 def _get_batched_input_spec(self, batch_size):
   """Returns the batched `input_spec` for the given `batch_size`."""
   if isinstance(self._input_spec, type_spec.BatchableTypeSpec):
     return self._input_spec._batch(batch_size)  # pylint: disable=protected-access
   if isinstance(self._input_spec, resource_variable_ops.VariableSpec):
     return resource_variable_ops.VariableSpec(
         shape=tf.TensorShape([batch_size]).concatenate(
             self._input_spec.shape),
         dtype=self._input_spec.dtype,
         trainable=self._input_spec.trainable)
   raise NotImplementedError(
       f'`{self.value_type.__name__}`s `TypeSpec` is not supported for '
       f'inputs of type {type(self._input_spec)}.')
예제 #11
0
 def _get_unbatched_input_spec(self):
   """Returns the `input_spec` with leading batch dimension removed."""
   if isinstance(self._input_spec, type_spec.BatchableTypeSpec):
     return self._input_spec._unbatch()  # pylint: disable=protected-access
   if isinstance(self._input_spec, resource_variable_ops.VariableSpec):
     return resource_variable_ops.VariableSpec(
         shape=(None if self._input_spec.shape is None
                else self._input_spec.shape[1:]),
         dtype=self._input_spec.dtype,
         trainable=self._input_spec.trainable)
   else:
     raise NotImplementedError(
         f'`{self.value_type.__name__}`s `TypeSpec` is not supported for '
         f'inputs of type {type(self._input_spec)}.')
예제 #12
0
 def testFromValue(self,
                   initial_value=None,
                   shape=None,
                   dtype=dtypes.float32,
                   trainable=True):
     if initial_value is None:
         var = resource_variable_ops.UninitializedVariable(
             shape=shape, dtype=dtype, trainable=trainable)
     else:
         var = resource_variable_ops.ResourceVariable(
             initial_value=initial_value,
             shape=shape,
             dtype=dtype,
             trainable=trainable)
     spec = resource_variable_ops.VariableSpec.from_value(var)
     expected_spec = resource_variable_ops.VariableSpec(shape=shape,
                                                        dtype=dtype,
                                                        trainable=trainable)
     self.assertEqual(spec, expected_spec)
예제 #13
0
 def testEquality(self):
     spec = resource_variable_ops.VariableSpec([1, 3], dtypes.float32,
                                               False)
     spec2 = resource_variable_ops.VariableSpec([1, 3], dtypes.float32,
                                                False)
     self.assertEqual(spec2, spec)
     spec3 = resource_variable_ops.VariableSpec([1, 3], dtypes.float32,
                                                False)
     self.assertEqual(spec3, spec)
     spec4 = resource_variable_ops.VariableSpec([1, 3], dtypes.float32,
                                                True)
     self.assertNotEqual(spec4, spec)
     spec5 = resource_variable_ops.VariableSpec([3, 3], dtypes.float32,
                                                True)
     self.assertNotEqual(spec5, spec)
     spec6 = resource_variable_ops.VariableSpec([1, 3], dtypes.int32, True)
     self.assertNotEqual(spec6, spec)
예제 #14
0
 def _type_spec(self):
     return ShardedVariableSpec(
         *(resource_variable_ops.VariableSpec(v.shape, v.dtype)
           for v in self._variables))
예제 #15
0
class DeferredTensorSpecTest(test_util.TestCase):
    @parameterized.named_parameters(
        ('DeferredTensorBijector',
         _make_deferred_tensor_spec(
             input_spec=tf.TensorSpec([4, 2], tf.float32),
             transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec),
         _make_deferred_tensor_spec(
             input_spec=tf.TensorSpec([4, 2], tf.float32),
             transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec)),
        ('DeferredTensorCallable',
         _make_deferred_tensor_spec(input_spec=tf.TensorSpec([4, 2],
                                                             tf.float32),
                                    transform_or_spec=tf.math.sigmoid,
                                    shape=tf.TensorShape([None, 2]),
                                    name='one'),
         _make_deferred_tensor_spec(input_spec=tf.TensorSpec([4, 2],
                                                             tf.float32),
                                    transform_or_spec=tf.math.sigmoid,
                                    shape=tf.TensorShape([None, 2]),
                                    name='two')),
        ('NestedDeferredTensor',
         _make_deferred_tensor_spec(
             input_spec=_make_deferred_tensor_spec(
                 tf.TensorSpec([], tf.float32), transform_or_spec=tf.math.exp),
             transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec),
         _make_deferred_tensor_spec(
             input_spec=_make_deferred_tensor_spec(
                 tf.TensorSpec([], tf.float32), transform_or_spec=tf.math.exp),
             transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec)),
        ('TransformedVariableBijector',
         _make_transformed_variable_spec(
             input_spec=tf.TensorSpec([4, None], tf.float32),
             transform_or_spec=_make_bijector_spec(tfb.Scale, [3.])),
         _make_transformed_variable_spec(
             input_spec=tf.TensorSpec([4, None], tf.float32),
             transform_or_spec=_make_bijector_spec(tfb.Scale, [3.]))),
        ('TranformedVariableCallable',
         _make_transformed_variable_spec(
             input_spec=resource_variable_ops.VariableSpec(None, tf.float64),
             transform_or_spec=tf.math.sigmoid,
             dtype=tf.float64,
             name='one'),
         _make_transformed_variable_spec(
             input_spec=resource_variable_ops.VariableSpec(None, tf.float64),
             transform_or_spec=tf.math.sigmoid,
             dtype=tf.float64,
             name='two')),
    )
    def testEquality(self, v1, v2):
        # pylint: disable=g-generic-assert
        self.assertEqual(v1, v2)
        self.assertEqual(v2, v1)
        self.assertFalse(v1 != v2)
        self.assertFalse(v2 != v1)
        self.assertEqual(hash(v1), hash(v2))

    @parameterized.named_parameters(
        ('DifferentInputSpecs',
         _make_deferred_tensor_spec(
             input_spec=tf.TensorSpec([4, 2], tf.float32),
             transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec),
         _make_deferred_tensor_spec(
             input_spec=tf.TensorSpec([None, 2], tf.float32),
             transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec)),
        ('DifferentBijectorSpecs',
         _make_deferred_tensor_spec(
             input_spec=tf.TensorSpec([4, 2], tf.float32),
             transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec,
             shape=tf.TensorShape([None, 2]),
             name='one'),
         _make_deferred_tensor_spec(input_spec=tf.TensorSpec([4, 2],
                                                             tf.float32),
                                    transform_or_spec=tfb.Exp()._type_spec,
                                    shape=tf.TensorShape([None, 2]),
                                    name='two')),
        ('DifferentDtypes',
         _make_transformed_variable_spec(
             input_spec=tf.TensorSpec([4, 2], tf.float64),
             transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec,
             dtype=tf.float64),
         _make_transformed_variable_spec(
             input_spec=tf.TensorSpec([4, 2], tf.float32),
             transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec)),
        ('DifferentCallables',
         _make_transformed_variable_spec(input_spec=tf.TensorSpec([4, 2],
                                                                  tf.float64),
                                         transform_or_spec=tf.math.sigmoid,
                                         dtype=tf.float64,
                                         name='one'),
         _make_transformed_variable_spec(input_spec=tf.TensorSpec([4, 2],
                                                                  tf.float64),
                                         transform_or_spec=tf.math.softplus,
                                         dtype=tf.float64,
                                         name='two')),
        ('DifferentAlsoTrack',
         _make_deferred_tensor_spec(input_spec=tf.TensorSpec([], tf.float32),
                                    transform_or_spec=tf.math.exp),
         _make_deferred_tensor_spec(input_spec=tf.TensorSpec([], tf.float32),
                                    transform_or_spec=tf.math.exp,
                                    also_track_spec=[
                                        resource_variable_ops.VariableSpec(
                                            [3, 2], tf.float32)
                                    ])),
        ('DifferentValueType',
         _make_deferred_tensor_spec(input_spec=tf.TensorSpec([], tf.float32),
                                    transform_or_spec=tf.math.exp),
         _make_transformed_variable_spec(input_spec=tf.TensorSpec([],
                                                                  tf.float32),
                                         transform_or_spec=tf.math.exp)),
    )
    def testInequality(self, v1, v2):
        # pylint: disable=g-generic-assert
        self.assertNotEqual(v1, v2)
        self.assertNotEqual(v2, v1)
        self.assertFalse(v1 == v2)
        self.assertFalse(v2 == v1)

    @parameterized.named_parameters(
        ('DeferredTensorBijector',
         _make_deferred_tensor_spec(
             input_spec=tf.TensorSpec([4, 2], tf.float32),
             transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec),
         _make_deferred_tensor_spec(
             input_spec=tf.TensorSpec([4, 2], tf.float32),
             transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec)),
        ('DeferredTensorCallable',
         _make_deferred_tensor_spec(input_spec=tf.TensorSpec([4, 2],
                                                             tf.float32),
                                    transform_or_spec=tf.math.sigmoid,
                                    shape=tf.TensorShape([4, 2]),
                                    name='one'),
         _make_deferred_tensor_spec(input_spec=tf.TensorSpec([4, 2],
                                                             tf.float32),
                                    transform_or_spec=tf.math.sigmoid,
                                    shape=tf.TensorShape([None, 2]),
                                    name='two')),
        ('TransformedVariableBijector',
         _make_transformed_variable_spec(
             input_spec=resource_variable_ops.VariableSpec([4, 2], tf.float32),
             transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec),
         _make_transformed_variable_spec(
             input_spec=resource_variable_ops.VariableSpec([4, 2], tf.float32),
             transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec)),
        ('TransformedVariableCallable',
         _make_transformed_variable_spec(input_spec=tf.TensorSpec([4, 2],
                                                                  tf.float32),
                                         transform_or_spec=tf.math.sigmoid,
                                         name='one'),
         _make_transformed_variable_spec(input_spec=tf.TensorSpec([4, 2],
                                                                  tf.float32),
                                         transform_or_spec=tf.math.sigmoid,
                                         name='two')),
    )
    def testIsCompatibleWith(self, v1, v2):
        self.assertTrue(v1.is_compatible_with(v2))
        self.assertTrue(v2.is_compatible_with(v1))

    @parameterized.named_parameters(
        ('IncompatibleInputSpecs',
         _make_deferred_tensor_spec(
             input_spec=tf.TensorSpec([4, 2], tf.float32),
             transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec),
         _make_deferred_tensor_spec(
             input_spec=tf.TensorSpec([None, 3], tf.float32),
             transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec)),
        ('DifferentDtypes',
         _make_transformed_variable_spec(
             input_spec=tf.TensorSpec([4, 2], tf.float32),
             transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec,
             dtype=tf.float64),
         _make_transformed_variable_spec(
             input_spec=tf.TensorSpec([4, 2], tf.float32),
             transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec)),
        ('DifferentCallables',
         _make_transformed_variable_spec(input_spec=tf.TensorSpec([4, 2],
                                                                  tf.float64),
                                         transform_or_spec=tf.math.sigmoid,
                                         dtype=tf.float64,
                                         name='one'),
         _make_transformed_variable_spec(input_spec=tf.TensorSpec([4, 2],
                                                                  tf.float64),
                                         transform_or_spec=tf.math.softplus,
                                         dtype=tf.float64,
                                         name='two')),
        ('DifferentAlsoTrack',
         _make_deferred_tensor_spec(input_spec=tf.TensorSpec([], tf.float32),
                                    transform_or_spec=tf.math.exp),
         _make_deferred_tensor_spec(input_spec=tf.TensorSpec([], tf.float32),
                                    transform_or_spec=tf.math.exp,
                                    also_track_spec=[
                                        resource_variable_ops.VariableSpec(
                                            [3, 2], tf.float32)
                                    ])),
        ('DifferentValueType',
         _make_deferred_tensor_spec(input_spec=tf.TensorSpec([], tf.float32),
                                    transform_or_spec=tf.math.exp),
         _make_transformed_variable_spec(input_spec=tf.TensorSpec([],
                                                                  tf.float32),
                                         transform_or_spec=tf.math.exp)),
    )
    def testIsNotCompatibleWith(self, v1, v2):
        self.assertFalse(v1.is_compatible_with(v2))
        self.assertFalse(v2.is_compatible_with(v1))

    @parameterized.named_parameters(
        ('DeferredTensor',
         _make_deferred_tensor_spec(input_spec=tf.TensorSpec([None, 2],
                                                             tf.float32),
                                    transform_or_spec=tf.math.sigmoid),
         _make_deferred_tensor_spec(input_spec=tf.TensorSpec([4, 2],
                                                             tf.float32),
                                    transform_or_spec=tf.math.sigmoid),
         _make_deferred_tensor_spec(input_spec=tf.TensorSpec([None, 2],
                                                             tf.float32),
                                    transform_or_spec=tf.math.sigmoid)),
        ('TransformedVariableBijector',
         _make_transformed_variable_spec(
             input_spec=tf.TensorSpec([4, 2], tf.float32),
             transform_or_spec=_make_bijector_spec(
                 tfb.Shift, [[2.]], use_variable=True, variable_shape=[1, 1])),
         _make_transformed_variable_spec(
             input_spec=tf.TensorSpec([4, 2], tf.float32),
             transform_or_spec=_make_bijector_spec(tfb.Shift, [[3.]],
                                                   use_variable=True,
                                                   variable_shape=[1, None])),
         _make_transformed_variable_spec(
             input_spec=tf.TensorSpec([4, 2], tf.float32),
             transform_or_spec=_make_bijector_spec(tfb.Shift, [[3.]],
                                                   use_variable=True,
                                                   variable_shape=[1, None]))),
        ('TransformedVariableCallable',
         _make_transformed_variable_spec(
             input_spec=resource_variable_ops.VariableSpec([4, 2], tf.float32),
             transform_or_spec=tf.math.sigmoid),
         _make_transformed_variable_spec(
             input_spec=resource_variable_ops.VariableSpec(None, tf.float32),
             transform_or_spec=tf.math.sigmoid),
         _make_transformed_variable_spec(
             input_spec=resource_variable_ops.VariableSpec(None, tf.float32),
             transform_or_spec=tf.math.sigmoid)))
    def testMostSpecificCompatibleType(self, v1, v2, expected):
        self.assertEqual(v1.most_specific_compatible_type(v2), expected)
        self.assertEqual(v2.most_specific_compatible_type(v1), expected)

    @parameterized.named_parameters(
        ('IncompatibleInputSpecs',
         _make_deferred_tensor_spec(
             input_spec=resource_variable_ops.VariableSpec([4, 2], tf.float32),
             transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec),
         _make_deferred_tensor_spec(
             input_spec=tf.TensorSpec([None, 3], tf.float32),
             transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec)),
        ('IncompatibleBijectorSpecs',
         _make_deferred_tensor_spec(
             input_spec=resource_variable_ops.VariableSpec([4, 2], tf.float32),
             transform_or_spec=tfb.Exp(validate_args=True)._type_spec),
         _make_deferred_tensor_spec(
             input_spec=tf.TensorSpec([None, 3], tf.float32),
             transform_or_spec=tfb.Exp(validate_args=False)._type_spec)),
        ('DifferentDtypes',
         _make_transformed_variable_spec(
             input_spec=tf.TensorSpec([], tf.float32),
             transform_or_spec=tfb.Sigmoid()._type_spec,
             dtype=tf.float64),
         _make_transformed_variable_spec(
             input_spec=tf.TensorSpec([], tf.float32),
             transform_or_spec=tfb.Sigmoid()._type_spec)),
        ('DifferentCallables',
         _make_transformed_variable_spec(input_spec=tf.TensorSpec([4, 2],
                                                                  tf.float64),
                                         transform_or_spec=tf.math.sigmoid,
                                         dtype=tf.float64,
                                         name='one'),
         _make_transformed_variable_spec(input_spec=tf.TensorSpec([4, 2],
                                                                  tf.float64),
                                         transform_or_spec=tf.math.softplus,
                                         dtype=tf.float64,
                                         name='two')),
    )
    def testMostSpecificCompatibleTypeException(self, v1, v2):
        with self.assertRaises(ValueError):
            v1.most_specific_compatible_type(v2)
        with self.assertRaises(ValueError):
            v2.most_specific_compatible_type(v1)

    @parameterized.named_parameters(
        ('DeferredTensor',
         _make_deferred_tensor_spec(
             input_spec=tf.TensorSpec([4, 2], tf.float32),
             transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec)),
        ('TransformedVariable',
         _make_transformed_variable_spec(
             input_spec=tf.TensorSpec([4, 2], tf.float32),
             transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec,
             dtype=tf.float64)))
    def testRepr(self, spec):
        kwargs = dict(spec._specs, **spec._unique_id_params, name=spec.name)  # pylint: disable=protected-access
        kwargs_str = ', '.join(f'{k}={v}' for k, v in kwargs.items())
        expected = f'{type(spec).__name__}({kwargs_str})'
        self.assertEqual(repr(spec), expected)