Пример #1
0
    def test_sparse_both_different(self):
        p_x = torch.FloatTensor([0.5, 0.0, 0.5])
        q_x = torch.FloatTensor([0.5, 0.5, 0.0])

        xent = categorical_cross_entropy(p_x, q_x)

        self.assertEqual(xent[0], float("inf"))
Пример #2
0
    def test_sparse_both_same(self):
        p_x = torch.FloatTensor([0.5, 0.5, 0.0])
        q_x = torch.FloatTensor([0.5, 0.5, 0.0])

        xent = categorical_cross_entropy(p_x, q_x)

        self.assertEqual(xent[0], 1)
Пример #3
0
    def test_sparse_true_dist(self):
        p_x = torch.FloatTensor([1, 0.0])
        q_x = torch.FloatTensor([0.25, 0.75])

        xent = categorical_cross_entropy(p_x, q_x)

        self.assertEqual(xent[0], 2)
Пример #4
0
    def test_simple(self):
        p_x = torch.FloatTensor([0.5, 0.5])
        q_x = torch.FloatTensor([0.25, 0.75])

        xent = categorical_cross_entropy(p_x, q_x)

        self.assertAlmostEqual(xent[0], 1.207518749639422, delta=1e-7)
Пример #5
0
    def test_one_vs_many(self):
        p_x = torch.FloatTensor([[0.5, 0.5], [1.0, 0.0]])
        q_x = torch.FloatTensor([0.25, 0.75])

        xent = categorical_cross_entropy(p_x, q_x)

        self.assertAlmostEqual(xent[0], 1.207518749639422, delta=1e-7)
        self.assertEqual(xent[1], 2)
Пример #6
0
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--source-list')
    parser.add_argument('--file-list')
    parser.add_argument('--vocab')
    parser.add_argument('--unk', default='<unk>')
    args = parser.parse_args()

    with open(args.vocab) as f:
        vocab = vocab.vocab_from_kaldi_wordlist(f, args.unk)

    documents = documents_from_fn(args.source_list)
    bows = bow_from_documents(documents, vocab).float()
    unigram_ps = bows_to_ps(bows.sum(dim=0, keepdim=True)).squeeze()

    test_documents = documents_from_fn(args.file_list)
    test_bows = bow_from_documents(test_documents, vocab).float()
    test_unigrams = bows_to_ps(test_bows)

    print(unigram_ps.size())
    print(test_unigrams.size())

    cross_entropies = analysis.categorical_cross_entropy(
        test_unigrams, unigram_ps)
    # print(cross_entropies)

    test_lengths = test_bows.sum(dim=1)
    avg_entropy = cross_entropies @ test_lengths / test_bows.sum()
    print("{:.4f} {:.2f}".format(avg_entropy, 2**avg_entropy))