def testAttrsCacheKeyGeneration(self): trace_a = trace_type.from_object(TestAttrsClass(1, 2)) expected = default_types.Attrs.from_type_and_attributes( TestAttrsClass, (default_types.Literal(1), default_types.Literal(2))) self.assertEqual(trace_a, expected) self.assertTrue(trace_a.is_subtype_of(trace_a))
def testTupleSerialization(self): tuple_original = default_types.Tuple(default_types.Literal(1), default_types.Literal(2), default_types.Literal(3)) self.assertEqual( serialization.deserialize(serialization.serialize(tuple_original)), tuple_original)
def testListSerialization(self): list_original = default_types.List(default_types.Literal(1), default_types.Literal(2), default_types.Literal(3)) self.assertEqual( serialization.deserialize(serialization.serialize(list_original)), list_original)
def testDictSerialization(self): dict_original = default_types.Dict({ 'a': default_types.Literal(1), 'b': default_types.Literal(2), 'c': default_types.Literal(3) }) self.assertEqual( serialization.deserialize(serialization.serialize(dict_original)), dict_original)
def testLiteralSupertypes(self): literal_a = default_types.Literal(1) literal_b = default_types.Literal(2) literal_c = default_types.Literal(1) self.assertEqual(literal_a, literal_a.most_specific_common_supertype([])) self.assertEqual(literal_a, literal_a.most_specific_common_supertype([literal_a])) self.assertEqual(literal_a, literal_a.most_specific_common_supertype([literal_c])) self.assertIsNone(literal_a.most_specific_common_supertype([literal_b ]))
def from_object(obj: Any, context: trace.TracingContext = None) -> trace.TraceType: """Returns a TraceType corresponding to the object based on the context. Args: obj: The object to generate a TraceType for. context: The TracingContext to be shared during protocol calls. Returns: A TraceType object representing the given object. """ if context is None: context = InternalTracingContext() if isinstance(obj, trace.SupportsTracingProtocol): return obj.__tf_tracing_type__(context) if hasattr(obj, "__wrapped__"): return from_object(obj.__wrapped__, context) if isinstance(obj, list): return default_types.List(*(from_object(c, context) for c in obj)) if isinstance(obj, tuple): if util.is_namedtuple(obj): named_tuple_type = type(obj) return default_types.NamedTuple.from_type_and_attributes( named_tuple_type, tuple(from_object(c, context) for c in obj)) else: return default_types.Tuple(*(from_object(c, context) for c in obj)) if isinstance(obj, collections.abc.Mapping): return default_types.Dict( {k: from_object(obj[k], context) for k in obj}) if util.is_attrs(obj): return default_types.Attrs.from_type_and_attributes( type(obj), tuple( from_object(getattr(obj, a.name), context) for a in obj.__attrs_attrs__)) try: ref = weakref.ref(obj, context.deletion_observer) if ref is None: raise TypeError( f"Deleted objects are not valid tf.function arguments, Got {obj!r}" ) else: return default_types.Weakref(ref) except TypeError: try: return default_types.Literal(obj) except: raise TypeError( f"Python object could not be represented through the generic tracing " f"type. Consider implementing the Tracing Protocol for it: {obj!r}" )
def testLiteralSerialization(self): literal_bool = default_types.Literal(True) literal_int = default_types.Literal(1) literal_float = default_types.Literal(1.2) literal_str = default_types.Literal('a') self.assertEqual( serialization.deserialize(serialization.serialize(literal_bool)), literal_bool) self.assertEqual( serialization.deserialize(serialization.serialize(literal_int)), literal_int) self.assertEqual( serialization.deserialize(serialization.serialize(literal_float)), literal_float) self.assertEqual( serialization.deserialize(serialization.serialize(literal_str)), literal_str)
def testReferenceSubtype(self): original = default_types.Reference(Mock2AsTopType(3), 1) clone = default_types.Reference(Mock2AsTopType(3), 1) different_id = default_types.Reference(Mock2AsTopType(3), 2) supertype = default_types.Reference(Mock2AsTopType(2), 1) different_type = default_types.Literal(1) self.assertEqual(original, clone) self.assertFalse(original.is_subtype_of(different_id)) self.assertTrue(original.is_subtype_of(supertype)) self.assertFalse(supertype.is_subtype_of(original)) self.assertFalse(original.is_subtype_of(different_type))
def testReferenceSupertype(self): original = default_types.Reference(Mock2AsTopType(3), 1) clone = default_types.Reference(Mock2AsTopType(3), 1) different_id = default_types.Reference(Mock2AsTopType(3), 2) supertype = default_types.Reference(Mock2AsTopType(2), 1) different_type = default_types.Literal(1) self.assertEqual(supertype.most_specific_common_supertype([]), supertype) self.assertEqual(original.most_specific_common_supertype([clone]), original) self.assertIsNone( original.most_specific_common_supertype([different_id])) self.assertIsNone( original.most_specific_common_supertype([different_type]))
def testReferencetSerialization(self): ref_original = default_types.Reference(default_types.Literal(3), 1) self.assertEqual( serialization.deserialize(serialization.serialize(ref_original)), ref_original)