示例#1
0
    def testMultilabelMatch3(self):
        predictions = np.random.randint(1, 5, size=(5, 1, 1, 1))
        targets = np.random.randint(1, 5, size=(5, 10, 1, 1))
        weights = np.random.randint(0, 2, size=(5, 1, 1, 1))
        targets *= weights

        print(
            "========================================================================================"
        )
        print("Predictions are: ")
        print(predictions)

        print("Targets are: ")
        print(targets)
        print(
            "========================================================================================"
        )

        predictions_repeat = np.repeat(predictions, 10, axis=1)
        expected = (predictions_repeat == targets).astype(float)
        expected = np.sum(expected, axis=(1, 2, 3))
        expected = np.minimum(expected / 3.0, 1.)
        expected = np.sum(expected * weights[:, 0, 0, 0]) / weights.shape[0]
        with self.test_session() as session:
            scores, weights_ = metrics.multilabel_accuracy_match3(
                tf.one_hot(predictions, depth=5, dtype=tf.float32),
                tf.constant(targets, dtype=tf.int32))
            a, a_op = tf.metrics.mean(scores, weights_)
            session.run(tf.local_variables_initializer())
            session.run(tf.global_variables_initializer())
            _ = session.run(a_op)
            actual = session.run(a)
        self.assertAlmostEqual(actual, expected, places=6)
示例#2
0
  def testMultilabelMatch3(self):
    predictions = np.random.randint(1, 5, size=(100, 1, 1, 1))
    targets = np.random.randint(1, 5, size=(100, 10, 1, 1))
    weights = np.random.randint(0, 2, size=(100, 1, 1, 1))
    targets *= weights

    predictions_repeat = np.repeat(predictions, 10, axis=1)
    expected = (predictions_repeat == targets).astype(float)
    expected = np.sum(expected, axis=(1, 2, 3))
    expected = np.minimum(expected / 3.0, 1.)
    expected = np.sum(expected * weights[:, 0, 0, 0]) / np.sum(weights)
    with self.test_session() as session:
      scores, weights_ = metrics.multilabel_accuracy_match3(
          tf.one_hot(predictions, depth=5, dtype=tf.float32),
          tf.constant(targets, dtype=tf.int32))
      a, a_op = tf.metrics.mean(scores, weights_)
      session.run(tf.local_variables_initializer())
      session.run(tf.global_variables_initializer())
      _ = session.run(a_op)
      actual = session.run(a)
    self.assertAlmostEqual(actual, expected, places=6)