コード例 #1
0
    def testCondPacked(self):
        x = MaskedTensorV2([1, 2, 3, 4], [True, False, True, False])
        y = MaskedTensorV2([5, 6, 7, 8], [False, True, True, False])
        x = extension_type.pack(x)
        y = extension_type.pack(y)

        x_2 = control_flow_ops.cond(constant_op.constant(True), lambda: x,
                                    lambda: y)
        y_2 = control_flow_ops.cond(constant_op.constant(False), lambda: x,
                                    lambda: y)

        self.assertAllEqual(x.values, x_2.values)
        self.assertAllEqual(x.mask, x_2.mask)
        self.assertAllEqual(y.values, y_2.values)
        self.assertAllEqual(y.mask, y_2.mask)

        a = MaskedTensorV2([1, 2, 3, 4], [True, False, True, False])
        b = extension_type.pack(a)
        b = control_flow_ops.cond(constant_op.constant(True),
                                  lambda: array_ops.size(a.mask),
                                  lambda: array_ops.size(a.values))
        self.assertAllEqual(b, 4)

        # Note: the following example would fail (with `Retval[0] does not have a
        # value`) if `ExtensionType.__getattr__` cached the results of unpacking
        # the value.  See the comment in `ExtensionType.__getattr__` for details.
        c = MaskedTensorV2([1, 2, 3, 4], [True, False, True, False])
        c = extension_type.pack(c)
        d = control_flow_ops.cond(constant_op.constant(False),
                                  lambda: array_ops.size(c.mask),
                                  lambda: array_ops.size(c.values))
        self.assertAllEqual(d, 4)
コード例 #2
0
    def testPackedEncoding(self):
        mt1 = MaskedTensorV2([1, 2, 3, 4], [True, True, False, True])
        self.assertLen(nest.flatten(mt1, expand_composites=True), 2)

        mt2 = extension_type.pack(mt1)
        self.assertLen(nest.flatten(mt2, expand_composites=True), 1)
        self.assertIsInstance(mt2.values, ops.Tensor)
        self.assertAllEqual(mt2.values, [1, 2, 3, 4])
        self.assertIsInstance(mt2.mask, ops.Tensor)
        self.assertAllEqual(mt2.mask, [True, True, False, True])

        mt3 = extension_type.unpack(mt2)
        self.assertLen(nest.flatten(mt3, expand_composites=True), 2)
        self.assertIsInstance(mt3.values, ops.Tensor)
        self.assertAllEqual(mt3.values, [1, 2, 3, 4])
        self.assertIsInstance(mt3.mask, ops.Tensor)
        self.assertAllEqual(mt3.mask, [True, True, False, True])

        nest.assert_same_structure(mt1, mt3, expand_composites=True)
        with self.assertRaisesRegex(ValueError, "don't have the same"):  # pylint: disable=g-error-prone-assert-raises
            nest.assert_same_structure(mt1, mt2, expand_composites=True)

        mt4 = MaskedTensorV1([1, 2, 3, 4], [True, True, False, True])
        with self.assertRaisesRegex(
                ValueError,
                'ExtensionTypes must have a __name__ field in order to be packed.'
        ):
            extension_type.pack(mt4)
コード例 #3
0
    def testPassIntoTfFunction(self):
        @def_function.function
        def fn(x):
            return x.with_default(99)

        mt = MaskedTensorV2([1, 2, 3, 4], [True, True, False, True])
        self.assertAllEqual([1, 2, 99, 4], fn(mt))
        self.assertAllEqual([1, 2, 99, 4], fn(extension_type.pack(mt)))
コード例 #4
0
    def testAttributeAccessors(self):
        mt1 = MaskedTensorV2([1, 2, 3, 4], [True, True, False, True])
        mt2 = extension_type.pack(mt1)

        for mt in [mt1, mt2]:
            self.assertIsInstance(mt.values, ops.Tensor)
            self.assertAllEqual(mt.values, [1, 2, 3, 4])
            self.assertIsInstance(mt.mask, ops.Tensor)
            self.assertAllEqual(mt.mask, [True, True, False, True])
コード例 #5
0
    def testWhileLoopPacked(self):
        x = MaskedTensorV2([1, 2, 3, 4], [True, False, True, False])
        x = extension_type.pack(x)
        cond = lambda i, x: i < 10

        def body(i, x):
            return i + 1, extension_type.pack(
                MaskedTensorV2(x.values * 2, x.mask))

        _, y = control_flow_ops.while_loop_v2(cond, body, [0, x])
        self.assertIsInstance(y, MaskedTensorV2)
        self.assertAllEqual(y.values, [1024, 2048, 3072, 4096])
        self.assertAllEqual(y.mask, [True, False, True, False])
コード例 #6
0
    def testAttributesAreImmutable(self):
        mt1 = MaskedTensorV2([1, 2, 3, 4], [True, True, False, True])
        mt2 = extension_type.pack(mt1)

        for mt in [mt1, mt2]:
            with self.assertRaisesRegex(AttributeError,
                                        "cannot assign to field 'score'"):
                mt.score = 12
            with self.assertRaisesRegex(AttributeError,
                                        "cannot assign to field 'values'"):
                mt.values = constant_op.constant([4, 3, 2, 1])
            with self.assertRaisesRegex(AttributeError,
                                        "cannot delete field 'values'"):
                del mt.values
コード例 #7
0
    def testAttributesAreImmutable(self):
        mt1 = MaskedTensorV2([1, 2, 3, 4], [True, True, False, True])
        mt2 = extension_type.pack(mt1)

        for mt in [mt1, mt2]:
            with self.assertRaisesRegex(
                    AttributeError,
                    'Cannot mutate attribute `score` outside the custom constructor of ExtensionType'
            ):
                mt.score = 12
            with self.assertRaisesRegex(
                    AttributeError,
                    'Cannot mutate attribute `values` outside the custom constructor of ExtensionType'
            ):
                mt.values = constant_op.constant([4, 3, 2, 1])
            with self.assertRaisesRegex(
                    AttributeError,
                    'Cannot mutate attribute `values` outside the custom constructor of ExtensionType'
            ):
                del mt.values
コード例 #8
0
 def body(i, x):
     return i + 1, extension_type.pack(
         MaskedTensorV2(x.values * 2, x.mask))
コード例 #9
0
 def mask_neg_values_packed(x):
     return extension_type.pack(MaskedTensorV2(x, x > 0))