def from_object(obj: Any, context: trace.TracingContext = None) -> trace.TraceType: """Returns a TraceType corresponding to the object based on the context. Args: obj: The object to generate a TraceType for. context: The TracingContext to be shared during protocol calls. Returns: A TraceType object representing the given object. """ if context is None: context = InternalTracingContext() if isinstance(obj, trace.SupportsTracingProtocol): return obj.__tf_tracing_type__(context) if hasattr(obj, "__wrapped__"): return from_object(obj.__wrapped__, context) if isinstance(obj, list): return default_types.List(*(from_object(c, context) for c in obj)) if isinstance(obj, tuple): if util.is_namedtuple(obj): named_tuple_type = type(obj) return default_types.NamedTuple.from_type_and_attributes( named_tuple_type, tuple(from_object(c, context) for c in obj)) else: return default_types.Tuple(*(from_object(c, context) for c in obj)) if isinstance(obj, collections.abc.Mapping): return default_types.Dict( {k: from_object(obj[k], context) for k in obj}) if util.is_attrs(obj): return default_types.Attrs.from_type_and_attributes( type(obj), tuple( from_object(getattr(obj, a.name), context) for a in obj.__attrs_attrs__)) try: ref = weakref.ref(obj, context.deletion_observer) if ref is None: raise TypeError( f"Deleted objects are not valid tf.function arguments, Got {obj!r}" ) else: return default_types.Weakref(ref) except TypeError: try: return default_types.Literal(obj) except: raise TypeError( f"Python object could not be represented through the generic tracing " f"type. Consider implementing the Tracing Protocol for it: {obj!r}" )
def create_trace_type(obj: Any, context: SignatureContext) -> trace.TraceType: """Returns a TraceType corresponding to the object based on the context. Args: obj: The object to generate a TraceType for. context: The TracingContext to be shared during protocol calls. Returns: A TraceType object representing the given object. """ if isinstance(obj, trace.SupportsTracingProtocol): return obj.__tf_tracing_type__(context) if isinstance(obj, list): return default_types.List(*(create_trace_type(c, context) for c in obj)) if isinstance(obj, tuple): return default_types.Tuple(*(create_trace_type(c, context) for c in obj)) if isinstance(obj, collections.abc.Mapping): return default_types.Dict( {k: create_trace_type(obj[k], context) for k in obj}) if hasattr(type(obj), "__attrs_attrs__"): return default_types.Attrs( type(obj), (create_trace_type(getattr(obj, a.name), context) for a in obj.__attrs_attrs__)) if hasattr(obj, "__wrapped__"): return create_trace_type(obj.__wrapped__, context) try: ref = weakref.ref(obj, context.deletion_observer) if ref is None: raise TypeError( f"Deleted objects are not valid tf.function arguments, Got {obj!r}" ) else: return default_types.Weakref(ref) except TypeError: try: return default_types.Generic(obj) except: raise TypeError( f"Python object could not be represented through the generic tracing " f"type. Consider implementing the Tracing Protocol for it: {obj!r}" )