def testNoDispatchableTypes(self):
        add_dispatcher = _pywrap_python_api_dispatcher.PythonAPIDispatcher(
            "tf.math.add", math_ops.add, 2, [0, 1], [], False)
        self.assertEqual(add_dispatcher.Dispatch(1, 2), NotImplemented)

        concat_dispatcher = _pywrap_python_api_dispatcher.PythonAPIDispatcher(
            "tf.concat", array_ops.concat, 2, [1], [0], False)
        self.assertEqual(concat_dispatcher.Dispatch([1], 0), NotImplemented)
 def testDispatchParamOutOfRange(self):
     with self.assertRaisesRegex(ValueError, "index out of range"):
         _pywrap_python_api_dispatcher.PythonAPIDispatcher(
             "some_api", None, 5, [0, 1, 5], [2, 3], True)
     with self.assertRaisesRegex(ValueError, "index out of range"):
         _pywrap_python_api_dispatcher.PythonAPIDispatcher(
             "some_api", None, 5, [0, -3], [2, 3], True)
     with self.assertRaisesRegex(ValueError, "index out of range"):
         _pywrap_python_api_dispatcher.PythonAPIDispatcher(
             "some_api", None, 5, [0, 1], [10, 3], True)
Exemplo n.º 3
0
    def testMultipleDispatchers(self):
        dispatcher = dispatch.PythonAPIDispatcher('tf.foo', ['x', 'y', 'name'],
                                                  (None, ))

        rt_checker = dispatch.MakeInstanceChecker(ragged_tensor.RaggedTensor)
        rt_x_checker = dispatch.PySignatureChecker([(0, rt_checker)])
        rt_y_checker = dispatch.PySignatureChecker([(1, rt_checker)])

        f1 = lambda x, y, name=None: 'f1'
        f2 = lambda x, y, name=None: 'f2'

        rt = ragged_factory_ops.constant([[1, 2], [3]])

        dispatcher.Register(rt_x_checker, f1)
        dispatcher.Register(rt_y_checker, f2)

        self.assertEqual(dispatcher.Dispatch((rt, 5), None), 'f1')
        self.assertEqual(dispatcher.Dispatch(('foo', rt), None), 'f2')
        self.assertEqual(dispatcher.Dispatch(('foo', ), {'y': rt}), 'f2')
        self.assertEqual(dispatcher.Dispatch(('foo', 'bar'), None),
                         NotImplemented)
        with self.assertRaisesRegex(
                ValueError, 'Multiple dispatch targets .*'
                r'match the arguments to tf\.foo'):
            dispatcher.Dispatch((rt, rt), None)
    def testDispatcherReturnsNotImplemented(self):
        dispatcher = _pywrap_python_api_dispatcher.PythonAPIDispatcher(
            "tf.math.add", math_ops.add, 2, [0, 1], [], False)
        x = 5
        y = Trace("constant", "disabled")
        z = Trace("constant", "z")

        self.assertEqual(dispatcher.Dispatch(x, y), NotImplemented)
        self.assertEqual(dispatcher.Dispatch(y, x), NotImplemented)
        self.assertEqual(dispatcher.Dispatch(y, z), NotImplemented)
        self.assertEqual(dispatcher.Dispatch(z, z), Trace("tf.math.add", z, z))
    def testSimpleDispatchWithTrace(self):
        dispatcher = _pywrap_python_api_dispatcher.PythonAPIDispatcher(
            "tf.math.add", math_ops.add, 2, [0, 1], [], False)
        x = 5
        y = Trace("constant", "y")
        z = Trace("constant", "z")

        Trace.log.clear()
        self.assertEqual(dispatcher.Dispatch(x, y), Trace("tf.math.add", x, y))
        self.assertEqual(dispatcher.Dispatch(y, x), Trace("tf.math.add", y, x))
        self.assertEqual(dispatcher.Dispatch(y, z), Trace("tf.math.add", y, z))
        self.assertEqual(Trace.log, [
            "__tf_dispatch__('Trace', 'tf.math.add')",
            "__tf_dispatch__('Trace', 'tf.math.add')",
            "__tf_dispatch__('Trace', 'tf.math.add')"
        ])
    def testSimpleDispatchWithWeightedTensor(self):
        dispatcher = _pywrap_python_api_dispatcher.PythonAPIDispatcher(
            "tf.math.add", math_ops.add, 2, [0, 1], [], False)
        x = 5
        y = WeightedTensor([1, 2, 3], 0.6)
        z = WeightedTensor([10, 20, 30], 0.2)

        x_plus_y = dispatcher.Dispatch(x, y)
        y_plus_x = dispatcher.Dispatch(y, x)
        y_plus_z = dispatcher.Dispatch(y, z)

        self.assertAllEqual(x_plus_y.tensor, [6, 7, 8])
        self.assertAllEqual(y_plus_x.tensor, [6, 7, 8])
        self.assertAllEqual(y_plus_z.tensor, [11, 22, 33])

        self.assertEqual(x_plus_y.weight, 0.6)
        self.assertEqual(y_plus_x.weight, 0.6)
        self.assertEqual(y_plus_z.weight, 0.4)
    def testDispatchPrecedenceRightToLeft(self):
        # We use an API for which dispatch is disabled, so all dispatchers get
        # called (since this test checks the order of the dispatcher list).
        dispatcher = _pywrap_python_api_dispatcher.PythonAPIDispatcher(
            "disabled", None, 5, [4, 0, 1], [2, 3], True)

        t = Trace("constant", "t")
        t2_1 = Trace2("constant", "t2_1")
        t2_2 = Trace2("constant", "t2_2")
        t2b = Trace2B("constant", "t2b")
        t3 = Trace3("constant", "t3")
        t4 = Trace4("constant", "t4")

        # Three dispatchable types, none of which is a subclass of the other:
        # * precedence is right_to_left (since we set right_to_left=True in the
        #   PtyonAPIDispatcher constructor).  (Note: arguments are scanned
        #   right-to-left, but the elements of list arguments are still scanned
        #   left-to-right.)
        # * duplicates are removed.
        Trace.log.clear()
        result = dispatcher.Dispatch(t2_1, t3, [], [t2_2, t3], t4)
        self.assertEqual(result, NotImplemented)
        self.assertEqual(Trace.log, [
            "__tf_dispatch__('Trace4', 'disabled')",
            "__tf_dispatch__('Trace2', 'disabled')",
            "__tf_dispatch__('Trace3', 'disabled')"
        ])

        # Subtypes are moved before their base types.  (Note: moving subtypes occurs
        # *after* we swap the order to be right-to-left; so the dispatch order here
        # is not what we'd get by just reversing the final dispatch order if
        # right_to_left were false.)
        Trace.log.clear()
        result = dispatcher.Dispatch(t2_1, t3, [t], [t2_2, t, t3, t4], t2b)
        self.assertEqual(result, NotImplemented)
        self.assertEqual(Trace.log, [
            "__tf_dispatch__('Trace2B', 'disabled')",
            "__tf_dispatch__('Trace2', 'disabled')",
            "__tf_dispatch__('Trace3', 'disabled')",
            "__tf_dispatch__('Trace4', 'disabled')",
            "__tf_dispatch__('Trace', 'disabled')"
        ])
Exemplo n.º 8
0
def add_type_based_api_dispatcher(target):
    """Adds a PythonAPIDispatcher to the given TensorFlow API function."""
    if hasattr(target, TYPE_BASED_DISPATCH_ATTR):
        raise ValueError(f"{target} already has a type-based API dispatcher.")

    _, unwrapped = tf_decorator.unwrap(target)
    target_argspec = tf_inspect.getargspec(unwrapped)
    if target_argspec.varargs or target_argspec.keywords:
        # @TODO(b/194903203) Add v2 dispatch support for APIs that take varargs
        # and keywords.  Examples of APIs that take varargs and kwargs: meshgrid,
        # einsum, map_values, map_flat_values.
        return target

    setattr(
        target, TYPE_BASED_DISPATCH_ATTR,
        _api_dispatcher.PythonAPIDispatcher(unwrapped.__name__,
                                            target_argspec.args,
                                            target_argspec.defaults))
    _TYPE_BASED_DISPATCH_SIGNATURES[target] = collections.defaultdict(list)
    return target
Exemplo n.º 9
0
    def testListAndUnionDispatch(self):
        dispatcher = dispatch.PythonAPIDispatcher('tf.foo',
                                                  ['x', 'ys', 'name'],
                                                  (None, ))

        rt_checker = dispatch.MakeInstanceChecker(ragged_tensor.RaggedTensor)
        tensor_checker = dispatch.MakeInstanceChecker(ops.Tensor)
        rt_or_t = dispatch.MakeUnionChecker([rt_checker, tensor_checker])
        list_of_rt_or_t = dispatch.MakeListChecker(rt_or_t)
        f1 = lambda x, ys, name=None: 'f1'
        dispatcher.Register(
            dispatch.PySignatureChecker([(0, rt_or_t), (1, list_of_rt_or_t)]),
            f1)

        rt = ragged_factory_ops.constant([[1, 2], [3]])
        t = constant_op.constant(5)
        self.assertEqual(dispatcher.Dispatch((rt, [t]), None), 'f1')
        self.assertEqual(dispatcher.Dispatch((rt, [rt]), None), 'f1')
        self.assertEqual(dispatcher.Dispatch((t, [rt]), None), 'f1')
        self.assertEqual(dispatcher.Dispatch((rt, []), None), 'f1')
        self.assertEqual(dispatcher.Dispatch((t, [t, t, rt, t]), None), 'f1')
        self.assertEqual(dispatcher.Dispatch((rt, [t], 'my_name'), None), 'f1')
        self.assertEqual(dispatcher.Dispatch((), {'x': rt, 'ys': [t]}), 'f1')
        self.assertEqual(
            dispatcher.Dispatch((), {
                'x': rt,
                'ys': [t],
                'name': 'x'
            }), 'f1')
        self.assertEqual(dispatcher.Dispatch((t, [t]), None), NotImplemented)
        self.assertEqual(dispatcher.Dispatch((t, []), None), NotImplemented)
        self.assertEqual(dispatcher.Dispatch(('foo', [rt]), None),
                         NotImplemented)
        self.assertEqual(dispatcher.Dispatch(('foo', 'bar'), None),
                         NotImplemented)
        self.assertEqual(dispatcher.Dispatch(('foo', 'bar', 'baz'), None),
                         NotImplemented)
Exemplo n.º 10
0
    def testBasicDispatch(self):
        dispatcher = dispatch.PythonAPIDispatcher('tf.foo', ['x', 'y', 'name'],
                                                  (None, ))

        rt_checker = dispatch.MakeInstanceChecker(ragged_tensor.RaggedTensor)
        f1 = lambda x, y, name=None: 'f1'
        dispatcher.Register(dispatch.PySignatureChecker([(0, rt_checker)]), f1)

        rt = ragged_factory_ops.constant([[1, 2], [3]])
        self.assertEqual(dispatcher.Dispatch((rt, 5), None), 'f1')
        self.assertEqual(dispatcher.Dispatch((rt, 5, 'my_name'), None), 'f1')
        self.assertEqual(dispatcher.Dispatch((), {'x': rt, 'y': 5}), 'f1')
        self.assertEqual(
            dispatcher.Dispatch((), {
                'x': rt,
                'y': 5,
                'name': 'x'
            }), 'f1')
        self.assertEqual(dispatcher.Dispatch(('foo', rt), None),
                         NotImplemented)
        self.assertEqual(dispatcher.Dispatch(('foo', 'bar'), None),
                         NotImplemented)
        self.assertEqual(dispatcher.Dispatch(('foo', 'bar', 'baz'), None),
                         NotImplemented)
    def testDispatchPrecedence(self):
        # We use an API for which dispatch is disabled, so all dispatchers get
        # called (since this test checks the order of the dispatcher list).
        dispatcher = _pywrap_python_api_dispatcher.PythonAPIDispatcher(
            "disabled", None, 5, [0, 1, 4], [2, 3], False)

        t = Trace("constant", "t")
        t2_1 = Trace2("constant", "t2_1")
        t2_2 = Trace2("constant", "t2_2")
        t2b = Trace2B("constant", "t2b")
        t3 = Trace3("constant", "t3")
        t4 = Trace4("constant", "t4")

        # Three dispatchable types, none of which is a subclass of the other:
        # * precedence is left-to-right.
        # * duplicates are removed.
        Trace.log.clear()
        result = dispatcher.Dispatch(t2_1, t3, [], [t2_2, t3], t4)
        self.assertEqual(result, NotImplemented)
        self.assertEqual(Trace.log, [
            "__tf_dispatch__('Trace2', 'disabled')",
            "__tf_dispatch__('Trace3', 'disabled')",
            "__tf_dispatch__('Trace4', 'disabled')"
        ])

        # Subtypes are moved before their base types.
        Trace.log.clear()
        result = dispatcher.Dispatch(t2_1, t3, [t], [t2_2, t, t3, t4], t2b)
        self.assertEqual(result, NotImplemented)
        self.assertEqual(Trace.log, [
            "__tf_dispatch__('Trace2B', 'disabled')",
            "__tf_dispatch__('Trace2', 'disabled')",
            "__tf_dispatch__('Trace3', 'disabled')",
            "__tf_dispatch__('Trace4', 'disabled')",
            "__tf_dispatch__('Trace', 'disabled')"
        ])