Beispiel #1
0
 def testAttrsCacheKeyGeneration(self):
     trace_a = trace_type.from_object(TestAttrsClass(1, 2))
     expected = default_types.Attrs.from_type_and_attributes(
         TestAttrsClass,
         (default_types.Literal(1), default_types.Literal(2)))
     self.assertEqual(trace_a, expected)
     self.assertTrue(trace_a.is_subtype_of(trace_a))
    def testTupleSerialization(self):
        tuple_original = default_types.Tuple(default_types.Literal(1),
                                             default_types.Literal(2),
                                             default_types.Literal(3))

        self.assertEqual(
            serialization.deserialize(serialization.serialize(tuple_original)),
            tuple_original)
    def testListSerialization(self):
        list_original = default_types.List(default_types.Literal(1),
                                           default_types.Literal(2),
                                           default_types.Literal(3))

        self.assertEqual(
            serialization.deserialize(serialization.serialize(list_original)),
            list_original)
    def testDictSerialization(self):
        dict_original = default_types.Dict({
            'a': default_types.Literal(1),
            'b': default_types.Literal(2),
            'c': default_types.Literal(3)
        })

        self.assertEqual(
            serialization.deserialize(serialization.serialize(dict_original)),
            dict_original)
    def testLiteralSupertypes(self):
        literal_a = default_types.Literal(1)
        literal_b = default_types.Literal(2)
        literal_c = default_types.Literal(1)

        self.assertEqual(literal_a,
                         literal_a.most_specific_common_supertype([]))
        self.assertEqual(literal_a,
                         literal_a.most_specific_common_supertype([literal_a]))
        self.assertEqual(literal_a,
                         literal_a.most_specific_common_supertype([literal_c]))
        self.assertIsNone(literal_a.most_specific_common_supertype([literal_b
                                                                    ]))
def from_object(obj: Any,
                context: trace.TracingContext = None) -> trace.TraceType:
    """Returns a TraceType corresponding to the object based on the context.

  Args:
    obj: The object to generate a TraceType for.
    context: The TracingContext to be shared during protocol calls.

  Returns:
    A TraceType object representing the given object.
  """

    if context is None:
        context = InternalTracingContext()

    if isinstance(obj, trace.SupportsTracingProtocol):
        return obj.__tf_tracing_type__(context)

    if hasattr(obj, "__wrapped__"):
        return from_object(obj.__wrapped__, context)

    if isinstance(obj, list):
        return default_types.List(*(from_object(c, context) for c in obj))

    if isinstance(obj, tuple):
        if util.is_namedtuple(obj):
            named_tuple_type = type(obj)
            return default_types.NamedTuple.from_type_and_attributes(
                named_tuple_type, tuple(from_object(c, context) for c in obj))
        else:
            return default_types.Tuple(*(from_object(c, context) for c in obj))

    if isinstance(obj, collections.abc.Mapping):
        return default_types.Dict(
            {k: from_object(obj[k], context)
             for k in obj})

    if util.is_attrs(obj):
        return default_types.Attrs.from_type_and_attributes(
            type(obj),
            tuple(
                from_object(getattr(obj, a.name), context)
                for a in obj.__attrs_attrs__))

    try:
        ref = weakref.ref(obj, context.deletion_observer)
        if ref is None:
            raise TypeError(
                f"Deleted objects are not valid tf.function arguments, Got {obj!r}"
            )
        else:
            return default_types.Weakref(ref)
    except TypeError:
        try:
            return default_types.Literal(obj)
        except:
            raise TypeError(
                f"Python object could not be represented through the generic tracing "
                f"type. Consider implementing the Tracing Protocol for it: {obj!r}"
            )
Beispiel #7
0
    def testLiteralSerialization(self):
        literal_bool = default_types.Literal(True)
        literal_int = default_types.Literal(1)
        literal_float = default_types.Literal(1.2)
        literal_str = default_types.Literal('a')

        self.assertEqual(
            serialization.deserialize(serialization.serialize(literal_bool)),
            literal_bool)
        self.assertEqual(
            serialization.deserialize(serialization.serialize(literal_int)),
            literal_int)
        self.assertEqual(
            serialization.deserialize(serialization.serialize(literal_float)),
            literal_float)
        self.assertEqual(
            serialization.deserialize(serialization.serialize(literal_str)),
            literal_str)
    def testReferenceSubtype(self):
        original = default_types.Reference(Mock2AsTopType(3), 1)
        clone = default_types.Reference(Mock2AsTopType(3), 1)
        different_id = default_types.Reference(Mock2AsTopType(3), 2)
        supertype = default_types.Reference(Mock2AsTopType(2), 1)
        different_type = default_types.Literal(1)

        self.assertEqual(original, clone)
        self.assertFalse(original.is_subtype_of(different_id))
        self.assertTrue(original.is_subtype_of(supertype))
        self.assertFalse(supertype.is_subtype_of(original))
        self.assertFalse(original.is_subtype_of(different_type))
    def testReferenceSupertype(self):
        original = default_types.Reference(Mock2AsTopType(3), 1)
        clone = default_types.Reference(Mock2AsTopType(3), 1)
        different_id = default_types.Reference(Mock2AsTopType(3), 2)
        supertype = default_types.Reference(Mock2AsTopType(2), 1)
        different_type = default_types.Literal(1)

        self.assertEqual(supertype.most_specific_common_supertype([]),
                         supertype)
        self.assertEqual(original.most_specific_common_supertype([clone]),
                         original)
        self.assertIsNone(
            original.most_specific_common_supertype([different_id]))
        self.assertIsNone(
            original.most_specific_common_supertype([different_type]))
 def testReferencetSerialization(self):
     ref_original = default_types.Reference(default_types.Literal(3), 1)
     self.assertEqual(
         serialization.deserialize(serialization.serialize(ref_original)),
         ref_original)