Esempio n. 1
0
    def testDispatchForUnion(self):
        MaybeMasked = typing.Union[MaskedTensor, ops.Tensor]

        @dispatch.dispatch_for(math_ops.add, {
            "x": MaybeMasked,
            "y": MaybeMasked
        })
        def masked_add(x, y, name=None):
            with ops.name_scope(name):
                x_values = x.values if isinstance(x, MaskedTensor) else x
                x_mask = x.mask if isinstance(x, MaskedTensor) else True
                y_values = y.values if isinstance(y, MaskedTensor) else y
                y_mask = y.mask if isinstance(y, MaskedTensor) else True
                return MaskedTensor(x_values + y_values, x_mask & y_mask)

        try:
            x = MaskedTensor([1, 2, 3, 4, 5], [1, 0, 1, 1, 1])
            y = constant_op.constant([10, 20, 30, 40, 50])
            z = math_ops.add(x, y)
            self.assertAllEqual(z.values, x.values + y)
            self.assertAllEqual(z.mask, x.mask)

        finally:
            # Clean up dispatch table.
            dispatch.unregister_dispatch_target(math_ops.add, masked_add)
Esempio n. 2
0
  def testDispatchTargetWithNoNameArgument(self):

    @dispatch.dispatch_for(math_ops.add, {"x": MaskedTensor, "y": MaskedTensor})
    def masked_add(x, y):
      return MaskedTensor(x.values + y.values, x.mask & y.mask)

    try:
      x = MaskedTensor([1, 2, 3, 4, 5], [1, 0, 1, 1, 1])
      y = MaskedTensor([1, 1, 1, 1, 1], [1, 1, 0, 1, 0])

      # pass name w/ keyword arg
      a = math_ops.add(x, y, name="MyAdd")
      if not context.executing_eagerly():  # names not defined in eager mode.
        self.assertRegex(a.values.name, r"^MyAdd/add.*")
        self.assertRegex(a.mask.name, r"^MyAdd/and.*")

      # pass name w/ positional arg
      b = math_ops.add(x, y, "B")
      if not context.executing_eagerly():  # names not defined in eager mode.
        self.assertRegex(b.values.name, r"^B/add.*")
        self.assertRegex(b.mask.name, r"^B/and.*")

      # default name value
      c = math_ops.add(x, y)
      if not context.executing_eagerly():  # names not defined in eager mode.
        self.assertRegex(c.values.name, r"^add.*")
        self.assertRegex(c.mask.name, r"^and.*")

    finally:
      # Clean up dispatch table.
      dispatch.unregister_dispatch_target(math_ops.add, masked_add)
Esempio n. 3
0
    def testDispatchSignatureWithUnspecifiedParameter(self):
        @dispatch.dispatch_for(math_ops.add, {"x": MaskedTensor})
        def masked_add(x, y):
            if y is None:
                return x
            y_values = y.values if isinstance(y, MaskedTensor) else y
            y_mask = y.mask if isinstance(y, MaskedTensor) else True
            return MaskedTensor(x.values + y_values, x.mask & y_mask)

        try:
            a = MaskedTensor([1, 2, 3, 4, 5], [1, 0, 1, 1, 1])
            b = constant_op.constant([10, 20, 30, 40, 50])
            c = [10, 20, 30, 40, 50]
            d = 50
            e = None
            # As long as `x` is a MaskedTensor, the dispatcher will be called
            # (regardless of the type for `y`):
            self.assertAllEqual(
                math_ops.add(a, b).values, [11, 22, 33, 44, 55])
            self.assertAllEqual(
                math_ops.add(a, c).values, [11, 22, 33, 44, 55])
            self.assertAllEqual(
                math_ops.add(a, d).values, [51, 52, 53, 54, 55])
            self.assertAllEqual(math_ops.add(a, e).values, [1, 2, 3, 4, 5])

        finally:
            # Clean up dispatch table.
            dispatch.unregister_dispatch_target(math_ops.add, masked_add)
Esempio n. 4
0
  def testTypeBasedDispatchTargetsFor(self):
    MaskedTensorList = typing.List[typing.Union[MaskedTensor, ops.Tensor]]
    try:
      @dispatch.dispatch_for(math_ops.add)
      def masked_add(x: MaskedTensor, y: MaskedTensor):
        del x, y

      @dispatch.dispatch_for(array_ops.concat)
      def masked_concat(values: MaskedTensorList, axis):
        del values, axis

      @dispatch.dispatch_for(math_ops.add)
      def silly_add(x: SillyTensor, y: SillyTensor):
        del x, y

      @dispatch.dispatch_for(math_ops.abs)
      def silly_abs(x: SillyTensor):
        del x

      # Note: `expeced` does not contain keys or values from SillyTensor.
      targets = dispatch.type_based_dispatch_signatures_for(MaskedTensor)
      expected = {math_ops.add: [{"x": MaskedTensor, "y": MaskedTensor}],
                  array_ops.concat: [{"values": MaskedTensorList}]}
      self.assertEqual(targets, expected)

    finally:
      # Clean up dispatch table.
      dispatch.unregister_dispatch_target(math_ops.add, masked_add)
      dispatch.unregister_dispatch_target(array_ops.concat, masked_concat)
      dispatch.unregister_dispatch_target(math_ops.add, silly_add)
      dispatch.unregister_dispatch_target(math_ops.abs, silly_abs)
Esempio n. 5
0
    def testDispatchForSignatureFromAnnotations(self):
        @dispatch.dispatch_for(math_ops.add)
        def masked_add(x: MaskedTensor, y: MaskedTensor, name=None):
            with ops.name_scope(name):
                return MaskedTensor(x.values + y.values, x.mask & y.mask)

        try:
            x = MaskedTensor([1, 2, 3, 4, 5], [1, 0, 1, 1, 1])
            y = MaskedTensor([1, 1, 1, 1, 1], [1, 1, 0, 1, 0])
            z = math_ops.add(x, y)
            self.assertAllEqual(z.values, x.values + y.values)
            self.assertAllEqual(z.mask, x.mask & y.mask)

        finally:
            # Clean up dispatch table.
            dispatch.unregister_dispatch_target(math_ops.add, masked_add)
Esempio n. 6
0
  def testDispatchWithKwargs(self):

    @dispatch.dispatch_for(math_ops.add, {"x": MaskedTensor, "y": MaskedTensor})
    def masked_add(*args, **kwargs):
      self.assertAllEqual(kwargs["x"].values, x.values)
      self.assertAllEqual(kwargs["y"].values, y.values)
      self.assertEmpty(args)
      return "stub"

    try:
      x = MaskedTensor([1, 2, 3, 4, 5], [1, 0, 1, 1, 1])
      y = MaskedTensor([1, 1, 1, 1, 1], [1, 1, 0, 1, 0])
      self.assertEqual(math_ops.add(x=x, y=y), "stub")

    finally:
      # Clean up dispatch table.
      dispatch.unregister_dispatch_target(math_ops.add, masked_add)
Esempio n. 7
0
    def testDispatchWithIterableParams(self):
        # The add_n API supports having `inputs` be an iterable (and not just
        # a sequence).
        @dispatch.dispatch_for(math_ops.add_n,
                               {"inputs": typing.List[MaskedTensor]})
        def masked_add_n(inputs):
            masks = array_ops.stack([x.mask for x in inputs])
            return MaskedTensor(math_ops.add_n([x.values for x in inputs]),
                                math_ops.reduce_all(masks, axis=0))

        try:
            generator = (MaskedTensor([i], [True]) for i in range(5))
            y = math_ops.add_n(generator)
            self.assertAllEqual(y.values, [0 + 1 + 2 + 3 + 4])
            self.assertAllEqual(y.mask, [True])

        finally:
            # Clean up dispatch table.
            dispatch.unregister_dispatch_target(math_ops.add_n, masked_add_n)
Esempio n. 8
0
    def testRegisterDispatchableType(self):
        Car = collections.namedtuple("Car", ["size", "speed"])
        dispatch.register_dispatchable_type(Car)

        @dispatch.dispatch_for(math_ops.add, {"x": Car, "y": Car})
        def add_car(x, y, name=None):
            with ops.name_scope(name):
                return Car(x.size + y.size, x.speed + y.speed)

        try:
            x = Car(constant_op.constant(1), constant_op.constant(3))
            y = Car(constant_op.constant(10), constant_op.constant(20))
            z = math_ops.add(x, y)
            self.assertAllEqual(z.size, 11)
            self.assertAllEqual(z.speed, 23)

        finally:
            # Clean up dispatch table.
            dispatch.unregister_dispatch_target(math_ops.add, add_car)
Esempio n. 9
0
  def testDispatchApiWithNoNameArg(self):
    # Note: The "tensor_equals" API has no "name" argument.
    signature = {"self": MaskedTensor, "other": MaskedTensor}

    @dispatch.dispatch_for(math_ops.tensor_equals, signature)
    def masked_tensor_equals(self, other):
      del self, other

    dispatch.unregister_dispatch_target(math_ops.tensor_equals,
                                        masked_tensor_equals)  # clean up.

    with self.assertRaisesRegexp(
        ValueError, r"Dispatch function's signature \(self, other, name=None\) "
        r"does not match API's signature \(self, other\)\."):

      @dispatch.dispatch_for(math_ops.tensor_equals, signature)
      def masked_tensor_equals_2(self, other, name=None):
        del self, other, name

      del masked_tensor_equals_2  # avoid pylint unused variable warning.
Esempio n. 10
0
  def testDispatchForList(self):

    @dispatch.dispatch_for(array_ops.concat,
                           {"values": typing.List[MaskedTensor]})
    def masked_concat(values, axis, name=None):
      with ops.name_scope(name):
        return MaskedTensor(
            array_ops.concat([v.values for v in values], axis),
            array_ops.concat([v.mask for v in values], axis))

    try:
      x = MaskedTensor([1, 2, 3, 4, 5], [1, 0, 1, 1, 1])
      y = MaskedTensor([1, 1, 1], [1, 1, 0])
      z = array_ops.concat([x, y], axis=0)
      self.assertAllEqual(z.values, array_ops.concat([x.values, y.values], 0))
      self.assertAllEqual(z.mask, array_ops.concat([x.mask, y.mask], 0))

    finally:
      # Clean up dispatch table.
      dispatch.unregister_dispatch_target(array_ops.concat, masked_concat)
Esempio n. 11
0
  def testDispatchForOptional(self):
    # Note: typing.Optional[X] == typing.Union[X, NoneType].

    @dispatch.dispatch_for(
        array_ops.where_v2, {
            "condition": MaskedTensor,
            "x": typing.Optional[MaskedTensor],
            "y": typing.Optional[MaskedTensor]
        })
    def masked_where(condition, x=None, y=None, name=None):
      del condition, x, y, name
      return "stub"

    try:
      x = MaskedTensor([True, False, True, True, True], [1, 0, 1, 1, 1])
      self.assertEqual(array_ops.where_v2(x), "stub")
      self.assertEqual(array_ops.where_v2(x, x, x), "stub")

    finally:
      # Clean up dispatch table.
      dispatch.unregister_dispatch_target(array_ops.where_v2, masked_where)
Esempio n. 12
0
 def testUnregisterDispatchTargetBadDispatchTargetError(self):
     fn = lambda x: x + 1
     with self.assertRaisesRegex(ValueError,
                                 ".* was not registered for .*"):
         dispatch.unregister_dispatch_target(math_ops.add, fn)
Esempio n. 13
0
 def testUnregisterDispatchTargetBadTargetError(self):
     fn = lambda x: x + 1
     with self.assertRaisesRegex(ValueError,
                                 ".* does not support dispatch"):
         dispatch.unregister_dispatch_target(fn, fn)