def testSample(self):
     """Ensure samples can be drawn."""
     with ops.Graph().as_default(), self.test_session() as sess:
         logits = np.asarray([
             [0., 0., 0.],  #
             [1., -1., 0.]
         ]).astype(np.float32)
         loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
             array_ops.constant(logits))
         sample = loss.sample(42)
         sample = sess.run(sample)
         self.assertEqual(sample.shape, (2, 3))
    def testEvaluateOnSample(self):
        """Ensure log probability of a sample can be drawn."""
        with ops.Graph().as_default(), self.test_session() as sess:
            logits = np.asarray([
                [0., 0., 0.],  #
                [1., -1., 0.]
            ]).astype(np.float32)
            loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
                array_ops.constant(logits))
            neg_log_prob = loss.evaluate_on_sample(42)

            # Simply ensure this doesn't crash. As the output is random, it's
            # difficult to say if the output is correct or not...
            neg_log_prob = sess.run(neg_log_prob)
 def testMultiMinibatchRegistration(self):
     """Ensure this loss function supports registering multiple minibatches."""
     with ops.Graph().as_default():
         tower_logits = []
         loss = None
         num_towers = 5
         for _ in range(num_towers):
             logits = random_ops.random_uniform(shape=[2, 3])
             tower_logits.append(logits)
             if loss is None:
                 loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
                     logits)
             else:
                 loss.register_additional_minibatch(logits)
         self.assertListEqual(loss.input_minibatches, tower_logits)
         self.assertEqual(loss.num_registered_minibatches, num_towers)
    def testEvaluateOnTargets(self):
        """Ensure log probability can be evaluated correctly."""
        with ops.Graph().as_default(), self.test_session() as sess:
            logits = np.asarray([
                [0., 0., 0.],  #
                [1., -1., 0.]
            ]).astype(np.float32)
            targets = np.asarray([2, 1]).astype(np.int32)
            loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
                array_ops.constant(logits),
                targets=array_ops.one_hot(targets, 3))
            neg_log_prob = loss.evaluate()
            neg_log_prob = sess.run(neg_log_prob)

            # Calculate explicit log probability of targets.
            probs = np.exp(logits) / np.sum(
                np.exp(logits), axis=1, keepdims=True)
            log_probs = np.log([
                probs[0, targets[0]],  #
                probs[1, targets[1]]
            ])
            expected_log_prob = np.sum(log_probs)

            self.assertAllClose(neg_log_prob, -expected_log_prob)