예제 #1
0
 def testTensorShapeUnknown(self):
     context = function_trace_type.SignatureContext()
     spec_1 = tensor_spec.TensorSpec(
         None, dtype=dtypes.int32)._tf_tracing_type(context)
     spec_2 = tensor_spec.TensorSpec(
         None, dtype=dtypes.int32)._tf_tracing_type(context)
     self.assertEqual(spec_1, spec_2)
  def testTensorAndSpecEquality(self):
    context = function_trace_type.SignatureContext()
    tensor = array_ops.zeros([11, 3, 5],
                             dtype=dtypes.int32)._tf_tracing_type(context)
    spec = tensor_spec.TensorSpec([11, 3, 5],
                                  dtype=dtypes.int32)._tf_tracing_type(context)
    spec_with_name = tensor_spec.TensorSpec(
        [11, 3, 5], dtype=dtypes.int32, name='name')._tf_tracing_type(context)

    self.assertEqual(tensor, spec)
    self.assertNotEqual(tensor, spec_with_name)
예제 #3
0
def make_cache_key(
    args,
    include_tensor_ranks_only: bool = False
) -> Tuple[FunctionCacheKey, function_trace_type.WeakrefDeletionObserver]:
    """Computes the cache key given the function arguments."""
    signature_context = function_trace_type.SignatureContext(
        include_tensor_ranks_only)
    function_signature = function_trace_type.make_function_signature(
        args, signature_context)
    return FunctionCacheKey(
        function_signature,
        _make_execution_context()), signature_context.deletion_observer
예제 #4
0
    def testTensorEquality(self):
        context = function_trace_type.SignatureContext()
        tensor_a = array_ops.zeros(
            [11, 3, 5], dtype=dtypes.int32)._tf_tracing_type(context)
        tensor_b = array_ops.zeros(
            [11, 4, 5], dtype=dtypes.int32)._tf_tracing_type(context)
        tensor_c = array_ops.zeros(
            [11, 3, 5], dtype=dtypes.float32)._tf_tracing_type(context)
        tensor_d = array_ops.ones([11, 3, 5],
                                  dtype=dtypes.int32)._tf_tracing_type(context)

        self.assertNotEqual(tensor_a, tensor_b)
        self.assertNotEqual(tensor_a, tensor_c)
        self.assertNotEqual(tensor_b, tensor_c)
        self.assertEqual(tensor_a, tensor_d)
def make_function_signature_with_context(inputs):
    return function_trace_type.make_function_signature(
        inputs, function_trace_type.SignatureContext())