def test_sparse_both_different(self): p_x = torch.FloatTensor([0.5, 0.0, 0.5]) q_x = torch.FloatTensor([0.5, 0.5, 0.0]) xent = categorical_cross_entropy(p_x, q_x) self.assertEqual(xent[0], float("inf"))
def test_sparse_both_same(self): p_x = torch.FloatTensor([0.5, 0.5, 0.0]) q_x = torch.FloatTensor([0.5, 0.5, 0.0]) xent = categorical_cross_entropy(p_x, q_x) self.assertEqual(xent[0], 1)
def test_sparse_true_dist(self): p_x = torch.FloatTensor([1, 0.0]) q_x = torch.FloatTensor([0.25, 0.75]) xent = categorical_cross_entropy(p_x, q_x) self.assertEqual(xent[0], 2)
def test_simple(self): p_x = torch.FloatTensor([0.5, 0.5]) q_x = torch.FloatTensor([0.25, 0.75]) xent = categorical_cross_entropy(p_x, q_x) self.assertAlmostEqual(xent[0], 1.207518749639422, delta=1e-7)
def test_one_vs_many(self): p_x = torch.FloatTensor([[0.5, 0.5], [1.0, 0.0]]) q_x = torch.FloatTensor([0.25, 0.75]) xent = categorical_cross_entropy(p_x, q_x) self.assertAlmostEqual(xent[0], 1.207518749639422, delta=1e-7) self.assertEqual(xent[1], 2)
if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--source-list') parser.add_argument('--file-list') parser.add_argument('--vocab') parser.add_argument('--unk', default='<unk>') args = parser.parse_args() with open(args.vocab) as f: vocab = vocab.vocab_from_kaldi_wordlist(f, args.unk) documents = documents_from_fn(args.source_list) bows = bow_from_documents(documents, vocab).float() unigram_ps = bows_to_ps(bows.sum(dim=0, keepdim=True)).squeeze() test_documents = documents_from_fn(args.file_list) test_bows = bow_from_documents(test_documents, vocab).float() test_unigrams = bows_to_ps(test_bows) print(unigram_ps.size()) print(test_unigrams.size()) cross_entropies = analysis.categorical_cross_entropy( test_unigrams, unigram_ps) # print(cross_entropies) test_lengths = test_bows.sum(dim=1) avg_entropy = cross_entropies @ test_lengths / test_bows.sum() print("{:.4f} {:.2f}".format(avg_entropy, 2**avg_entropy))