def testAttrsCacheKeyGeneration(self): trace_a = trace_type.from_object(TestAttrsClass(1, 2)) expected = default_types.Attrs( TestAttrsClass, (default_types.Generic(1), default_types.Generic(2))) self.assertEqual(trace_a, expected) self.assertTrue(trace_a.is_subtype_of(trace_a))
def testAttrsCacheKeyGeneration(self): trace_a = make_function_signature_with_context(TestAttrsClass(1, 2)) expected = default_types.Attrs( TestAttrsClass, (default_types.Generic(1), default_types.Generic(2))) self.assertEqual(trace_a, expected) self.assertTrue(trace_a.is_subtype_of(trace_a))
def testGenericSupertypes(self): generic_a = default_types.Generic(1) generic_b = default_types.Generic(2) generic_c = default_types.Generic(1) self.assertEqual(generic_a, generic_a.most_specific_common_supertype([])) self.assertEqual(generic_a, generic_a.most_specific_common_supertype([generic_a])) self.assertEqual(generic_a, generic_a.most_specific_common_supertype([generic_c])) self.assertIsNone(generic_a.most_specific_common_supertype([generic_b ]))
def from_object(obj: Any, context: trace.TracingContext = None) -> trace.TraceType: """Returns a TraceType corresponding to the object based on the context. Args: obj: The object to generate a TraceType for. context: The TracingContext to be shared during protocol calls. Returns: A TraceType object representing the given object. """ if context is None: context = InternalTracingContext() if isinstance(obj, trace.SupportsTracingProtocol): return obj.__tf_tracing_type__(context) if hasattr(obj, "__wrapped__"): return from_object(obj.__wrapped__, context) if isinstance(obj, list): return default_types.List(*(from_object(c, context) for c in obj)) if isinstance(obj, tuple): if util.is_namedtuple(obj): return default_types.NamedTuple( type(obj), tuple(from_object(c, context) for c in obj)) else: return default_types.Tuple(*(from_object(c, context) for c in obj)) if isinstance(obj, collections.abc.Mapping): return default_types.Dict( {k: from_object(obj[k], context) for k in obj}) if util.is_attrs(obj): return default_types.Attrs( type(obj), tuple( from_object(getattr(obj, a.name), context) for a in obj.__attrs_attrs__)) try: ref = weakref.ref(obj, context.deletion_observer) if ref is None: raise TypeError( f"Deleted objects are not valid tf.function arguments, Got {obj!r}" ) else: return default_types.Weakref(ref) except TypeError: try: return default_types.Generic(obj) except: raise TypeError( f"Python object could not be represented through the generic tracing " f"type. Consider implementing the Tracing Protocol for it: {obj!r}" )
def testReferenceSubtype(self): class MockSubtypeOf2(default_types.Generic): def is_subtype_of(self, other): return other._object == 2 original = default_types.Reference(MockSubtypeOf2(3), 1) clone = default_types.Reference(MockSubtypeOf2(3), 1) different_id = default_types.Reference(MockSubtypeOf2(3), 2) supertype = default_types.Reference(MockSubtypeOf2(2), 1) different_type = default_types.Generic(1) self.assertEqual(original, clone) self.assertFalse(original.is_subtype_of(different_id)) self.assertTrue(original.is_subtype_of(supertype)) self.assertFalse(supertype.is_subtype_of(original)) self.assertFalse(original.is_subtype_of(different_type))
def testReferenceSupertype(self): class Mock2AsTopType(default_types.Generic): def most_specific_common_supertype(self, types): if not all( isinstance(other, Mock2AsTopType) for other in types): return None return self if all(self._object == other._object for other in types) else Mock2AsTopType(2) original = default_types.Reference(Mock2AsTopType(3), 1) clone = default_types.Reference(Mock2AsTopType(3), 1) different_id = default_types.Reference(Mock2AsTopType(3), 2) supertype = default_types.Reference(Mock2AsTopType(2), 1) different_type = default_types.Generic(1) self.assertEqual(supertype.most_specific_common_supertype([]), supertype) self.assertEqual(original.most_specific_common_supertype([clone]), original) self.assertIsNone( original.most_specific_common_supertype([different_id])) self.assertIsNone( original.most_specific_common_supertype([different_type]))