def testComplexStruct(self): struct = {(1, 2, 3): {(1, 2): {12: 2}}, (3, 2, 3): (2, {2: 3})} trace_a = function_trace_type.get_arg_spec(struct, False, False, True) trace_b = function_trace_type.get_arg_spec(struct, False, False, True) self.assertEqual(trace_a, trace_b) self.assertTrue(trace_a.is_subtype_of(trace_b)) self.assertTrue(trace_b.is_subtype_of(trace_a))
def testCustomUnequableTypeSucceeds(self): class CustomUnequable: def __eq__(self, o): raise ValueError def __hash__(self): return 0 object_a = CustomUnequable() object_b = CustomUnequable() trace_a_1 = function_trace_type.get_arg_spec(object_a, False, True, True) trace_a_2 = function_trace_type.get_arg_spec(object_a, False, True, True) trace_b = function_trace_type.get_arg_spec(object_b, False, True, True) self.assertEqual(trace_a_1, trace_a_2) with self.assertRaises(ValueError): trace_a_1.__eq__(trace_b) del object_a self.assertNotEqual(trace_a_1, trace_a_2) self.assertNotEqual(trace_a_2, trace_a_1) del object_b self.assertNotEqual(trace_a_1, trace_a_2) self.assertNotEqual(trace_a_2, trace_a_1)
def testCompositeAndSpec(self): composite_tensor = ragged_tensor.RaggedTensor.from_row_splits( values=[1, 2, 3], row_splits=[0, 2, 3]) spec = ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32) self.assertEqual( function_trace_type.get_arg_spec(composite_tensor, False, False, True), function_trace_type.get_arg_spec(spec, False, False, True))
def testDictEquality(self): trace_a = function_trace_type.get_arg_spec({1: 2, 3: 4}, False, False, True) trace_b = function_trace_type.get_arg_spec({1: 2, 3: 2}, False, False, True) trace_c = function_trace_type.get_arg_spec({1: 2, 3: 0}, False, False, True) trace_d = function_trace_type.get_arg_spec({3: 4, 1: 2}, False, False, True) self.assertNotEqual(trace_a, trace_b) self.assertNotEqual(trace_a, trace_c) self.assertNotEqual(trace_b, trace_c) self.assertEqual(trace_a, trace_d)
def testListEquality(self): trace_a = function_trace_type.get_arg_spec([1, 2, 3, 4], False, False, True) trace_b = function_trace_type.get_arg_spec([1, 2, 2, 4], False, False, True) trace_c = function_trace_type.get_arg_spec([1, 2, 3], False, False, True) trace_d = function_trace_type.get_arg_spec([1, 2, 3, 4], False, False, True) self.assertNotEqual(trace_a, trace_b) self.assertNotEqual(trace_a, trace_c) self.assertNotEqual(trace_b, trace_c) self.assertEqual(trace_a, trace_d)
def testCustomUnhashableTypeFailsGracefully(self): class CustomUnhashable: def __eq__(self, o): return True obj = CustomUnhashable() with self.assertRaisesRegex( errors.InvalidArgumentError, r'could not be represented through the generic tracing type'): function_trace_type.get_arg_spec(obj, False, True, True)
def testCustomUnhashableTypeFailsGracefully(self): class CustomUnhashable: def __eq__(self, o): return True obj = CustomUnhashable() with self.assertRaisesRegex( errors.InvalidArgumentError, r'Could not determine tracing type of generic object'): function_trace_type.get_arg_spec(obj, False, True, True)
def make_cache_key(args, include_tensor_ranks_only: bool = False ) -> FunctionCacheKey: """Computes the cache key given the function arguments.""" arg_spec = function_trace_type.get_arg_spec( args, include_tensor_ranks_only, _ENCODE_VARIABLES_BY_RESOURCE_ID, USE_FULL_TRACE_TYPE) return FunctionCacheKey(arg_spec, _make_execution_context())
def testVariableAliasing(self): v1 = resource_variable_ops.ResourceVariable([1]) v2 = resource_variable_ops.ResourceVariable([1]) v3 = resource_variable_ops.ResourceVariable([1]) all_unique = function_trace_type.get_arg_spec((v1, v2, v3), False, True, True) all_same = function_trace_type.get_arg_spec((v1, v1, v1), False, True, True) self.assertNotEqual(all_unique, all_same) v3 = resource_variable_ops.ResourceVariable([2]) v4 = resource_variable_ops.ResourceVariable([2]) v5 = resource_variable_ops.ResourceVariable([2]) all_unique_again = function_trace_type.get_arg_spec((v3, v4, v5), False, True, True) all_same_again = function_trace_type.get_arg_spec((v4, v4, v4), False, True, True) self.assertEqual(all_unique, all_unique_again) self.assertEqual(all_same, all_same_again)
def testAttrsCacheKeyGeneration(self): if attr is None: self.skipTest('attr module is unavailable.') trace_a = function_trace_type.get_arg_spec( TestAttrsClass(1, 2), False, False, True) expected = function_trace_type.AttrsType( TestAttrsClass, (function_trace_type.GenericType(1), function_trace_type.GenericType(2))) self.assertEqual(trace_a, expected) self.assertTrue(trace_a.is_subtype_of(trace_a))
def make_cache_key_from_args(args, kwargs, include_tensor_ranks_only=False): """Computes the cache key given inputs and execution context.""" inputs = (args, kwargs) signature = function_trace_type.get_arg_spec( inputs, include_tensor_ranks_only, ENCODE_VARIABLES_BY_RESOURCE_ID, USE_FULL_TRACE_TYPE) (parent_graph, device_functions, colocation_stack, in_cross_replica_context, variable_policy, xla_context_id) = _cache_key_context() return CacheKey(signature, parent_graph, device_functions, colocation_stack, in_cross_replica_context, variable_policy, xla_context_id)
def testIteratorAliasing(self): it1 = iter(dataset_ops.DatasetV2.from_tensor_slices([1, 2, 3])) it2 = iter(dataset_ops.DatasetV2.from_tensor_slices([1, 2, 3])) self.assertEqual( function_trace_type.get_arg_spec((it1, it1), False, False, True), function_trace_type.get_arg_spec((it2, it2), False, False, True)) self.assertEqual( function_trace_type.get_arg_spec((it1, it2), False, False, True), function_trace_type.get_arg_spec((it2, it1), False, False, True)) self.assertNotEqual( function_trace_type.get_arg_spec((it1, it1), False, False, True), function_trace_type.get_arg_spec((it1, it2), False, False, True))
def testGeneric(self): function_trace_type.get_arg_spec(1, False, True, True) function_trace_type.get_arg_spec(DummyGenericClass(), False, True, True)
def encode_variables(var_list): function_trace_type.get_arg_spec(var_list, False, False, True)
def encode_tensor_specs(tensor_specs): function_trace_type.get_arg_spec(tensor_specs, False, False, True)
def encode_model(model): function_trace_type.get_arg_spec(model, False, False, function.USE_FULL_TRACE_TYPE)
def encode_struct(struct): function_trace_type.get_arg_spec(struct, False, False, function.USE_FULL_TRACE_TYPE)
def encode_tensor_specs(tensor_specs): function_trace_type.get_arg_spec(tensor_specs, False, False, function.USE_FULL_TRACE_TYPE)
def encode_variables(var_list): function_trace_type.get_arg_spec(var_list, False, False, function.USE_FULL_TRACE_TYPE)
def testList(self): function_trace_type.get_arg_spec([1, 2, 3], False, True, True)
def testAttrs(self): if attr is None: self.skipTest('attr module is unavailable.') function_trace_type.get_arg_spec(TestAttrsClass(1, 2), False, True, True)
def testDict(self): function_trace_type.get_arg_spec({1: 1, 2: 2, 3: 3}, False, True, True)
def testTuple(self): function_trace_type.get_arg_spec((1, 2, 3), False, True, True)
def testTensor(self): tensor = array_ops.zeros([10]) function_trace_type.get_arg_spec(tensor, False, True, True)
def encode_model(model): function_trace_type.get_arg_spec(model, False, False, True)
def encode_struct(struct): function_trace_type.get_arg_spec(struct, False, False, True)