def testListTupleInequality(self):
        literal = default_types.Literal

        list_a = default_types.List(literal(1), literal(2), literal(3))
        list_b = default_types.List(literal(1), literal(2), literal(3))

        tuple_a = default_types.Tuple(literal(1), literal(2), literal(3))
        tuple_b = default_types.Tuple(literal(1), literal(2), literal(3))

        self.assertEqual(list_a, list_b)
        self.assertEqual(tuple_a, tuple_b)
        self.assertNotEqual(list_a, tuple_a)
        self.assertNotEqual(tuple_a, list_a)
    def testListTupleInequality(self):
        generic = default_types.Generic

        list_a = default_types.List(generic(1), generic(2), generic(3))
        list_b = default_types.List(generic(1), generic(2), generic(3))

        tuple_a = default_types.Tuple(generic(1), generic(2), generic(3))
        tuple_b = default_types.Tuple(generic(1), generic(2), generic(3))

        self.assertEqual(list_a, list_b)
        self.assertEqual(tuple_a, tuple_b)
        self.assertNotEqual(list_a, tuple_a)
        self.assertNotEqual(tuple_a, list_a)
    def testTupleSupertype(self):
        tuple_a = default_types.Tuple(MockSupertypes2With3(1),
                                      MockSupertypes2With3(2),
                                      MockSupertypes2With3(3))
        tuple_b = default_types.Tuple(MockSupertypes2With3(2),
                                      MockSupertypes2With3(2),
                                      MockSupertypes2With3(2))

        self.assertEqual(tuple_a, tuple_a.most_specific_common_supertype([]))
        self.assertIsNone(tuple_a.most_specific_common_supertype([tuple_b]))
        self.assertEqual(
            tuple_b.most_specific_common_supertype([tuple_a]),
            default_types.Tuple(MockSupertypes2With3(3),
                                MockSupertypes2With3(3),
                                MockSupertypes2With3(3)))
def from_object(obj: Any,
                context: trace.TracingContext = None) -> trace.TraceType:
    """Returns a TraceType corresponding to the object based on the context.

  Args:
    obj: The object to generate a TraceType for.
    context: The TracingContext to be shared during protocol calls.

  Returns:
    A TraceType object representing the given object.
  """

    if context is None:
        context = InternalTracingContext()

    if isinstance(obj, trace.SupportsTracingProtocol):
        return obj.__tf_tracing_type__(context)

    if hasattr(obj, "__wrapped__"):
        return from_object(obj.__wrapped__, context)

    if isinstance(obj, list):
        return default_types.List(*(from_object(c, context) for c in obj))

    if isinstance(obj, tuple):
        if util.is_namedtuple(obj):
            named_tuple_type = type(obj)
            return default_types.NamedTuple.from_type_and_attributes(
                named_tuple_type, tuple(from_object(c, context) for c in obj))
        else:
            return default_types.Tuple(*(from_object(c, context) for c in obj))

    if isinstance(obj, collections.abc.Mapping):
        return default_types.Dict(
            {k: from_object(obj[k], context)
             for k in obj})

    if util.is_attrs(obj):
        return default_types.Attrs.from_type_and_attributes(
            type(obj),
            tuple(
                from_object(getattr(obj, a.name), context)
                for a in obj.__attrs_attrs__))

    try:
        ref = weakref.ref(obj, context.deletion_observer)
        if ref is None:
            raise TypeError(
                f"Deleted objects are not valid tf.function arguments, Got {obj!r}"
            )
        else:
            return default_types.Weakref(ref)
    except TypeError:
        try:
            return default_types.Literal(obj)
        except:
            raise TypeError(
                f"Python object could not be represented through the generic tracing "
                f"type. Consider implementing the Tracing Protocol for it: {obj!r}"
            )
    def testTupleSerialization(self):
        tuple_original = default_types.Tuple(default_types.Literal(1),
                                             default_types.Literal(2),
                                             default_types.Literal(3))

        self.assertEqual(
            serialization.deserialize(serialization.serialize(tuple_original)),
            tuple_original)
Esempio n. 6
0
    def testTupleSupertype(self):
        class Supertypable(default_types.Generic):
            def most_specific_common_supertype(self, others):
                if not others:
                    return self

                if self._object == 2 and isinstance(others[0]._object, int):
                    return Supertypable(3)
                else:
                    return None

        tuple_a = default_types.Tuple(Supertypable(1), Supertypable(2),
                                      Supertypable(3))
        tuple_b = default_types.Tuple(Supertypable(2), Supertypable(2),
                                      Supertypable(2))

        self.assertEqual(tuple_a, tuple_a.most_specific_common_supertype([]))
        self.assertIsNone(tuple_a.most_specific_common_supertype([tuple_b]))
        self.assertEqual(
            tuple_b.most_specific_common_supertype([tuple_a]),
            default_types.Tuple(Supertypable(3), Supertypable(3),
                                Supertypable(3)))
Esempio n. 7
0
def create_trace_type(obj: Any, context: SignatureContext) -> trace.TraceType:
    """Returns a TraceType corresponding to the object based on the context.

  Args:
    obj: The object to generate a TraceType for.
    context: The TracingContext to be shared during protocol calls.

  Returns:
    A TraceType object representing the given object.
  """

    if isinstance(obj, trace.SupportsTracingProtocol):
        return obj.__tf_tracing_type__(context)

    if isinstance(obj, list):
        return default_types.List(*(create_trace_type(c, context)
                                    for c in obj))

    if isinstance(obj, tuple):
        return default_types.Tuple(*(create_trace_type(c, context)
                                     for c in obj))

    if isinstance(obj, collections.abc.Mapping):
        return default_types.Dict(
            {k: create_trace_type(obj[k], context)
             for k in obj})

    if hasattr(type(obj), "__attrs_attrs__"):
        return default_types.Attrs(
            type(obj), (create_trace_type(getattr(obj, a.name), context)
                        for a in obj.__attrs_attrs__))

    if hasattr(obj, "__wrapped__"):
        return create_trace_type(obj.__wrapped__, context)

    try:
        ref = weakref.ref(obj, context.deletion_observer)
        if ref is None:
            raise TypeError(
                f"Deleted objects are not valid tf.function arguments, Got {obj!r}"
            )
        else:
            return default_types.Weakref(ref)
    except TypeError:
        try:
            return default_types.Generic(obj)
        except:
            raise TypeError(
                f"Python object could not be represented through the generic tracing "
                f"type. Consider implementing the Tracing Protocol for it: {obj!r}"
            )