def test_invalid_input(self):
    """Test that functions properly fail on invalid input."""
    with self.assertRaisesRegexp(ValueError, 'Shapes .* are incompatible'):
      classifier_metrics.run_inception(array_ops.ones([7, 50, 50, 3]))

    p = array_ops.zeros([8, 10])
    p_logits = array_ops.zeros([8, 10])
    q = array_ops.zeros([10])
    with self.assertRaisesRegexp(ValueError, 'must be floating type'):
      classifier_metrics._kl_divergence(
          array_ops.zeros([8, 10], dtype=dtypes.int32), p_logits, q)

    with self.assertRaisesRegexp(ValueError, 'must be floating type'):
      classifier_metrics._kl_divergence(p,
                                        array_ops.zeros(
                                            [8, 10], dtype=dtypes.int32), q)

    with self.assertRaisesRegexp(ValueError, 'must be floating type'):
      classifier_metrics._kl_divergence(p, p_logits,
                                        array_ops.zeros(
                                            [10], dtype=dtypes.int32))

    with self.assertRaisesRegexp(ValueError, 'must have rank 2'):
      classifier_metrics._kl_divergence(array_ops.zeros([8]), p_logits, q)

    with self.assertRaisesRegexp(ValueError, 'must have rank 2'):
      classifier_metrics._kl_divergence(p, array_ops.zeros([8]), q)

    with self.assertRaisesRegexp(ValueError, 'must have rank 1'):
      classifier_metrics._kl_divergence(p, p_logits, array_ops.zeros([10, 8]))
示例#2
0
    def test_invalid_input(self):
        """Test that functions properly fail on invalid input."""
        with self.assertRaisesRegexp(ValueError, 'Shapes .* are incompatible'):
            classifier_metrics.run_inception(array_ops.ones([7, 50, 50, 3]))

        p = array_ops.zeros([8, 10])
        p_logits = array_ops.zeros([8, 10])
        q = array_ops.zeros([10])
        with self.assertRaisesRegexp(ValueError, 'must be floating type'):
            classifier_metrics._kl_divergence(
                array_ops.zeros([8, 10], dtype=dtypes.int32), p_logits, q)

        with self.assertRaisesRegexp(ValueError, 'must be floating type'):
            classifier_metrics._kl_divergence(
                p, array_ops.zeros([8, 10], dtype=dtypes.int32), q)

        with self.assertRaisesRegexp(ValueError, 'must be floating type'):
            classifier_metrics._kl_divergence(
                p, p_logits, array_ops.zeros([10], dtype=dtypes.int32))

        with self.assertRaisesRegexp(ValueError, 'must have rank 2'):
            classifier_metrics._kl_divergence(array_ops.zeros([8]), p_logits,
                                              q)

        with self.assertRaisesRegexp(ValueError, 'must have rank 2'):
            classifier_metrics._kl_divergence(p, array_ops.zeros([8]), q)

        with self.assertRaisesRegexp(ValueError, 'must have rank 1'):
            classifier_metrics._kl_divergence(p, p_logits,
                                              array_ops.zeros([10, 8]))