def testDispatchSignatureWithUnspecifiedParameter(self): @dispatch.dispatch_for_api(math_ops.add, {"x": MaskedTensor}) def masked_add(x, y): if y is None: return x y_values = y.values if isinstance(y, MaskedTensor) else y y_mask = y.mask if isinstance(y, MaskedTensor) else True return MaskedTensor(x.values + y_values, x.mask & y_mask) try: a = MaskedTensor([1, 2, 3, 4, 5], [1, 0, 1, 1, 1]) b = constant_op.constant([10, 20, 30, 40, 50]) c = [10, 20, 30, 40, 50] d = 50 e = None # As long as `x` is a MaskedTensor, the dispatcher will be called # (regardless of the type for `y`): self.assertAllEqual( math_ops.add(a, b).values, [11, 22, 33, 44, 55]) self.assertAllEqual( math_ops.add(a, c).values, [11, 22, 33, 44, 55]) self.assertAllEqual( math_ops.add(a, d).values, [51, 52, 53, 54, 55]) self.assertAllEqual(math_ops.add(a, e).values, [1, 2, 3, 4, 5]) finally: # Clean up dispatch table. dispatch.unregister_dispatch_for(masked_add)
def testDispatchForTensorLike(self): MaskedOrTensorLike = typing.Union[MaskedTensor, core_tf_types.TensorLike] @dispatch.dispatch_for_api(math_ops.add) def masked_add(x: MaskedOrTensorLike, y: MaskedOrTensorLike, name=None): with ops.name_scope(name): x_values = x.values if isinstance(x, MaskedTensor) else x x_mask = x.mask if isinstance(x, MaskedTensor) else True y_values = y.values if isinstance(y, MaskedTensor) else y y_mask = y.mask if isinstance(y, MaskedTensor) else True return MaskedTensor(x_values + y_values, x_mask & y_mask) try: x = MaskedTensor([1, 2, 3, 4, 5], [1, 0, 1, 1, 1]) y1 = [10, 20, 30, 40, 50] y2 = np.array([10, 20, 30, 40, 50]) y3 = constant_op.constant([10, 20, 30, 40, 50]) y4 = variables.Variable([5, 4, 3, 2, 1]) if not context.executing_eagerly(): self.evaluate(variables.global_variables_initializer()) for y in [y1, y2, y3, y4]: z = math_ops.add(x, y) self.assertAllEqual(z.values, x.values + y) self.assertAllEqual(z.mask, x.mask) finally: # Clean up dispatch table. dispatch.unregister_dispatch_for(masked_add)
def testDispatchForUnion(self): MaybeMasked = typing.Union[MaskedTensor, ops.Tensor] @dispatch.dispatch_for_api(math_ops.add, { "x": MaybeMasked, "y": MaybeMasked }) def masked_add(x, y, name=None): with ops.name_scope(name): x_values = x.values if isinstance(x, MaskedTensor) else x x_mask = x.mask if isinstance(x, MaskedTensor) else True y_values = y.values if isinstance(y, MaskedTensor) else y y_mask = y.mask if isinstance(y, MaskedTensor) else True return MaskedTensor(x_values + y_values, x_mask & y_mask) try: x = MaskedTensor([1, 2, 3, 4, 5], [1, 0, 1, 1, 1]) y = constant_op.constant([10, 20, 30, 40, 50]) z = math_ops.add(x, y) self.assertAllEqual(z.values, x.values + y) self.assertAllEqual(z.mask, x.mask) finally: # Clean up dispatch table. dispatch.unregister_dispatch_for(masked_add)
def testDispatchForSignatureFromAnnotations(self): @dispatch.dispatch_for_api(math_ops.add) def masked_add(x: MaskedTensor, y: MaskedTensor, name=None): with ops.name_scope(name): return MaskedTensor(x.values + y.values, x.mask & y.mask) try: x = MaskedTensor([1, 2, 3, 4, 5], [1, 0, 1, 1, 1]) y = MaskedTensor([1, 1, 1, 1, 1], [1, 1, 0, 1, 0]) z = math_ops.add(x, y) self.assertAllEqual(z.values, x.values + y.values) self.assertAllEqual(z.mask, x.mask & y.mask) finally: # Clean up dispatch table. dispatch.unregister_dispatch_for(masked_add)
def testDuplicateDispatchForUnaryElementwiseAPIsError(self): @dispatch.dispatch_for_unary_elementwise_apis(MaskedTensor) def handler(api_func, x): return MaskedTensor(api_func(x.values), x.mask) try: with self.assertRaisesRegex( ValueError, r"A unary elementwise dispatch handler \(.*\) has " "already been registered for .*"): @dispatch.dispatch_for_unary_elementwise_apis(MaskedTensor) def another_handler(api_func, x): return MaskedTensor(api_func(x.values), ~x.mask) del another_handler finally: dispatch.unregister_dispatch_for(handler)
def testDispatchWithIterableParams(self): # The add_n API supports having `inputs` be an iterable (and not just # a sequence). @dispatch.dispatch_for_api(math_ops.add_n, {"inputs": typing.List[MaskedTensor]}) def masked_add_n(inputs): masks = array_ops.stack([x.mask for x in inputs]) return MaskedTensor(math_ops.add_n([x.values for x in inputs]), math_ops.reduce_all(masks, axis=0)) try: generator = (MaskedTensor([i], [True]) for i in range(5)) y = math_ops.add_n(generator) self.assertAllEqual(y.values, [0 + 1 + 2 + 3 + 4]) self.assertAllEqual(y.mask, [True]) finally: # Clean up dispatch table. dispatch.unregister_dispatch_for(masked_add_n)
def testRegisterDispatchableType(self): Car = collections.namedtuple("Car", ["size", "speed"]) dispatch.register_dispatchable_type(Car) @dispatch.dispatch_for_api(math_ops.add, {"x": Car, "y": Car}) def add_car(x, y, name=None): with ops.name_scope(name): return Car(x.size + y.size, x.speed + y.speed) try: x = Car(constant_op.constant(1), constant_op.constant(3)) y = Car(constant_op.constant(10), constant_op.constant(20)) z = math_ops.add(x, y) self.assertAllEqual(z.size, 11) self.assertAllEqual(z.speed, 23) finally: # Clean up dispatch table. dispatch.unregister_dispatch_for(add_car)
def testDispatchWithKwargs(self): @dispatch.dispatch_for_api(math_ops.add, { "x": MaskedTensor, "y": MaskedTensor }) def masked_add(*args, **kwargs): self.assertAllEqual(kwargs["x"].values, x.values) self.assertAllEqual(kwargs["y"].values, y.values) self.assertEmpty(args) return "stub" try: x = MaskedTensor([1, 2, 3, 4, 5], [1, 0, 1, 1, 1]) y = MaskedTensor([1, 1, 1, 1, 1], [1, 1, 0, 1, 0]) self.assertEqual(math_ops.add(x=x, y=y), "stub") finally: # Clean up dispatch table. dispatch.unregister_dispatch_for(masked_add)
def testDispatchApiWithNoNameArg(self): # Note: The "tensor_equals" API has no "name" argument. signature = {"self": MaskedTensor, "other": MaskedTensor} @dispatch.dispatch_for_api(math_ops.tensor_equals, signature) def masked_tensor_equals(self, other): del self, other dispatch.unregister_dispatch_for(masked_tensor_equals) # clean up. with self.assertRaisesRegexp( ValueError, r"Dispatch function's signature \(self, other, name=None\) " r"does not match API's signature \(self, other\)\."): @dispatch.dispatch_for_api(math_ops.tensor_equals, signature) def masked_tensor_equals_2(self, other, name=None): del self, other, name del masked_tensor_equals_2 # avoid pylint unused variable warning.
def testDispatchForList(self): @dispatch.dispatch_for_api(array_ops.concat, {"values": typing.List[MaskedTensor]}) def masked_concat(values, axis, name=None): with ops.name_scope(name): return MaskedTensor( array_ops.concat([v.values for v in values], axis), array_ops.concat([v.mask for v in values], axis)) try: x = MaskedTensor([1, 2, 3, 4, 5], [1, 0, 1, 1, 1]) y = MaskedTensor([1, 1, 1], [1, 1, 0]) z = array_ops.concat([x, y], axis=0) self.assertAllEqual(z.values, array_ops.concat([x.values, y.values], 0)) self.assertAllEqual(z.mask, array_ops.concat([x.mask, y.mask], 0)) finally: # Clean up dispatch table. dispatch.unregister_dispatch_for(masked_concat)
def testDispatchForOptional(self): # Note: typing.Optional[X] == typing.Union[X, NoneType]. @dispatch.dispatch_for_api( array_ops.where_v2, { "condition": MaskedTensor, "x": typing.Optional[MaskedTensor], "y": typing.Optional[MaskedTensor] }) def masked_where(condition, x=None, y=None, name=None): del condition, x, y, name return "stub" try: x = MaskedTensor([True, False, True, True, True], [1, 0, 1, 1, 1]) self.assertEqual(array_ops.where_v2(x), "stub") self.assertEqual(array_ops.where_v2(x, x, x), "stub") finally: # Clean up dispatch table. dispatch.unregister_dispatch_for(masked_where)
def testDispatchForBinaryElementwiseAPIs(self): @dispatch.dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor) def binary_elementwise_api_handler(api_func, x, y): return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask) try: x = MaskedTensor([1, -2, -3], [True, True, False]) y = MaskedTensor([10, 20, 30], [True, False, True]) # Test calls with positional & keyword arguments (& combinations) x_times_y = math_ops.multiply(x, y) x_plus_y = math_ops.add(x, y=y) x_minus_y = math_ops.subtract(x=x, y=y) min_x_y = math_ops.minimum(x, y, "min_x_y") y_times_x = math_ops.multiply(y, x, name="y_times_x") y_plus_x = math_ops.add(y, y=x, name="y_plus_x") y_minus_x = math_ops.subtract(x=y, y=x, name="y_minus_x") self.assertAllEqual(x_times_y.values, [10, -40, -90]) self.assertAllEqual(x_plus_y.values, [11, 18, 27]) self.assertAllEqual(x_minus_y.values, [-9, -22, -33]) self.assertAllEqual(min_x_y.values, [1, -2, -3]) self.assertAllEqual(y_times_x.values, [10, -40, -90]) self.assertAllEqual(y_plus_x.values, [11, 18, 27]) self.assertAllEqual(y_minus_x.values, [9, 22, 33]) for result in [ x_times_y, x_plus_y, x_minus_y, min_x_y, y_times_x, y_plus_x, y_minus_x ]: self.assertAllEqual(result.mask, [True, False, False]) if not context.executing_eagerly( ): # names not defined in eager mode. self.assertRegex(min_x_y.values.name, r"^min_x_y/Minimum:.*") self.assertRegex(min_x_y.mask.name, r"^min_x_y/and:.*") self.assertRegex(y_times_x.values.name, r"^y_times_x/.*") self.assertRegex(y_plus_x.values.name, r"^y_plus_x/.*") self.assertRegex(y_minus_x.values.name, r"^y_minus_x/.*") finally: dispatch.unregister_dispatch_for(binary_elementwise_api_handler)
def testRegisterUnaryElementwiseApiAfterHandler(self): # Test that it's ok to call register_unary_elementwise_api after # dispatch_for_unary_elementwise_apis. @dispatch.dispatch_for_unary_elementwise_apis(MaskedTensor) def handler(api_func, x): return MaskedTensor(api_func(x.values), x.mask) try: @dispatch.register_unary_elementwise_api @dispatch.add_dispatch_support def some_op(x): return x * 2 x = MaskedTensor([1, 2, 3], [True, False, True]) y = some_op(x) self.assertAllEqual(y.values, [2, 4, 6]) self.assertAllEqual(y.mask, [True, False, True]) finally: dispatch.unregister_dispatch_for(handler)
def testDispatchForUnaryElementwiseAPIs(self): @dispatch.dispatch_for_unary_elementwise_apis(MaskedTensor) def unary_elementwise_api_handler(api_func, x): return MaskedTensor(api_func(x.values), x.mask) try: x = MaskedTensor([1, -2, -3], [True, True, False]) # Test calls with positional & keyword argument (& combinations) abs_x = math_ops.abs(x) sign_x = math_ops.sign(x=x) neg_x = math_ops.negative(x, "neg_x") invert_x = bitwise_ops.invert(x, name="invert_x") ones_like_x = array_ops.ones_like(x, name="ones_like_x") ones_like_x_float = array_ops.ones_like(x, dtypes.float32, name="ones_like_x_float") self.assertAllEqual(abs_x.values, [1, 2, 3]) self.assertAllEqual(sign_x.values, [1, -1, -1]) self.assertAllEqual(neg_x.values, [-1, 2, 3]) self.assertAllEqual(invert_x.values, [-2, 1, 2]) self.assertAllEqual(ones_like_x.values, [1, 1, 1]) self.assertAllEqual(ones_like_x_float.values, [1., 1., 1.]) for result in [ abs_x, sign_x, neg_x, invert_x, ones_like_x, ones_like_x_float ]: self.assertAllEqual(result.mask, [True, True, False]) if not context.executing_eagerly( ): # names not defined in eager mode. self.assertRegex(neg_x.values.name, r"^neg_x/Neg:.*") self.assertRegex(invert_x.values.name, r"^invert_x/.*") self.assertRegex(ones_like_x.values.name, r"^ones_like_x/.*") self.assertRegex(ones_like_x_float.values.name, r"^ones_like_x_float/.*") finally: dispatch.unregister_dispatch_for(unary_elementwise_api_handler)
def testDispatchTargetWithNoNameArgument(self): @dispatch.dispatch_for_api(math_ops.add, { "x": MaskedTensor, "y": MaskedTensor }) def masked_add(x, y): return MaskedTensor(x.values + y.values, x.mask & y.mask) try: x = MaskedTensor([1, 2, 3, 4, 5], [1, 0, 1, 1, 1]) y = MaskedTensor([1, 1, 1, 1, 1], [1, 1, 0, 1, 0]) # pass name w/ keyword arg a = math_ops.add(x, y, name="MyAdd") if not context.executing_eagerly( ): # names not defined in eager mode. self.assertRegex(a.values.name, r"^MyAdd/add.*") self.assertRegex(a.mask.name, r"^MyAdd/and.*") # pass name w/ positional arg b = math_ops.add(x, y, "B") if not context.executing_eagerly( ): # names not defined in eager mode. self.assertRegex(b.values.name, r"^B/add.*") self.assertRegex(b.mask.name, r"^B/and.*") # default name value c = math_ops.add(x, y) if not context.executing_eagerly( ): # names not defined in eager mode. self.assertRegex(c.values.name, r"^add.*") self.assertRegex(c.mask.name, r"^and.*") finally: # Clean up dispatch table. dispatch.unregister_dispatch_for(masked_add)
def testRegisterBinaryElementwiseApiAfterHandler(self): # Test that it's ok to call register_binary_elementwise_api after # dispatch_for_binary_elementwise_apis. @dispatch.dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor) def handler(api_func, x, y): return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask) try: @dispatch.register_binary_elementwise_api @dispatch.add_dispatch_support def some_op(x, y): return x * 2 + y x = MaskedTensor([1, 2, 3], [True, False, True]) y = MaskedTensor([10, 20, 30], [True, True, False]) z = some_op(x, y) self.assertAllEqual(z.values, [12, 24, 36]) self.assertAllEqual(z.mask, [True, False, False]) finally: dispatch.unregister_dispatch_for(handler)
def testTypeBasedDispatchTargetsFor(self): MaskedTensorList = typing.List[typing.Union[MaskedTensor, ops.Tensor]] try: @dispatch.dispatch_for_api(math_ops.add) def masked_add(x: MaskedTensor, y: MaskedTensor): del x, y @dispatch.dispatch_for_api(array_ops.concat) def masked_concat(values: MaskedTensorList, axis): del values, axis @dispatch.dispatch_for_api(math_ops.add) def silly_add(x: SillyTensor, y: SillyTensor): del x, y @dispatch.dispatch_for_api(math_ops.abs) def silly_abs(x: SillyTensor): del x # Note: `expeced` does not contain keys or values from SillyTensor. targets = dispatch.type_based_dispatch_signatures_for(MaskedTensor) expected = { math_ops.add: [{ "x": MaskedTensor, "y": MaskedTensor }], array_ops.concat: [{ "values": MaskedTensorList }] } self.assertEqual(targets, expected) finally: # Clean up dispatch table. dispatch.unregister_dispatch_for(masked_add) dispatch.unregister_dispatch_for(masked_concat) dispatch.unregister_dispatch_for(silly_add) dispatch.unregister_dispatch_for(silly_abs)
def testUnregisterDispatchTargetBadTargetError(self): fn = lambda x: x + 1 with self.assertRaisesRegex(ValueError, "Function .* was not registered"): dispatch.unregister_dispatch_for(fn)