예제 #1
0
    def testMultipleDispatchers(self):
        dispatcher = dispatch.PythonAPIDispatcher('tf.foo', ['x', 'y', 'name'],
                                                  (None, ))

        rt_checker = dispatch.MakeInstanceChecker(ragged_tensor.RaggedTensor)
        rt_x_checker = dispatch.PySignatureChecker([(0, rt_checker)])
        rt_y_checker = dispatch.PySignatureChecker([(1, rt_checker)])

        f1 = lambda x, y, name=None: 'f1'
        f2 = lambda x, y, name=None: 'f2'

        rt = ragged_factory_ops.constant([[1, 2], [3]])

        dispatcher.Register(rt_x_checker, f1)
        dispatcher.Register(rt_y_checker, f2)

        self.assertEqual(dispatcher.Dispatch((rt, 5), None), 'f1')
        self.assertEqual(dispatcher.Dispatch(('foo', rt), None), 'f2')
        self.assertEqual(dispatcher.Dispatch(('foo', ), {'y': rt}), 'f2')
        self.assertEqual(dispatcher.Dispatch(('foo', 'bar'), None),
                         NotImplemented)
        with self.assertRaisesRegex(
                ValueError, 'Multiple dispatch targets .*'
                r'match the arguments to tf\.foo'):
            dispatcher.Dispatch((rt, rt), None)
예제 #2
0
    def testSortByCost(self):
        a = dispatch.MakeInstanceChecker(int)
        b = dispatch.MakeInstanceChecker(float)
        c = dispatch.MakeUnionChecker([a, b])
        d = dispatch.MakeListChecker(a)
        e = dispatch.MakeListChecker(c)
        checker = dispatch.PySignatureChecker([(0, e), (1, c), (2, d), (3, a)])

        # Note: `repr(checker)` lists the args in the order they will be checked.
        self.assertEqual(
            repr(checker),
            '<PySignatureChecker '
            'args[3]:int, '  # a: cost=1
            'args[1]:Union[int, float], '  # c: cost=3
            'args[2]:List[int], '  # d: cost=10
            'args[0]:List[Union[int, float]]>'  # e: cost=30
        )  # pyformat: disable
예제 #3
0
    def testList(self):
        rt_checker = dispatch.MakeInstanceChecker(ragged_tensor.RaggedTensor)
        rt_list_checker = dispatch.MakeListChecker(rt_checker)
        checker = dispatch.PySignatureChecker([(0, rt_list_checker)])

        rt = ragged_factory_ops.constant([[1, 2], [3]])

        self.check_signatures(checker, [
            (([rt], ), True),
            (([], ), False),
            ((rt, ), False),
            (([rt, rt + 3, rt * 2], ), True),
            (([rt, rt.values, rt * 2], ), False),
        ])  # pyformat: disable

        self.assertEqual(repr(checker),
                         '<PySignatureChecker args[0]:List[RaggedTensor]>')
예제 #4
0
def _make_signature_checker(api_signature, signature):
    """Builds a PySignatureChecker for the given type signature.

  Args:
    api_signature: The `inspect.Signature` of the API whose signature is
      being checked.
    signature: Dictionary mapping parameter names to type annotations.

  Returns:
    A `PySignatureChecker`.
  """
    if not (isinstance(signature, dict)
            and all(isinstance(k, (str, int)) for k in signature)):
        raise TypeError(
            "signatures must be dictionaries mapping parameter names "
            "to type annotations.")
    checkers = []

    param_names = list(api_signature.parameters)
    for param_name, param_type in signature.items():
        # Convert positional parameters to named parameters.
        if (isinstance(param_name, int)
                and param_name < len(api_signature.parameters)):
            param_name = list(
                api_signature.parameters.values())[param_name].name

        # Check that the parameter exists, and has an appropriate kind.
        param = api_signature.parameters.get(param_name, None)
        if param is None:
            raise ValueError("signature includes annotation for unknown "
                             f"parameter {param_name!r}.")
        if param.kind not in (tf_inspect.Parameter.POSITIONAL_ONLY,
                              tf_inspect.Parameter.POSITIONAL_OR_KEYWORD):
            raise ValueError(
                "Dispatch currently only supports type annotations "
                "for positional parameters; can't handle annotation "
                f"for {param.kind!r} parameter {param_name}.")

        checker = make_type_checker(param_type)
        index = param_names.index(param_name)
        checkers.append((index, checker))

    return _api_dispatcher.PySignatureChecker(checkers)
예제 #5
0
    def testSimpleSignature(self):
        int_checker = dispatch.MakeInstanceChecker(int)
        rt_checker = dispatch.MakeInstanceChecker(ragged_tensor.RaggedTensor)
        checker = dispatch.PySignatureChecker([(0, int_checker),
                                               (2, rt_checker)])
        rt = ragged_factory_ops.constant([[1, 2], [3]])

        self.check_signatures(checker, [
            ((1, 2, rt), True),
            ((1, 2, 3), False),
            ((1, 2), False),
            ((), False),
            ((5, 'x', rt, None), True),
            (([5], 'x', rt, None), False),
            ((5, 'x', [rt], None), False),
        ])  # pyformat: disable

        self.assertEqual(
            repr(checker),
            '<PySignatureChecker args[0]:int, args[2]:RaggedTensor>')
예제 #6
0
    def testUnion(self):
        rt_checker = dispatch.MakeInstanceChecker(ragged_tensor.RaggedTensor)
        tensor_checker = dispatch.MakeInstanceChecker(ops.Tensor)
        rt_or_tensor = dispatch.MakeUnionChecker([rt_checker, tensor_checker])
        checker = dispatch.PySignatureChecker([(0, rt_or_tensor),
                                               (1, rt_or_tensor)])

        t = constant_op.constant([[1, 2], [3, 4]])
        rt = ragged_factory_ops.constant([[1, 2], [3]])

        self.check_signatures(checker, [
            ((t, t), False),
            ((t, rt), True),
            ((rt, t), True),
            ((rt, rt), True),
            ((rt, [rt]), False),
            ((rt, rt, 1, 2, None), True),
        ])  # pyformat: disable

        self.assertEqual(
            repr(checker),
            '<PySignatureChecker args[0]:Union[RaggedTensor, Tensor], '
            'args[1]:Union[RaggedTensor, Tensor]>')
예제 #7
0
    def testListAndUnionDispatch(self):
        dispatcher = dispatch.PythonAPIDispatcher('tf.foo',
                                                  ['x', 'ys', 'name'],
                                                  (None, ))

        rt_checker = dispatch.MakeInstanceChecker(ragged_tensor.RaggedTensor)
        tensor_checker = dispatch.MakeInstanceChecker(ops.Tensor)
        rt_or_t = dispatch.MakeUnionChecker([rt_checker, tensor_checker])
        list_of_rt_or_t = dispatch.MakeListChecker(rt_or_t)
        f1 = lambda x, ys, name=None: 'f1'
        dispatcher.Register(
            dispatch.PySignatureChecker([(0, rt_or_t), (1, list_of_rt_or_t)]),
            f1)

        rt = ragged_factory_ops.constant([[1, 2], [3]])
        t = constant_op.constant(5)
        self.assertEqual(dispatcher.Dispatch((rt, [t]), None), 'f1')
        self.assertEqual(dispatcher.Dispatch((rt, [rt]), None), 'f1')
        self.assertEqual(dispatcher.Dispatch((t, [rt]), None), 'f1')
        self.assertEqual(dispatcher.Dispatch((rt, []), None), 'f1')
        self.assertEqual(dispatcher.Dispatch((t, [t, t, rt, t]), None), 'f1')
        self.assertEqual(dispatcher.Dispatch((rt, [t], 'my_name'), None), 'f1')
        self.assertEqual(dispatcher.Dispatch((), {'x': rt, 'ys': [t]}), 'f1')
        self.assertEqual(
            dispatcher.Dispatch((), {
                'x': rt,
                'ys': [t],
                'name': 'x'
            }), 'f1')
        self.assertEqual(dispatcher.Dispatch((t, [t]), None), NotImplemented)
        self.assertEqual(dispatcher.Dispatch((t, []), None), NotImplemented)
        self.assertEqual(dispatcher.Dispatch(('foo', [rt]), None),
                         NotImplemented)
        self.assertEqual(dispatcher.Dispatch(('foo', 'bar'), None),
                         NotImplemented)
        self.assertEqual(dispatcher.Dispatch(('foo', 'bar', 'baz'), None),
                         NotImplemented)
예제 #8
0
    def testBasicDispatch(self):
        dispatcher = dispatch.PythonAPIDispatcher('tf.foo', ['x', 'y', 'name'],
                                                  (None, ))

        rt_checker = dispatch.MakeInstanceChecker(ragged_tensor.RaggedTensor)
        f1 = lambda x, y, name=None: 'f1'
        dispatcher.Register(dispatch.PySignatureChecker([(0, rt_checker)]), f1)

        rt = ragged_factory_ops.constant([[1, 2], [3]])
        self.assertEqual(dispatcher.Dispatch((rt, 5), None), 'f1')
        self.assertEqual(dispatcher.Dispatch((rt, 5, 'my_name'), None), 'f1')
        self.assertEqual(dispatcher.Dispatch((), {'x': rt, 'y': 5}), 'f1')
        self.assertEqual(
            dispatcher.Dispatch((), {
                'x': rt,
                'y': 5,
                'name': 'x'
            }), 'f1')
        self.assertEqual(dispatcher.Dispatch(('foo', rt), None),
                         NotImplemented)
        self.assertEqual(dispatcher.Dispatch(('foo', 'bar'), None),
                         NotImplemented)
        self.assertEqual(dispatcher.Dispatch(('foo', 'bar', 'baz'), None),
                         NotImplemented)