Beispiel #1
0
    def testCompositeAndSpec(self):
        composite_tensor = ragged_tensor.RaggedTensor.from_row_splits(
            values=[1, 2, 3], row_splits=[0, 2, 3])
        spec = ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32)

        self.assertEqual(trace_type.from_object(composite_tensor),
                         trace_type.from_object(spec))
Beispiel #2
0
    def testCustomUnequableTypeSucceeds(self):
        class CustomUnequable:
            def __eq__(self, o):
                raise ValueError

            def __hash__(self):
                return 0

        object_a = CustomUnequable()
        object_b = CustomUnequable()
        trace_a_1 = trace_type.from_object(object_a)
        trace_a_2 = trace_type.from_object(object_a)
        trace_b = trace_type.from_object(object_b)
        self.assertEqual(trace_a_1, trace_a_2)

        with self.assertRaises(ValueError):
            trace_a_1.__eq__(trace_b)

        del object_a
        self.assertNotEqual(trace_a_1, trace_a_2)
        self.assertNotEqual(trace_a_2, trace_a_1)

        del object_b
        self.assertNotEqual(trace_a_1, trace_a_2)
        self.assertNotEqual(trace_a_2, trace_a_1)
Beispiel #3
0
 def testComplexStruct(self):
     struct = {(1, 2, 3): {(1, 2): {12: 2}}, (3, 2, 3): (2, {2: 3})}
     trace_a = trace_type.from_object(struct)
     trace_b = trace_type.from_object(struct)
     self.assertEqual(trace_a, trace_b)
     self.assertTrue(trace_a.is_subtype_of(trace_b))
     self.assertTrue(trace_b.is_subtype_of(trace_a))
Beispiel #4
0
 def testDictEquality(self):
     trace_a = trace_type.from_object({1: 2, 3: 4})
     trace_b = trace_type.from_object({1: 2, 3: 2})
     trace_c = trace_type.from_object({1: 2, 3: 0})
     trace_d = trace_type.from_object({3: 4, 1: 2})
     self.assertNotEqual(trace_a, trace_b)
     self.assertNotEqual(trace_a, trace_c)
     self.assertNotEqual(trace_b, trace_c)
     self.assertEqual(trace_a, trace_d)
Beispiel #5
0
 def testListEquality(self):
     trace_a = trace_type.from_object([1, 2, 3, 4])
     trace_b = trace_type.from_object([1, 2, 2, 4])
     trace_c = trace_type.from_object([1, 2, 3])
     trace_d = trace_type.from_object([1, 2, 3, 4])
     self.assertNotEqual(trace_a, trace_b)
     self.assertNotEqual(trace_a, trace_c)
     self.assertNotEqual(trace_b, trace_c)
     self.assertEqual(trace_a, trace_d)
Beispiel #6
0
 def testTupleEquality(self):
     trace_a = trace_type.from_object((1, 2, 3, 4))
     trace_b = trace_type.from_object((1, 2, 2, 4))
     trace_c = trace_type.from_object((1, 2, 3))
     trace_d = trace_type.from_object((1, 2, 3, 4))
     self.assertNotEqual(trace_a, trace_b)
     self.assertNotEqual(trace_a, trace_c)
     self.assertNotEqual(trace_b, trace_c)
     self.assertEqual(trace_a, trace_d)
Beispiel #7
0
    def testCustomUnhashableTypeFailsGracefully(self):
        class CustomUnhashable:
            def __eq__(self, o):
                return True

        obj = CustomUnhashable()
        with self.assertRaisesRegex(
                TypeError,
                r'could not be represented through the generic tracing type'):
            trace_type.from_object(obj)
Beispiel #8
0
    def testWrappedNamedTuple(self):
        ActualType = collections.namedtuple('ActualType', ['a', 'b', 'c'])

        class MockWrapper(tuple):
            # Generated through trackable data structures:
            # //tensorflow/python/training/tracking/data_structures.py
            # With design pattern similar to Python functools:
            # https://docs.python.org/3/library/functools.html?highlight=__wrapped__#functools.update_wrapper
            __wrapped__ = ActualType(1, 2, 3)

        self.assertEqual(trace_type.from_object(MockWrapper()),
                         trace_type.from_object(ActualType(1, 2, 3)))
Beispiel #9
0
    def testVariableAliasing(self):
        v1 = resource_variable_ops.ResourceVariable([1])
        v2 = resource_variable_ops.ResourceVariable([1])
        v3 = resource_variable_ops.ResourceVariable([1])
        all_unique = trace_type.from_object((v1, v2, v3))
        all_same = trace_type.from_object((v1, v1, v1))
        self.assertNotEqual(all_unique, all_same)

        v3 = resource_variable_ops.ResourceVariable([2])
        v4 = resource_variable_ops.ResourceVariable([2])
        v5 = resource_variable_ops.ResourceVariable([2])
        all_unique_again = trace_type.from_object((v3, v4, v5))
        all_same_again = trace_type.from_object((v4, v4, v4))
        self.assertEqual(all_unique, all_unique_again)
        self.assertEqual(all_same, all_same_again)
Beispiel #10
0
 def testAttrsCacheKeyGeneration(self):
     trace_a = trace_type.from_object(TestAttrsClass(1, 2))
     expected = default_types.Attrs(
         TestAttrsClass,
         (default_types.Generic(1), default_types.Generic(2)))
     self.assertEqual(trace_a, expected)
     self.assertTrue(trace_a.is_subtype_of(trace_a))
Beispiel #11
0
 def testGetPlaceholderValue(self):
     composite_value = [
         1, 2, (3, [4, 5]), {
             6: [7]
         }, TestAttrsClass(8, (10, 11))
     ]
     composite_type = trace_type.from_object(composite_value)
     placeholder_value = composite_type._placeholder_value()
     self.assertEqual(composite_value, placeholder_value)
Beispiel #12
0
def make_cache_key(
    args: Any,
    captures: Any = None,
) -> Tuple[function_cache.FunctionCacheKey,
           trace_type.WeakrefDeletionObserver]:
    """Computes the cache key given the function arguments."""
    if captures is None:
        captures = dict()
    signature_context = trace_type.InternalTracingContext()
    args_signature = trace_type.from_object(args, signature_context)
    captures_dict_tracetype = trace_type.from_object(captures,
                                                     signature_context)
    captures_signature = function_cache.CaptureSnapshot(
        captures_dict_tracetype.mapping)

    return function_cache.FunctionCacheKey(
        args_signature, captures_signature,
        make_function_context()), signature_context.deletion_observer
Beispiel #13
0
def make_cache_key(
    args
) -> Tuple[function_cache.FunctionCacheKey, trace_type.WeakrefDeletionObserver]:
  """Computes the cache key given the function arguments."""
  signature_context = trace_type.InternalTracingContext()
  function_signature = trace_type.from_object(
      args, signature_context)
  return function_cache.FunctionCacheKey(
      function_signature,
      make_function_context()), signature_context.deletion_observer
Beispiel #14
0
    def testIteratorAliasing(self):
        it1 = iter(dataset_ops.DatasetV2.from_tensor_slices([1, 2, 3]))
        it2 = iter(dataset_ops.DatasetV2.from_tensor_slices([1, 2, 3]))

        self.assertEqual(trace_type.from_object((it1, it1)),
                         trace_type.from_object((it2, it2)))
        self.assertEqual(trace_type.from_object((it1, it2)),
                         trace_type.from_object((it2, it1)))
        self.assertNotEqual(trace_type.from_object((it1, it1)),
                            trace_type.from_object((it1, it2)))
Beispiel #15
0
 def testTensor(self):
     tensor = array_ops.zeros([10])
     trace_type.from_object(tensor)
Beispiel #16
0
 def testGeneric(self):
     trace_type.from_object(1)
     trace_type.from_object(DummyGenericClass())
Beispiel #17
0
 def encode_struct(struct):
     trace_type.from_object(struct)
Beispiel #18
0
 def encode_tensor_specs(tensor_specs):
     trace_type.from_object(tensor_specs)
Beispiel #19
0
 def encode_variables(var_list):
     trace_type.from_object(var_list)
Beispiel #20
0
 def encode_tensors(tensors):
     trace_type.from_object(tensors)
Beispiel #21
0
 def testAttrs(self):
     trace_type.from_object(TestAttrsClass(1, 2))
Beispiel #22
0
 def testList(self):
     trace_type.from_object([1, 2, 3])
Beispiel #23
0
 def testDict(self):
     trace_type.from_object({1: 1, 2: 2, 3: 3})
Beispiel #24
0
 def testTuple(self):
     trace_type.from_object((1, 2, 3))