def testSample(self): """Ensure samples can be drawn.""" with ops.Graph().as_default(), self.test_session() as sess: logits = np.asarray([ [0., 0., 0.], # [1., -1., 0.] ]).astype(np.float32) loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( array_ops.constant(logits)) sample = loss.sample(42) sample = sess.run(sample) self.assertEqual(sample.shape, (2, 3))
def testEvaluateOnSample(self): """Ensure log probability of a sample can be drawn.""" with ops.Graph().as_default(), self.test_session() as sess: logits = np.asarray([ [0., 0., 0.], # [1., -1., 0.] ]).astype(np.float32) loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( array_ops.constant(logits)) neg_log_prob = loss.evaluate_on_sample(42) # Simply ensure this doesn't crash. As the output is random, it's # difficult to say if the output is correct or not... neg_log_prob = sess.run(neg_log_prob)
def testMultiMinibatchRegistration(self): """Ensure this loss function supports registering multiple minibatches.""" with ops.Graph().as_default(): tower_logits = [] loss = None num_towers = 5 for _ in range(num_towers): logits = random_ops.random_uniform(shape=[2, 3]) tower_logits.append(logits) if loss is None: loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( logits) else: loss.register_additional_minibatch(logits) self.assertListEqual(loss.input_minibatches, tower_logits) self.assertEqual(loss.num_registered_minibatches, num_towers)
def testEvaluateOnTargets(self): """Ensure log probability can be evaluated correctly.""" with ops.Graph().as_default(), self.test_session() as sess: logits = np.asarray([ [0., 0., 0.], # [1., -1., 0.] ]).astype(np.float32) targets = np.asarray([2, 1]).astype(np.int32) loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( array_ops.constant(logits), targets=array_ops.one_hot(targets, 3)) neg_log_prob = loss.evaluate() neg_log_prob = sess.run(neg_log_prob) # Calculate explicit log probability of targets. probs = np.exp(logits) / np.sum( np.exp(logits), axis=1, keepdims=True) log_probs = np.log([ probs[0, targets[0]], # probs[1, targets[1]] ]) expected_log_prob = np.sum(log_probs) self.assertAllClose(neg_log_prob, -expected_log_prob)