Exemplo n.º 1
0
 def testAttrsCacheKeyGeneration(self):
     trace_a = trace_type.from_object(TestAttrsClass(1, 2))
     expected = default_types.Attrs(
         TestAttrsClass,
         (default_types.Generic(1), default_types.Generic(2)))
     self.assertEqual(trace_a, expected)
     self.assertTrue(trace_a.is_subtype_of(trace_a))
Exemplo n.º 2
0
 def testAttrsCacheKeyGeneration(self):
   trace_a = make_function_signature_with_context(TestAttrsClass(1, 2))
   expected = default_types.Attrs(
       TestAttrsClass,
       (default_types.Generic(1), default_types.Generic(2)))
   self.assertEqual(trace_a, expected)
   self.assertTrue(trace_a.is_subtype_of(trace_a))
Exemplo n.º 3
0
    def testGenericSupertypes(self):
        generic_a = default_types.Generic(1)
        generic_b = default_types.Generic(2)
        generic_c = default_types.Generic(1)

        self.assertEqual(generic_a,
                         generic_a.most_specific_common_supertype([]))
        self.assertEqual(generic_a,
                         generic_a.most_specific_common_supertype([generic_a]))
        self.assertEqual(generic_a,
                         generic_a.most_specific_common_supertype([generic_c]))
        self.assertIsNone(generic_a.most_specific_common_supertype([generic_b
                                                                    ]))
Exemplo n.º 4
0
def from_object(obj: Any,
                context: trace.TracingContext = None) -> trace.TraceType:
    """Returns a TraceType corresponding to the object based on the context.

  Args:
    obj: The object to generate a TraceType for.
    context: The TracingContext to be shared during protocol calls.

  Returns:
    A TraceType object representing the given object.
  """

    if context is None:
        context = InternalTracingContext()

    if isinstance(obj, trace.SupportsTracingProtocol):
        return obj.__tf_tracing_type__(context)

    if hasattr(obj, "__wrapped__"):
        return from_object(obj.__wrapped__, context)

    if isinstance(obj, list):
        return default_types.List(*(from_object(c, context) for c in obj))

    if isinstance(obj, tuple):
        if util.is_namedtuple(obj):
            return default_types.NamedTuple(
                type(obj), tuple(from_object(c, context) for c in obj))
        else:
            return default_types.Tuple(*(from_object(c, context) for c in obj))

    if isinstance(obj, collections.abc.Mapping):
        return default_types.Dict(
            {k: from_object(obj[k], context)
             for k in obj})

    if util.is_attrs(obj):
        return default_types.Attrs(
            type(obj),
            tuple(
                from_object(getattr(obj, a.name), context)
                for a in obj.__attrs_attrs__))

    try:
        ref = weakref.ref(obj, context.deletion_observer)
        if ref is None:
            raise TypeError(
                f"Deleted objects are not valid tf.function arguments, Got {obj!r}"
            )
        else:
            return default_types.Weakref(ref)
    except TypeError:
        try:
            return default_types.Generic(obj)
        except:
            raise TypeError(
                f"Python object could not be represented through the generic tracing "
                f"type. Consider implementing the Tracing Protocol for it: {obj!r}"
            )
Exemplo n.º 5
0
    def testReferenceSubtype(self):
        class MockSubtypeOf2(default_types.Generic):
            def is_subtype_of(self, other):
                return other._object == 2

        original = default_types.Reference(MockSubtypeOf2(3), 1)
        clone = default_types.Reference(MockSubtypeOf2(3), 1)
        different_id = default_types.Reference(MockSubtypeOf2(3), 2)
        supertype = default_types.Reference(MockSubtypeOf2(2), 1)
        different_type = default_types.Generic(1)

        self.assertEqual(original, clone)
        self.assertFalse(original.is_subtype_of(different_id))
        self.assertTrue(original.is_subtype_of(supertype))
        self.assertFalse(supertype.is_subtype_of(original))
        self.assertFalse(original.is_subtype_of(different_type))
Exemplo n.º 6
0
    def testReferenceSupertype(self):
        class Mock2AsTopType(default_types.Generic):
            def most_specific_common_supertype(self, types):
                if not all(
                        isinstance(other, Mock2AsTopType) for other in types):
                    return None
                return self if all(self._object == other._object
                                   for other in types) else Mock2AsTopType(2)

        original = default_types.Reference(Mock2AsTopType(3), 1)
        clone = default_types.Reference(Mock2AsTopType(3), 1)
        different_id = default_types.Reference(Mock2AsTopType(3), 2)
        supertype = default_types.Reference(Mock2AsTopType(2), 1)
        different_type = default_types.Generic(1)

        self.assertEqual(supertype.most_specific_common_supertype([]),
                         supertype)
        self.assertEqual(original.most_specific_common_supertype([clone]),
                         original)
        self.assertIsNone(
            original.most_specific_common_supertype([different_id]))
        self.assertIsNone(
            original.most_specific_common_supertype([different_type]))