コード例 #1
0
  def testSampledCandidates(self):
    with self.cached_session():
      true_classes = constant_op.constant(
          [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
      sampled_candidates, _, _ = candidate_sampling_ops.all_candidate_sampler(
          true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True)
      result = self.evaluate(sampled_candidates)

    expected_ids = [0, 1, 2, 3, 4]
    self.assertAllEqual(result, expected_ids)
    self.assertEqual(sampled_candidates.get_shape(), [self.NUM_SAMPLED])
コード例 #2
0
  def testSampledLogExpectedCount(self):
    with self.cached_session():
      true_classes = constant_op.constant(
          [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
      _, _, sampled_expected_count = candidate_sampling_ops.all_candidate_sampler(  # pylint: disable=line-too-long
          true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True)
      sampled_log_expected_count = math_ops.log(sampled_expected_count)
      result = self.evaluate(sampled_log_expected_count)

    self.assertAllEqual(result, [0.0] * self.NUM_SAMPLED)
    self.assertEqual(sampled_expected_count.get_shape(), [self.NUM_SAMPLED])
    self.assertEqual(sampled_log_expected_count.get_shape(), [self.NUM_SAMPLED])
コード例 #3
0
  def testTrueLogExpectedCount(self):
    with self.cached_session():
      true_classes = constant_op.constant(
          [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
      _, true_expected_count, _ = candidate_sampling_ops.all_candidate_sampler(
          true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True)
      true_log_expected_count = math_ops.log(true_expected_count)
      result = self.evaluate(true_log_expected_count)

    self.assertAllEqual(result, [[0.0] * self.NUM_TRUE] * self.BATCH_SIZE)
    self.assertEqual(true_expected_count.get_shape(),
                     [self.BATCH_SIZE, self.NUM_TRUE])
    self.assertEqual(true_log_expected_count.get_shape(),
                     [self.BATCH_SIZE, self.NUM_TRUE])
コード例 #4
0
  def testAccidentalHits(self):
    with self.cached_session() as sess:
      true_classes = constant_op.constant(
          [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
      sampled_candidates, _, _ = candidate_sampling_ops.all_candidate_sampler(
          true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True)
      accidental_hits = candidate_sampling_ops.compute_accidental_hits(
          true_classes, sampled_candidates, self.NUM_TRUE)
      indices, ids, weights = self.evaluate(accidental_hits)

    self.assertEqual(1, accidental_hits[0].get_shape().ndims)
    self.assertEqual(1, accidental_hits[1].get_shape().ndims)
    self.assertEqual(1, accidental_hits[2].get_shape().ndims)
    for index, id_, weight in zip(indices, ids, weights):
      self.assertTrue(id_ in self.TRUE_LABELS[index])
      self.assertLess(weight, -1.0e37)