Beispiel #1
0
    def test_one_dimensional(self):
        x = tf.constant([0.3, 0.1, 0.2, -0.5, 42.0])
        top_1 = self.evaluate(metrics_utils._filter_top_k(x=x, k=1))
        top_2 = self.evaluate(metrics_utils._filter_top_k(x=x, k=2))
        top_3 = self.evaluate(metrics_utils._filter_top_k(x=x, k=3))

        self.assertAllClose(
            top_1,
            [
                metrics_utils.NEG_INF,
                metrics_utils.NEG_INF,
                metrics_utils.NEG_INF,
                metrics_utils.NEG_INF,
                42.0,
            ],
        )
        self.assertAllClose(
            top_2,
            [
                0.3,
                metrics_utils.NEG_INF,
                metrics_utils.NEG_INF,
                metrics_utils.NEG_INF,
                42.0,
            ],
        )
        self.assertAllClose(
            top_3,
            [0.3, metrics_utils.NEG_INF, 0.2, metrics_utils.NEG_INF, 42.0],
        )
Beispiel #2
0
    def test_three_dimensional(self):
        x = tf.constant([[[.3, .1, .2], [-.3, -.2, -.1]],
                         [[5., .2, 42.], [-.3, -.6, -.99]]])
        top_2 = self.evaluate(metrics_utils._filter_top_k(x=x, k=2))

        self.assertAllClose(top_2, [[[.3, metrics_utils.NEG_INF, .2],
                                     [metrics_utils.NEG_INF, -.2, -.1]],
                                    [[5., metrics_utils.NEG_INF, 42.],
                                     [-.3, -.6, metrics_utils.NEG_INF]]])
Beispiel #3
0
    def test_three_dimensional(self):
        x = tf.constant([
            [[0.3, 0.1, 0.2], [-0.3, -0.2, -0.1]],
            [[5.0, 0.2, 42.0], [-0.3, -0.6, -0.99]],
        ])
        top_2 = self.evaluate(metrics_utils._filter_top_k(x=x, k=2))

        self.assertAllClose(
            top_2,
            [
                [
                    [0.3, metrics_utils.NEG_INF, 0.2],
                    [metrics_utils.NEG_INF, -0.2, -0.1],
                ],
                [
                    [5.0, metrics_utils.NEG_INF, 42.0],
                    [-0.3, -0.6, metrics_utils.NEG_INF],
                ],
            ],
        )
Beispiel #4
0
        def _filter_top_k(x):
            # This loses the static shape.
            x = tf.numpy_function(_identity, (x, ), tf.float32)

            return metrics_utils._filter_top_k(x=x, k=2)