def testTensorShapeUnknown(self): context = function_trace_type.SignatureContext() spec_1 = tensor_spec.TensorSpec( None, dtype=dtypes.int32)._tf_tracing_type(context) spec_2 = tensor_spec.TensorSpec( None, dtype=dtypes.int32)._tf_tracing_type(context) self.assertEqual(spec_1, spec_2)
def testTensorAndSpecEquality(self): context = function_trace_type.SignatureContext() tensor = array_ops.zeros([11, 3, 5], dtype=dtypes.int32)._tf_tracing_type(context) spec = tensor_spec.TensorSpec([11, 3, 5], dtype=dtypes.int32)._tf_tracing_type(context) spec_with_name = tensor_spec.TensorSpec( [11, 3, 5], dtype=dtypes.int32, name='name')._tf_tracing_type(context) self.assertEqual(tensor, spec) self.assertNotEqual(tensor, spec_with_name)
def make_cache_key( args, include_tensor_ranks_only: bool = False ) -> Tuple[FunctionCacheKey, function_trace_type.WeakrefDeletionObserver]: """Computes the cache key given the function arguments.""" signature_context = function_trace_type.SignatureContext( include_tensor_ranks_only) function_signature = function_trace_type.make_function_signature( args, signature_context) return FunctionCacheKey( function_signature, _make_execution_context()), signature_context.deletion_observer
def testTensorEquality(self): context = function_trace_type.SignatureContext() tensor_a = array_ops.zeros( [11, 3, 5], dtype=dtypes.int32)._tf_tracing_type(context) tensor_b = array_ops.zeros( [11, 4, 5], dtype=dtypes.int32)._tf_tracing_type(context) tensor_c = array_ops.zeros( [11, 3, 5], dtype=dtypes.float32)._tf_tracing_type(context) tensor_d = array_ops.ones([11, 3, 5], dtype=dtypes.int32)._tf_tracing_type(context) self.assertNotEqual(tensor_a, tensor_b) self.assertNotEqual(tensor_a, tensor_c) self.assertNotEqual(tensor_b, tensor_c) self.assertEqual(tensor_a, tensor_d)
def make_function_signature_with_context(inputs): return function_trace_type.make_function_signature( inputs, function_trace_type.SignatureContext())