def testListTupleInequality(self): generic = default_types.Generic list_a = default_types.List(generic(1), generic(2), generic(3)) list_b = default_types.List(generic(1), generic(2), generic(3)) tuple_a = default_types.Tuple(generic(1), generic(2), generic(3)) tuple_b = default_types.Tuple(generic(1), generic(2), generic(3)) self.assertEqual(list_a, list_b) self.assertEqual(tuple_a, tuple_b) self.assertNotEqual(list_a, tuple_a) self.assertNotEqual(tuple_a, list_a)
def testListTupleInequality(self): literal = default_types.Literal list_a = default_types.List(literal(1), literal(2), literal(3)) list_b = default_types.List(literal(1), literal(2), literal(3)) tuple_a = default_types.Tuple(literal(1), literal(2), literal(3)) tuple_b = default_types.Tuple(literal(1), literal(2), literal(3)) self.assertEqual(list_a, list_b) self.assertEqual(tuple_a, tuple_b) self.assertNotEqual(list_a, tuple_a) self.assertNotEqual(tuple_a, list_a)
def testListSupertype(self): list_a = default_types.List(MockSupertypes2With3(1), MockSupertypes2With3(2), MockSupertypes2With3(3)) list_b = default_types.List(MockSupertypes2With3(2), MockSupertypes2With3(2), MockSupertypes2With3(2)) self.assertEqual(list_a, list_a.most_specific_common_supertype([])) self.assertIsNone(list_a.most_specific_common_supertype([list_b])) self.assertEqual( list_b.most_specific_common_supertype([list_a]), default_types.List(MockSupertypes2With3(3), MockSupertypes2With3(3), MockSupertypes2With3(3)))
def from_object(obj: Any, context: trace.TracingContext = None) -> trace.TraceType: """Returns a TraceType corresponding to the object based on the context. Args: obj: The object to generate a TraceType for. context: The TracingContext to be shared during protocol calls. Returns: A TraceType object representing the given object. """ if context is None: context = InternalTracingContext() if isinstance(obj, trace.SupportsTracingProtocol): return obj.__tf_tracing_type__(context) if hasattr(obj, "__wrapped__"): return from_object(obj.__wrapped__, context) if isinstance(obj, list): return default_types.List(*(from_object(c, context) for c in obj)) if isinstance(obj, tuple): if util.is_namedtuple(obj): named_tuple_type = type(obj) return default_types.NamedTuple.from_type_and_attributes( named_tuple_type, tuple(from_object(c, context) for c in obj)) else: return default_types.Tuple(*(from_object(c, context) for c in obj)) if isinstance(obj, collections.abc.Mapping): return default_types.Dict( {k: from_object(obj[k], context) for k in obj}) if util.is_attrs(obj): return default_types.Attrs.from_type_and_attributes( type(obj), tuple( from_object(getattr(obj, a.name), context) for a in obj.__attrs_attrs__)) try: ref = weakref.ref(obj, context.deletion_observer) if ref is None: raise TypeError( f"Deleted objects are not valid tf.function arguments, Got {obj!r}" ) else: return default_types.Weakref(ref) except TypeError: try: return default_types.Literal(obj) except: raise TypeError( f"Python object could not be represented through the generic tracing " f"type. Consider implementing the Tracing Protocol for it: {obj!r}" )
def testListSerialization(self): list_original = default_types.List(default_types.Literal(1), default_types.Literal(2), default_types.Literal(3)) self.assertEqual( serialization.deserialize(serialization.serialize(list_original)), list_original)
def testListSupertype(self): class Supertypable(default_types.Generic): def most_specific_common_supertype(self, others): if not others: return self if self._object == 2 and isinstance(others[0]._object, int): return Supertypable(3) else: return None list_a = default_types.List(Supertypable(1), Supertypable(2), Supertypable(3)) list_b = default_types.List(Supertypable(2), Supertypable(2), Supertypable(2)) self.assertEqual(list_a, list_a.most_specific_common_supertype([])) self.assertIsNone(list_a.most_specific_common_supertype([list_b])) self.assertEqual( list_b.most_specific_common_supertype([list_a]), default_types.List(Supertypable(3), Supertypable(3), Supertypable(3)))
def create_trace_type(obj: Any, context: SignatureContext) -> trace.TraceType: """Returns a TraceType corresponding to the object based on the context. Args: obj: The object to generate a TraceType for. context: The TracingContext to be shared during protocol calls. Returns: A TraceType object representing the given object. """ if isinstance(obj, trace.SupportsTracingProtocol): return obj.__tf_tracing_type__(context) if isinstance(obj, list): return default_types.List(*(create_trace_type(c, context) for c in obj)) if isinstance(obj, tuple): return default_types.Tuple(*(create_trace_type(c, context) for c in obj)) if isinstance(obj, collections.abc.Mapping): return default_types.Dict( {k: create_trace_type(obj[k], context) for k in obj}) if hasattr(type(obj), "__attrs_attrs__"): return default_types.Attrs( type(obj), (create_trace_type(getattr(obj, a.name), context) for a in obj.__attrs_attrs__)) if hasattr(obj, "__wrapped__"): return create_trace_type(obj.__wrapped__, context) try: ref = weakref.ref(obj, context.deletion_observer) if ref is None: raise TypeError( f"Deleted objects are not valid tf.function arguments, Got {obj!r}" ) else: return default_types.Weakref(ref) except TypeError: try: return default_types.Generic(obj) except: raise TypeError( f"Python object could not be represented through the generic tracing " f"type. Consider implementing the Tracing Protocol for it: {obj!r}" )