Esempio n. 1
0
def make_type_checker(annotation):
    """Builds a PyTypeChecker for the given type annotation."""
    if type_annotations.is_generic_union(annotation):
        type_args = type_annotations.get_generic_type_args(annotation)
        options = [make_type_checker(t) for t in type_args]
        return _api_dispatcher.MakeUnionChecker(options)

    elif type_annotations.is_generic_list(annotation):
        type_args = type_annotations.get_generic_type_args(annotation)
        if len(type_args) != 1:
            raise AssertionError(
                "Expected List[...] to have a single type parameter")
        elt_type = make_type_checker(type_args[0])
        return _api_dispatcher.MakeListChecker(elt_type)

    elif isinstance(annotation, type):
        if annotation not in _is_instance_checker_cache:
            checker = _api_dispatcher.MakeInstanceChecker(annotation)
            _is_instance_checker_cache[annotation] = checker
        return _is_instance_checker_cache[annotation]

    elif annotation is None:
        return make_type_checker(type(None))

    else:
        raise ValueError(
            f"Type annotation {annotation} is not currently supported"
            " by dispatch.  Supported annotations: type objects, "
            " List[...], and Union[...]")
Esempio n. 2
0
    def testUnionChecker(self):
        int_checker = dispatch.MakeInstanceChecker(int)
        float_checker = dispatch.MakeInstanceChecker(float)
        str_checker = dispatch.MakeInstanceChecker(str)
        none_checker = dispatch.MakeInstanceChecker(type(None))
        tensor_checker = dispatch.MakeInstanceChecker(ops.Tensor)
        ragged_checker = dispatch.MakeInstanceChecker(
            ragged_tensor.RaggedTensor)

        t = constant_op.constant([1, 2, 3])
        rt = ragged_factory_ops.constant([[1, 2], [3, 4, 5]])

        with self.subTest('Union[int, float, str]'):
            checker = dispatch.MakeUnionChecker(
                [int_checker, float_checker, str_checker])
            self.assertEqual(checker.Check(3), MATCH)
            self.assertEqual(checker.Check(3.0), MATCH)
            self.assertEqual(checker.Check('x'), MATCH)
            self.assertEqual(checker.Check('x'), MATCH)
            self.assertEqual(checker.Check(None), NO_MATCH)
            self.assertEqual(checker.Check(t), NO_MATCH)
            self.assertEqual(checker.cost(), 4)
            self.assertEqual(repr(checker),
                             '<PyTypeChecker Union[int, float, str]>')

        with self.subTest('Optional[int] (aka Union[int, None])'):
            checker = dispatch.MakeUnionChecker([int_checker, none_checker])
            self.assertEqual(checker.Check(3), MATCH)
            self.assertEqual(checker.Check(3.0), NO_MATCH)
            self.assertEqual(checker.Check(None), MATCH)
            self.assertEqual(checker.Check(t), NO_MATCH)
            self.assertEqual(checker.cost(), 3)
            self.assertEqual(repr(checker),
                             '<PyTypeChecker Union[int, NoneType]>')

        with self.subTest('Union[Tensor, RaggedTensor]'):
            checker = dispatch.MakeUnionChecker(
                [tensor_checker, ragged_checker])
            self.assertEqual(checker.Check(3), NO_MATCH)
            self.assertEqual(checker.Check(3.0), NO_MATCH)
            self.assertEqual(checker.Check(None), NO_MATCH)
            self.assertEqual(checker.Check(t), MATCH)
            self.assertEqual(checker.Check(rt), MATCH_DISPATCHABLE)
            self.assertEqual(checker.cost(), 3)
            self.assertEqual(repr(checker),
                             '<PyTypeChecker Union[Tensor, RaggedTensor]>')
Esempio n. 3
0
def make_type_checker(annotation):
    """Builds a PyTypeChecker for the given type annotation."""
    if type_annotations.is_generic_union(annotation):
        type_args = type_annotations.get_generic_type_args(annotation)

        # If the union contains two or more simple types, then use a single
        # InstanceChecker to check them.
        simple_types = [t for t in type_args if isinstance(t, type)]
        simple_types = tuple(sorted(simple_types, key=id))
        if len(simple_types) > 1:
            if simple_types not in _is_instance_checker_cache:
                checker = _api_dispatcher.MakeInstanceChecker(*simple_types)
                _is_instance_checker_cache[simple_types] = checker
            options = ([_is_instance_checker_cache[simple_types]] + [
                make_type_checker(t)
                for t in type_args if not isinstance(t, type)
            ])
            return _api_dispatcher.MakeUnionChecker(options)

        options = [make_type_checker(t) for t in type_args]
        return _api_dispatcher.MakeUnionChecker(options)

    elif type_annotations.is_generic_list(annotation):
        type_args = type_annotations.get_generic_type_args(annotation)
        if len(type_args) != 1:
            raise AssertionError(
                "Expected List[...] to have a single type parameter")
        elt_type = make_type_checker(type_args[0])
        return _api_dispatcher.MakeListChecker(elt_type)

    elif isinstance(annotation, type):
        if annotation not in _is_instance_checker_cache:
            checker = _api_dispatcher.MakeInstanceChecker(annotation)
            _is_instance_checker_cache[annotation] = checker
        return _is_instance_checker_cache[annotation]

    elif annotation is None:
        return make_type_checker(type(None))

    else:
        raise ValueError(
            f"Type annotation {annotation} is not currently supported"
            " by dispatch.  Supported annotations: type objects, "
            " List[...], and Union[...]")
Esempio n. 4
0
    def testSortByCost(self):
        a = dispatch.MakeInstanceChecker(int)
        b = dispatch.MakeInstanceChecker(float)
        c = dispatch.MakeUnionChecker([a, b])
        d = dispatch.MakeListChecker(a)
        e = dispatch.MakeListChecker(c)
        checker = dispatch.PySignatureChecker([(0, e), (1, c), (2, d), (3, a)])

        # Note: `repr(checker)` lists the args in the order they will be checked.
        self.assertEqual(
            repr(checker),
            '<PySignatureChecker '
            'args[3]:int, '  # a: cost=1
            'args[1]:Union[int, float], '  # c: cost=3
            'args[2]:List[int], '  # d: cost=10
            'args[0]:List[Union[int, float]]>'  # e: cost=30
        )  # pyformat: disable
Esempio n. 5
0
    def testUnion(self):
        rt_checker = dispatch.MakeInstanceChecker(ragged_tensor.RaggedTensor)
        tensor_checker = dispatch.MakeInstanceChecker(ops.Tensor)
        rt_or_tensor = dispatch.MakeUnionChecker([rt_checker, tensor_checker])
        checker = dispatch.PySignatureChecker([(0, rt_or_tensor),
                                               (1, rt_or_tensor)])

        t = constant_op.constant([[1, 2], [3, 4]])
        rt = ragged_factory_ops.constant([[1, 2], [3]])

        self.check_signatures(checker, [
            ((t, t), False),
            ((t, rt), True),
            ((rt, t), True),
            ((rt, rt), True),
            ((rt, [rt]), False),
            ((rt, rt, 1, 2, None), True),
        ])  # pyformat: disable

        self.assertEqual(
            repr(checker),
            '<PySignatureChecker args[0]:Union[RaggedTensor, Tensor], '
            'args[1]:Union[RaggedTensor, Tensor]>')
Esempio n. 6
0
    def testListAndUnionDispatch(self):
        dispatcher = dispatch.PythonAPIDispatcher('tf.foo',
                                                  ['x', 'ys', 'name'],
                                                  (None, ))

        rt_checker = dispatch.MakeInstanceChecker(ragged_tensor.RaggedTensor)
        tensor_checker = dispatch.MakeInstanceChecker(ops.Tensor)
        rt_or_t = dispatch.MakeUnionChecker([rt_checker, tensor_checker])
        list_of_rt_or_t = dispatch.MakeListChecker(rt_or_t)
        f1 = lambda x, ys, name=None: 'f1'
        dispatcher.Register(
            dispatch.PySignatureChecker([(0, rt_or_t), (1, list_of_rt_or_t)]),
            f1)

        rt = ragged_factory_ops.constant([[1, 2], [3]])
        t = constant_op.constant(5)
        self.assertEqual(dispatcher.Dispatch((rt, [t]), None), 'f1')
        self.assertEqual(dispatcher.Dispatch((rt, [rt]), None), 'f1')
        self.assertEqual(dispatcher.Dispatch((t, [rt]), None), 'f1')
        self.assertEqual(dispatcher.Dispatch((rt, []), None), 'f1')
        self.assertEqual(dispatcher.Dispatch((t, [t, t, rt, t]), None), 'f1')
        self.assertEqual(dispatcher.Dispatch((rt, [t], 'my_name'), None), 'f1')
        self.assertEqual(dispatcher.Dispatch((), {'x': rt, 'ys': [t]}), 'f1')
        self.assertEqual(
            dispatcher.Dispatch((), {
                'x': rt,
                'ys': [t],
                'name': 'x'
            }), 'f1')
        self.assertEqual(dispatcher.Dispatch((t, [t]), None), NotImplemented)
        self.assertEqual(dispatcher.Dispatch((t, []), None), NotImplemented)
        self.assertEqual(dispatcher.Dispatch(('foo', [rt]), None),
                         NotImplemented)
        self.assertEqual(dispatcher.Dispatch(('foo', 'bar'), None),
                         NotImplemented)
        self.assertEqual(dispatcher.Dispatch(('foo', 'bar', 'baz'), None),
                         NotImplemented)
Esempio n. 7
0
    def testListChecker(self):
        int_checker = dispatch.MakeInstanceChecker(int)
        tensor_checker = dispatch.MakeInstanceChecker(ops.Tensor)
        ragged_checker = dispatch.MakeInstanceChecker(
            ragged_tensor.RaggedTensor)
        np_int_checker = dispatch.MakeInstanceChecker(np.integer)

        t = constant_op.constant([1, 2, 3])
        rt = ragged_factory_ops.constant([[1, 2], [3, 4, 5]])
        a = [1, 2, 3]
        b = ['a', 2, t]
        c = [t, t * 2, t - 2]
        d = [t, rt]
        e = []
        f = (1, 2, 3)
        g = (rt, )
        h = {1: 2, 3: 4}
        i = np.array([1, 2, 3])

        with self.subTest('List[int]'):
            checker = dispatch.MakeListChecker(int_checker)
            self.assertEqual(checker.Check(a), MATCH)
            self.assertEqual(checker.Check(b), NO_MATCH)
            self.assertEqual(checker.Check(c), NO_MATCH)
            self.assertEqual(checker.Check(d), NO_MATCH)
            self.assertEqual(checker.Check(e), MATCH)
            self.assertEqual(checker.Check(f), MATCH)
            self.assertEqual(checker.Check(iter(a)), NO_MATCH)
            self.assertEqual(checker.Check(iter(b)), NO_MATCH)
            self.assertEqual(checker.Check(reversed(e)), NO_MATCH)
            self.assertEqual(checker.Check(h), NO_MATCH)
            self.assertEqual(checker.Check(i), NO_MATCH)
            self.assertEqual(checker.cost(), 10)
            self.assertEqual(repr(checker), '<PyTypeChecker List[int]>')

        with self.subTest('List[Tensor]'):
            checker = dispatch.MakeListChecker(tensor_checker)
            self.assertEqual(checker.Check(a), NO_MATCH)
            self.assertEqual(checker.Check(b), NO_MATCH)
            self.assertEqual(checker.Check(c), MATCH)
            self.assertEqual(checker.Check(d), NO_MATCH)
            self.assertEqual(checker.Check(e), MATCH)
            self.assertEqual(checker.cost(), 10)
            self.assertEqual(repr(checker), '<PyTypeChecker List[Tensor]>')

        with self.subTest('List[Union[Tensor, RaggedTensor]]'):
            checker = dispatch.MakeListChecker(
                dispatch.MakeUnionChecker([tensor_checker, ragged_checker]))
            self.assertEqual(checker.Check(a), NO_MATCH)
            self.assertEqual(checker.Check(b), NO_MATCH)
            self.assertEqual(checker.Check(c), MATCH)
            self.assertEqual(checker.Check(d), MATCH_DISPATCHABLE)
            self.assertEqual(checker.Check(e), MATCH)
            self.assertEqual(checker.Check(f), NO_MATCH)
            self.assertEqual(checker.Check(g), MATCH_DISPATCHABLE)
            self.assertEqual(checker.cost(), 30)
            self.assertEqual(
                repr(checker),
                '<PyTypeChecker List[Union[Tensor, RaggedTensor]]>')

        with self.subTest('List[Union[int, np.integer]]'):
            # Note: np.integer is a subtype of int in *some* Python versions.
            checker = dispatch.MakeListChecker(
                dispatch.MakeUnionChecker([int_checker, np_int_checker]))
            self.assertEqual(checker.Check(a), MATCH)
            self.assertEqual(checker.Check(np.array(a)), NO_MATCH)
            self.assertEqual(checker.Check(np.array(a) * 1.5), NO_MATCH)