def testSampledCandidates(self): with self.cached_session(): true_classes = constant_op.constant( [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64) sampled_candidates, _, _ = candidate_sampling_ops.all_candidate_sampler( true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True) result = self.evaluate(sampled_candidates) expected_ids = [0, 1, 2, 3, 4] self.assertAllEqual(result, expected_ids) self.assertEqual(sampled_candidates.get_shape(), [self.NUM_SAMPLED])
def testSampledLogExpectedCount(self): with self.cached_session(): true_classes = constant_op.constant( [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64) _, _, sampled_expected_count = candidate_sampling_ops.all_candidate_sampler( # pylint: disable=line-too-long true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True) sampled_log_expected_count = math_ops.log(sampled_expected_count) result = self.evaluate(sampled_log_expected_count) self.assertAllEqual(result, [0.0] * self.NUM_SAMPLED) self.assertEqual(sampled_expected_count.get_shape(), [self.NUM_SAMPLED]) self.assertEqual(sampled_log_expected_count.get_shape(), [self.NUM_SAMPLED])
def testTrueLogExpectedCount(self): with self.cached_session(): true_classes = constant_op.constant( [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64) _, true_expected_count, _ = candidate_sampling_ops.all_candidate_sampler( true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True) true_log_expected_count = math_ops.log(true_expected_count) result = self.evaluate(true_log_expected_count) self.assertAllEqual(result, [[0.0] * self.NUM_TRUE] * self.BATCH_SIZE) self.assertEqual(true_expected_count.get_shape(), [self.BATCH_SIZE, self.NUM_TRUE]) self.assertEqual(true_log_expected_count.get_shape(), [self.BATCH_SIZE, self.NUM_TRUE])
def testAccidentalHits(self): with self.cached_session() as sess: true_classes = constant_op.constant( [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64) sampled_candidates, _, _ = candidate_sampling_ops.all_candidate_sampler( true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True) accidental_hits = candidate_sampling_ops.compute_accidental_hits( true_classes, sampled_candidates, self.NUM_TRUE) indices, ids, weights = self.evaluate(accidental_hits) self.assertEqual(1, accidental_hits[0].get_shape().ndims) self.assertEqual(1, accidental_hits[1].get_shape().ndims) self.assertEqual(1, accidental_hits[2].get_shape().ndims) for index, id_, weight in zip(indices, ids, weights): self.assertTrue(id_ in self.TRUE_LABELS[index]) self.assertLess(weight, -1.0e37)