def testCompositeAndSpec(self): composite_tensor = ragged_tensor.RaggedTensor.from_row_splits( values=[1, 2, 3], row_splits=[0, 2, 3]) spec = ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32) self.assertEqual(trace_type.from_object(composite_tensor), trace_type.from_object(spec))
def testCustomUnequableTypeSucceeds(self): class CustomUnequable: def __eq__(self, o): raise ValueError def __hash__(self): return 0 object_a = CustomUnequable() object_b = CustomUnequable() trace_a_1 = trace_type.from_object(object_a) trace_a_2 = trace_type.from_object(object_a) trace_b = trace_type.from_object(object_b) self.assertEqual(trace_a_1, trace_a_2) with self.assertRaises(ValueError): trace_a_1.__eq__(trace_b) del object_a self.assertNotEqual(trace_a_1, trace_a_2) self.assertNotEqual(trace_a_2, trace_a_1) del object_b self.assertNotEqual(trace_a_1, trace_a_2) self.assertNotEqual(trace_a_2, trace_a_1)
def testComplexStruct(self): struct = {(1, 2, 3): {(1, 2): {12: 2}}, (3, 2, 3): (2, {2: 3})} trace_a = trace_type.from_object(struct) trace_b = trace_type.from_object(struct) self.assertEqual(trace_a, trace_b) self.assertTrue(trace_a.is_subtype_of(trace_b)) self.assertTrue(trace_b.is_subtype_of(trace_a))
def testDictEquality(self): trace_a = trace_type.from_object({1: 2, 3: 4}) trace_b = trace_type.from_object({1: 2, 3: 2}) trace_c = trace_type.from_object({1: 2, 3: 0}) trace_d = trace_type.from_object({3: 4, 1: 2}) self.assertNotEqual(trace_a, trace_b) self.assertNotEqual(trace_a, trace_c) self.assertNotEqual(trace_b, trace_c) self.assertEqual(trace_a, trace_d)
def testListEquality(self): trace_a = trace_type.from_object([1, 2, 3, 4]) trace_b = trace_type.from_object([1, 2, 2, 4]) trace_c = trace_type.from_object([1, 2, 3]) trace_d = trace_type.from_object([1, 2, 3, 4]) self.assertNotEqual(trace_a, trace_b) self.assertNotEqual(trace_a, trace_c) self.assertNotEqual(trace_b, trace_c) self.assertEqual(trace_a, trace_d)
def testTupleEquality(self): trace_a = trace_type.from_object((1, 2, 3, 4)) trace_b = trace_type.from_object((1, 2, 2, 4)) trace_c = trace_type.from_object((1, 2, 3)) trace_d = trace_type.from_object((1, 2, 3, 4)) self.assertNotEqual(trace_a, trace_b) self.assertNotEqual(trace_a, trace_c) self.assertNotEqual(trace_b, trace_c) self.assertEqual(trace_a, trace_d)
def testCustomUnhashableTypeFailsGracefully(self): class CustomUnhashable: def __eq__(self, o): return True obj = CustomUnhashable() with self.assertRaisesRegex( TypeError, r'could not be represented through the generic tracing type'): trace_type.from_object(obj)
def testWrappedNamedTuple(self): ActualType = collections.namedtuple('ActualType', ['a', 'b', 'c']) class MockWrapper(tuple): # Generated through trackable data structures: # //tensorflow/python/training/tracking/data_structures.py # With design pattern similar to Python functools: # https://docs.python.org/3/library/functools.html?highlight=__wrapped__#functools.update_wrapper __wrapped__ = ActualType(1, 2, 3) self.assertEqual(trace_type.from_object(MockWrapper()), trace_type.from_object(ActualType(1, 2, 3)))
def testVariableAliasing(self): v1 = resource_variable_ops.ResourceVariable([1]) v2 = resource_variable_ops.ResourceVariable([1]) v3 = resource_variable_ops.ResourceVariable([1]) all_unique = trace_type.from_object((v1, v2, v3)) all_same = trace_type.from_object((v1, v1, v1)) self.assertNotEqual(all_unique, all_same) v3 = resource_variable_ops.ResourceVariable([2]) v4 = resource_variable_ops.ResourceVariable([2]) v5 = resource_variable_ops.ResourceVariable([2]) all_unique_again = trace_type.from_object((v3, v4, v5)) all_same_again = trace_type.from_object((v4, v4, v4)) self.assertEqual(all_unique, all_unique_again) self.assertEqual(all_same, all_same_again)
def testAttrsCacheKeyGeneration(self): trace_a = trace_type.from_object(TestAttrsClass(1, 2)) expected = default_types.Attrs( TestAttrsClass, (default_types.Generic(1), default_types.Generic(2))) self.assertEqual(trace_a, expected) self.assertTrue(trace_a.is_subtype_of(trace_a))
def testGetPlaceholderValue(self): composite_value = [ 1, 2, (3, [4, 5]), { 6: [7] }, TestAttrsClass(8, (10, 11)) ] composite_type = trace_type.from_object(composite_value) placeholder_value = composite_type._placeholder_value() self.assertEqual(composite_value, placeholder_value)
def make_cache_key( args: Any, captures: Any = None, ) -> Tuple[function_cache.FunctionCacheKey, trace_type.WeakrefDeletionObserver]: """Computes the cache key given the function arguments.""" if captures is None: captures = dict() signature_context = trace_type.InternalTracingContext() args_signature = trace_type.from_object(args, signature_context) captures_dict_tracetype = trace_type.from_object(captures, signature_context) captures_signature = function_cache.CaptureSnapshot( captures_dict_tracetype.mapping) return function_cache.FunctionCacheKey( args_signature, captures_signature, make_function_context()), signature_context.deletion_observer
def make_cache_key( args ) -> Tuple[function_cache.FunctionCacheKey, trace_type.WeakrefDeletionObserver]: """Computes the cache key given the function arguments.""" signature_context = trace_type.InternalTracingContext() function_signature = trace_type.from_object( args, signature_context) return function_cache.FunctionCacheKey( function_signature, make_function_context()), signature_context.deletion_observer
def testIteratorAliasing(self): it1 = iter(dataset_ops.DatasetV2.from_tensor_slices([1, 2, 3])) it2 = iter(dataset_ops.DatasetV2.from_tensor_slices([1, 2, 3])) self.assertEqual(trace_type.from_object((it1, it1)), trace_type.from_object((it2, it2))) self.assertEqual(trace_type.from_object((it1, it2)), trace_type.from_object((it2, it1))) self.assertNotEqual(trace_type.from_object((it1, it1)), trace_type.from_object((it1, it2)))
def testTensor(self): tensor = array_ops.zeros([10]) trace_type.from_object(tensor)
def testGeneric(self): trace_type.from_object(1) trace_type.from_object(DummyGenericClass())
def encode_struct(struct): trace_type.from_object(struct)
def encode_tensor_specs(tensor_specs): trace_type.from_object(tensor_specs)
def encode_variables(var_list): trace_type.from_object(var_list)
def encode_tensors(tensors): trace_type.from_object(tensors)
def testAttrs(self): trace_type.from_object(TestAttrsClass(1, 2))
def testList(self): trace_type.from_object([1, 2, 3])
def testDict(self): trace_type.from_object({1: 1, 2: 2, 3: 3})
def testTuple(self): trace_type.from_object((1, 2, 3))