예제 #1
0
    def testDispatchSignatureWithUnspecifiedParameter(self):
        @dispatch.dispatch_for_api(math_ops.add, {"x": MaskedTensor})
        def masked_add(x, y):
            if y is None:
                return x
            y_values = y.values if isinstance(y, MaskedTensor) else y
            y_mask = y.mask if isinstance(y, MaskedTensor) else True
            return MaskedTensor(x.values + y_values, x.mask & y_mask)

        try:
            a = MaskedTensor([1, 2, 3, 4, 5], [1, 0, 1, 1, 1])
            b = constant_op.constant([10, 20, 30, 40, 50])
            c = [10, 20, 30, 40, 50]
            d = 50
            e = None
            # As long as `x` is a MaskedTensor, the dispatcher will be called
            # (regardless of the type for `y`):
            self.assertAllEqual(
                math_ops.add(a, b).values, [11, 22, 33, 44, 55])
            self.assertAllEqual(
                math_ops.add(a, c).values, [11, 22, 33, 44, 55])
            self.assertAllEqual(
                math_ops.add(a, d).values, [51, 52, 53, 54, 55])
            self.assertAllEqual(math_ops.add(a, e).values, [1, 2, 3, 4, 5])

        finally:
            # Clean up dispatch table.
            dispatch.unregister_dispatch_for(masked_add)
예제 #2
0
    def testDispatchForTensorLike(self):
        MaskedOrTensorLike = typing.Union[MaskedTensor,
                                          core_tf_types.TensorLike]

        @dispatch.dispatch_for_api(math_ops.add)
        def masked_add(x: MaskedOrTensorLike,
                       y: MaskedOrTensorLike,
                       name=None):
            with ops.name_scope(name):
                x_values = x.values if isinstance(x, MaskedTensor) else x
                x_mask = x.mask if isinstance(x, MaskedTensor) else True
                y_values = y.values if isinstance(y, MaskedTensor) else y
                y_mask = y.mask if isinstance(y, MaskedTensor) else True
                return MaskedTensor(x_values + y_values, x_mask & y_mask)

        try:
            x = MaskedTensor([1, 2, 3, 4, 5], [1, 0, 1, 1, 1])
            y1 = [10, 20, 30, 40, 50]
            y2 = np.array([10, 20, 30, 40, 50])
            y3 = constant_op.constant([10, 20, 30, 40, 50])
            y4 = variables.Variable([5, 4, 3, 2, 1])
            if not context.executing_eagerly():
                self.evaluate(variables.global_variables_initializer())
            for y in [y1, y2, y3, y4]:
                z = math_ops.add(x, y)
                self.assertAllEqual(z.values, x.values + y)
                self.assertAllEqual(z.mask, x.mask)

        finally:
            # Clean up dispatch table.
            dispatch.unregister_dispatch_for(masked_add)
예제 #3
0
    def testDispatchForUnion(self):
        MaybeMasked = typing.Union[MaskedTensor, ops.Tensor]

        @dispatch.dispatch_for_api(math_ops.add, {
            "x": MaybeMasked,
            "y": MaybeMasked
        })
        def masked_add(x, y, name=None):
            with ops.name_scope(name):
                x_values = x.values if isinstance(x, MaskedTensor) else x
                x_mask = x.mask if isinstance(x, MaskedTensor) else True
                y_values = y.values if isinstance(y, MaskedTensor) else y
                y_mask = y.mask if isinstance(y, MaskedTensor) else True
                return MaskedTensor(x_values + y_values, x_mask & y_mask)

        try:
            x = MaskedTensor([1, 2, 3, 4, 5], [1, 0, 1, 1, 1])
            y = constant_op.constant([10, 20, 30, 40, 50])
            z = math_ops.add(x, y)
            self.assertAllEqual(z.values, x.values + y)
            self.assertAllEqual(z.mask, x.mask)

        finally:
            # Clean up dispatch table.
            dispatch.unregister_dispatch_for(masked_add)
예제 #4
0
    def testDispatchForSignatureFromAnnotations(self):
        @dispatch.dispatch_for_api(math_ops.add)
        def masked_add(x: MaskedTensor, y: MaskedTensor, name=None):
            with ops.name_scope(name):
                return MaskedTensor(x.values + y.values, x.mask & y.mask)

        try:
            x = MaskedTensor([1, 2, 3, 4, 5], [1, 0, 1, 1, 1])
            y = MaskedTensor([1, 1, 1, 1, 1], [1, 1, 0, 1, 0])
            z = math_ops.add(x, y)
            self.assertAllEqual(z.values, x.values + y.values)
            self.assertAllEqual(z.mask, x.mask & y.mask)

        finally:
            # Clean up dispatch table.
            dispatch.unregister_dispatch_for(masked_add)
예제 #5
0
    def testDuplicateDispatchForUnaryElementwiseAPIsError(self):
        @dispatch.dispatch_for_unary_elementwise_apis(MaskedTensor)
        def handler(api_func, x):
            return MaskedTensor(api_func(x.values), x.mask)

        try:
            with self.assertRaisesRegex(
                    ValueError,
                    r"A unary elementwise dispatch handler \(.*\) has "
                    "already been registered for .*"):

                @dispatch.dispatch_for_unary_elementwise_apis(MaskedTensor)
                def another_handler(api_func, x):
                    return MaskedTensor(api_func(x.values), ~x.mask)

                del another_handler

        finally:
            dispatch.unregister_dispatch_for(handler)
예제 #6
0
    def testDispatchWithIterableParams(self):
        # The add_n API supports having `inputs` be an iterable (and not just
        # a sequence).
        @dispatch.dispatch_for_api(math_ops.add_n,
                                   {"inputs": typing.List[MaskedTensor]})
        def masked_add_n(inputs):
            masks = array_ops.stack([x.mask for x in inputs])
            return MaskedTensor(math_ops.add_n([x.values for x in inputs]),
                                math_ops.reduce_all(masks, axis=0))

        try:
            generator = (MaskedTensor([i], [True]) for i in range(5))
            y = math_ops.add_n(generator)
            self.assertAllEqual(y.values, [0 + 1 + 2 + 3 + 4])
            self.assertAllEqual(y.mask, [True])

        finally:
            # Clean up dispatch table.
            dispatch.unregister_dispatch_for(masked_add_n)
예제 #7
0
    def testRegisterDispatchableType(self):
        Car = collections.namedtuple("Car", ["size", "speed"])
        dispatch.register_dispatchable_type(Car)

        @dispatch.dispatch_for_api(math_ops.add, {"x": Car, "y": Car})
        def add_car(x, y, name=None):
            with ops.name_scope(name):
                return Car(x.size + y.size, x.speed + y.speed)

        try:
            x = Car(constant_op.constant(1), constant_op.constant(3))
            y = Car(constant_op.constant(10), constant_op.constant(20))
            z = math_ops.add(x, y)
            self.assertAllEqual(z.size, 11)
            self.assertAllEqual(z.speed, 23)

        finally:
            # Clean up dispatch table.
            dispatch.unregister_dispatch_for(add_car)
예제 #8
0
    def testDispatchWithKwargs(self):
        @dispatch.dispatch_for_api(math_ops.add, {
            "x": MaskedTensor,
            "y": MaskedTensor
        })
        def masked_add(*args, **kwargs):
            self.assertAllEqual(kwargs["x"].values, x.values)
            self.assertAllEqual(kwargs["y"].values, y.values)
            self.assertEmpty(args)
            return "stub"

        try:
            x = MaskedTensor([1, 2, 3, 4, 5], [1, 0, 1, 1, 1])
            y = MaskedTensor([1, 1, 1, 1, 1], [1, 1, 0, 1, 0])
            self.assertEqual(math_ops.add(x=x, y=y), "stub")

        finally:
            # Clean up dispatch table.
            dispatch.unregister_dispatch_for(masked_add)
예제 #9
0
    def testDispatchApiWithNoNameArg(self):
        # Note: The "tensor_equals" API has no "name" argument.
        signature = {"self": MaskedTensor, "other": MaskedTensor}

        @dispatch.dispatch_for_api(math_ops.tensor_equals, signature)
        def masked_tensor_equals(self, other):
            del self, other

        dispatch.unregister_dispatch_for(masked_tensor_equals)  # clean up.

        with self.assertRaisesRegexp(
                ValueError,
                r"Dispatch function's signature \(self, other, name=None\) "
                r"does not match API's signature \(self, other\)\."):

            @dispatch.dispatch_for_api(math_ops.tensor_equals, signature)
            def masked_tensor_equals_2(self, other, name=None):
                del self, other, name

            del masked_tensor_equals_2  # avoid pylint unused variable warning.
예제 #10
0
    def testDispatchForList(self):
        @dispatch.dispatch_for_api(array_ops.concat,
                                   {"values": typing.List[MaskedTensor]})
        def masked_concat(values, axis, name=None):
            with ops.name_scope(name):
                return MaskedTensor(
                    array_ops.concat([v.values for v in values], axis),
                    array_ops.concat([v.mask for v in values], axis))

        try:
            x = MaskedTensor([1, 2, 3, 4, 5], [1, 0, 1, 1, 1])
            y = MaskedTensor([1, 1, 1], [1, 1, 0])
            z = array_ops.concat([x, y], axis=0)
            self.assertAllEqual(z.values,
                                array_ops.concat([x.values, y.values], 0))
            self.assertAllEqual(z.mask, array_ops.concat([x.mask, y.mask], 0))

        finally:
            # Clean up dispatch table.
            dispatch.unregister_dispatch_for(masked_concat)
예제 #11
0
    def testDispatchForOptional(self):
        # Note: typing.Optional[X] == typing.Union[X, NoneType].

        @dispatch.dispatch_for_api(
            array_ops.where_v2, {
                "condition": MaskedTensor,
                "x": typing.Optional[MaskedTensor],
                "y": typing.Optional[MaskedTensor]
            })
        def masked_where(condition, x=None, y=None, name=None):
            del condition, x, y, name
            return "stub"

        try:
            x = MaskedTensor([True, False, True, True, True], [1, 0, 1, 1, 1])
            self.assertEqual(array_ops.where_v2(x), "stub")
            self.assertEqual(array_ops.where_v2(x, x, x), "stub")

        finally:
            # Clean up dispatch table.
            dispatch.unregister_dispatch_for(masked_where)
예제 #12
0
    def testDispatchForBinaryElementwiseAPIs(self):
        @dispatch.dispatch_for_binary_elementwise_apis(MaskedTensor,
                                                       MaskedTensor)
        def binary_elementwise_api_handler(api_func, x, y):
            return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask)

        try:
            x = MaskedTensor([1, -2, -3], [True, True, False])
            y = MaskedTensor([10, 20, 30], [True, False, True])
            # Test calls with positional & keyword arguments (& combinations)
            x_times_y = math_ops.multiply(x, y)
            x_plus_y = math_ops.add(x, y=y)
            x_minus_y = math_ops.subtract(x=x, y=y)
            min_x_y = math_ops.minimum(x, y, "min_x_y")
            y_times_x = math_ops.multiply(y, x, name="y_times_x")
            y_plus_x = math_ops.add(y, y=x, name="y_plus_x")
            y_minus_x = math_ops.subtract(x=y, y=x, name="y_minus_x")
            self.assertAllEqual(x_times_y.values, [10, -40, -90])
            self.assertAllEqual(x_plus_y.values, [11, 18, 27])
            self.assertAllEqual(x_minus_y.values, [-9, -22, -33])
            self.assertAllEqual(min_x_y.values, [1, -2, -3])
            self.assertAllEqual(y_times_x.values, [10, -40, -90])
            self.assertAllEqual(y_plus_x.values, [11, 18, 27])
            self.assertAllEqual(y_minus_x.values, [9, 22, 33])
            for result in [
                    x_times_y, x_plus_y, x_minus_y, min_x_y, y_times_x,
                    y_plus_x, y_minus_x
            ]:
                self.assertAllEqual(result.mask, [True, False, False])
            if not context.executing_eagerly(
            ):  # names not defined in eager mode.
                self.assertRegex(min_x_y.values.name, r"^min_x_y/Minimum:.*")
                self.assertRegex(min_x_y.mask.name, r"^min_x_y/and:.*")
                self.assertRegex(y_times_x.values.name, r"^y_times_x/.*")
                self.assertRegex(y_plus_x.values.name, r"^y_plus_x/.*")
                self.assertRegex(y_minus_x.values.name, r"^y_minus_x/.*")

        finally:
            dispatch.unregister_dispatch_for(binary_elementwise_api_handler)
예제 #13
0
    def testRegisterUnaryElementwiseApiAfterHandler(self):
        # Test that it's ok to call register_unary_elementwise_api after
        # dispatch_for_unary_elementwise_apis.

        @dispatch.dispatch_for_unary_elementwise_apis(MaskedTensor)
        def handler(api_func, x):
            return MaskedTensor(api_func(x.values), x.mask)

        try:

            @dispatch.register_unary_elementwise_api
            @dispatch.add_dispatch_support
            def some_op(x):
                return x * 2

            x = MaskedTensor([1, 2, 3], [True, False, True])
            y = some_op(x)
            self.assertAllEqual(y.values, [2, 4, 6])
            self.assertAllEqual(y.mask, [True, False, True])

        finally:
            dispatch.unregister_dispatch_for(handler)
예제 #14
0
    def testDispatchForUnaryElementwiseAPIs(self):
        @dispatch.dispatch_for_unary_elementwise_apis(MaskedTensor)
        def unary_elementwise_api_handler(api_func, x):
            return MaskedTensor(api_func(x.values), x.mask)

        try:
            x = MaskedTensor([1, -2, -3], [True, True, False])
            # Test calls with positional & keyword argument (& combinations)
            abs_x = math_ops.abs(x)
            sign_x = math_ops.sign(x=x)
            neg_x = math_ops.negative(x, "neg_x")
            invert_x = bitwise_ops.invert(x, name="invert_x")
            ones_like_x = array_ops.ones_like(x, name="ones_like_x")
            ones_like_x_float = array_ops.ones_like(x,
                                                    dtypes.float32,
                                                    name="ones_like_x_float")
            self.assertAllEqual(abs_x.values, [1, 2, 3])
            self.assertAllEqual(sign_x.values, [1, -1, -1])
            self.assertAllEqual(neg_x.values, [-1, 2, 3])
            self.assertAllEqual(invert_x.values, [-2, 1, 2])
            self.assertAllEqual(ones_like_x.values, [1, 1, 1])
            self.assertAllEqual(ones_like_x_float.values, [1., 1., 1.])
            for result in [
                    abs_x, sign_x, neg_x, invert_x, ones_like_x,
                    ones_like_x_float
            ]:
                self.assertAllEqual(result.mask, [True, True, False])
            if not context.executing_eagerly(
            ):  # names not defined in eager mode.
                self.assertRegex(neg_x.values.name, r"^neg_x/Neg:.*")
                self.assertRegex(invert_x.values.name, r"^invert_x/.*")
                self.assertRegex(ones_like_x.values.name, r"^ones_like_x/.*")
                self.assertRegex(ones_like_x_float.values.name,
                                 r"^ones_like_x_float/.*")

        finally:
            dispatch.unregister_dispatch_for(unary_elementwise_api_handler)
예제 #15
0
    def testDispatchTargetWithNoNameArgument(self):
        @dispatch.dispatch_for_api(math_ops.add, {
            "x": MaskedTensor,
            "y": MaskedTensor
        })
        def masked_add(x, y):
            return MaskedTensor(x.values + y.values, x.mask & y.mask)

        try:
            x = MaskedTensor([1, 2, 3, 4, 5], [1, 0, 1, 1, 1])
            y = MaskedTensor([1, 1, 1, 1, 1], [1, 1, 0, 1, 0])

            # pass name w/ keyword arg
            a = math_ops.add(x, y, name="MyAdd")
            if not context.executing_eagerly(
            ):  # names not defined in eager mode.
                self.assertRegex(a.values.name, r"^MyAdd/add.*")
                self.assertRegex(a.mask.name, r"^MyAdd/and.*")

            # pass name w/ positional arg
            b = math_ops.add(x, y, "B")
            if not context.executing_eagerly(
            ):  # names not defined in eager mode.
                self.assertRegex(b.values.name, r"^B/add.*")
                self.assertRegex(b.mask.name, r"^B/and.*")

            # default name value
            c = math_ops.add(x, y)
            if not context.executing_eagerly(
            ):  # names not defined in eager mode.
                self.assertRegex(c.values.name, r"^add.*")
                self.assertRegex(c.mask.name, r"^and.*")

        finally:
            # Clean up dispatch table.
            dispatch.unregister_dispatch_for(masked_add)
예제 #16
0
    def testRegisterBinaryElementwiseApiAfterHandler(self):
        # Test that it's ok to call register_binary_elementwise_api after
        # dispatch_for_binary_elementwise_apis.

        @dispatch.dispatch_for_binary_elementwise_apis(MaskedTensor,
                                                       MaskedTensor)
        def handler(api_func, x, y):
            return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask)

        try:

            @dispatch.register_binary_elementwise_api
            @dispatch.add_dispatch_support
            def some_op(x, y):
                return x * 2 + y

            x = MaskedTensor([1, 2, 3], [True, False, True])
            y = MaskedTensor([10, 20, 30], [True, True, False])
            z = some_op(x, y)
            self.assertAllEqual(z.values, [12, 24, 36])
            self.assertAllEqual(z.mask, [True, False, False])

        finally:
            dispatch.unregister_dispatch_for(handler)
예제 #17
0
    def testTypeBasedDispatchTargetsFor(self):
        MaskedTensorList = typing.List[typing.Union[MaskedTensor, ops.Tensor]]
        try:

            @dispatch.dispatch_for_api(math_ops.add)
            def masked_add(x: MaskedTensor, y: MaskedTensor):
                del x, y

            @dispatch.dispatch_for_api(array_ops.concat)
            def masked_concat(values: MaskedTensorList, axis):
                del values, axis

            @dispatch.dispatch_for_api(math_ops.add)
            def silly_add(x: SillyTensor, y: SillyTensor):
                del x, y

            @dispatch.dispatch_for_api(math_ops.abs)
            def silly_abs(x: SillyTensor):
                del x

            # Note: `expeced` does not contain keys or values from SillyTensor.
            targets = dispatch.type_based_dispatch_signatures_for(MaskedTensor)
            expected = {
                math_ops.add: [{
                    "x": MaskedTensor,
                    "y": MaskedTensor
                }],
                array_ops.concat: [{
                    "values": MaskedTensorList
                }]
            }
            self.assertEqual(targets, expected)

        finally:
            # Clean up dispatch table.
            dispatch.unregister_dispatch_for(masked_add)
            dispatch.unregister_dispatch_for(masked_concat)
            dispatch.unregister_dispatch_for(silly_add)
            dispatch.unregister_dispatch_for(silly_abs)
예제 #18
0
 def testUnregisterDispatchTargetBadTargetError(self):
     fn = lambda x: x + 1
     with self.assertRaisesRegex(ValueError,
                                 "Function .* was not registered"):
         dispatch.unregister_dispatch_for(fn)