예제 #1
0
def from_object(obj: Any,
                context: trace.TracingContext = None) -> trace.TraceType:
    """Returns a TraceType corresponding to the object based on the context.

  Args:
    obj: The object to generate a TraceType for.
    context: The TracingContext to be shared during protocol calls.

  Returns:
    A TraceType object representing the given object.
  """

    if context is None:
        context = InternalTracingContext()

    if isinstance(obj, trace.SupportsTracingProtocol):
        return obj.__tf_tracing_type__(context)

    if hasattr(obj, "__wrapped__"):
        return from_object(obj.__wrapped__, context)

    if isinstance(obj, list):
        return default_types.List(*(from_object(c, context) for c in obj))

    if isinstance(obj, tuple):
        if util.is_namedtuple(obj):
            named_tuple_type = type(obj)
            return default_types.NamedTuple.from_type_and_attributes(
                named_tuple_type, tuple(from_object(c, context) for c in obj))
        else:
            return default_types.Tuple(*(from_object(c, context) for c in obj))

    if isinstance(obj, collections.abc.Mapping):
        return default_types.Dict(
            {k: from_object(obj[k], context)
             for k in obj})

    if util.is_attrs(obj):
        return default_types.Attrs.from_type_and_attributes(
            type(obj),
            tuple(
                from_object(getattr(obj, a.name), context)
                for a in obj.__attrs_attrs__))

    try:
        ref = weakref.ref(obj, context.deletion_observer)
        if ref is None:
            raise TypeError(
                f"Deleted objects are not valid tf.function arguments, Got {obj!r}"
            )
        else:
            return default_types.Weakref(ref)
    except TypeError:
        try:
            return default_types.Literal(obj)
        except:
            raise TypeError(
                f"Python object could not be represented through the generic tracing "
                f"type. Consider implementing the Tracing Protocol for it: {obj!r}"
            )
예제 #2
0
def create_trace_type(obj: Any, context: SignatureContext) -> trace.TraceType:
    """Returns a TraceType corresponding to the object based on the context.

  Args:
    obj: The object to generate a TraceType for.
    context: The TracingContext to be shared during protocol calls.

  Returns:
    A TraceType object representing the given object.
  """

    if isinstance(obj, trace.SupportsTracingProtocol):
        return obj.__tf_tracing_type__(context)

    if isinstance(obj, list):
        return default_types.List(*(create_trace_type(c, context)
                                    for c in obj))

    if isinstance(obj, tuple):
        return default_types.Tuple(*(create_trace_type(c, context)
                                     for c in obj))

    if isinstance(obj, collections.abc.Mapping):
        return default_types.Dict(
            {k: create_trace_type(obj[k], context)
             for k in obj})

    if hasattr(type(obj), "__attrs_attrs__"):
        return default_types.Attrs(
            type(obj), (create_trace_type(getattr(obj, a.name), context)
                        for a in obj.__attrs_attrs__))

    if hasattr(obj, "__wrapped__"):
        return create_trace_type(obj.__wrapped__, context)

    try:
        ref = weakref.ref(obj, context.deletion_observer)
        if ref is None:
            raise TypeError(
                f"Deleted objects are not valid tf.function arguments, Got {obj!r}"
            )
        else:
            return default_types.Weakref(ref)
    except TypeError:
        try:
            return default_types.Generic(obj)
        except:
            raise TypeError(
                f"Python object could not be represented through the generic tracing "
                f"type. Consider implementing the Tracing Protocol for it: {obj!r}"
            )