示例#1
0
    def setUpClass(self):
        self.vocab = NgramModelVocabulary(2, "abcdeadbe")

        self.trigram_counter = NgramCounter(3, self.vocab)
        self.trigram_counter.train_counts(['abcd', 'egdbe'])

        self.bigram_counter = NgramCounter(2, self.vocab)
        self.bigram_counter.train_counts(['abcd', 'egdbe'])
示例#2
0
def create_counter_and_train_text(train_text_list, vocab=None, order=order):
    if vocab is None:
        vocab = create_ngrams_vocabulary(train_text_list)
    counters = NgramCounter(order, vocab)
    for i, text in enumerate(train_text_list):
        counters.train_counts([text])
        if i % 1 == 0:
            print('in step: {}, len text: {}'.format(i, len(text)))
    return counters
示例#3
0
class RepoLanguageModel:
    def __init__(self, path):
        self.path = path
        self.corpus = []

    def get_all_repo_files(self):

        files = []
        for filename in glob.iglob(self.path + '**/*.java', recursive=True):
            files.append(filename)
        return files

    def create_corpus(self, files):

        parsed_files = []

        print('Bad files: ')
        for filename in tqdm_notebook(files):
            with open(filename, 'r') as file:
                text = file.read()
                stoptrans = str.maketrans('', '', '#\'`\\"')
                try:
                    parsed_files.append([
                        t.value for t in javalang.tokenizer.tokenize(
                            text.translate(stoptrans))
                    ])
                except:
                    print(filename)

        corpus = []
        for file in tqdm_notebook(parsed_files):
            line = []
            for i in range(len(file)):
                if file[i] != ';':
                    line.append(file[i])
                else:
                    corpus.append(line)
                    line = []
        self.corpus = corpus
        self.vocabulary = NgramModelVocabulary(1,
                                               [j for i in corpus for j in i])
        return

    def train_ngrams(self, corpus, n):
        self.ngrams = NgramCounter(n, self.vocabulary)
        self.ngrams.train_counts(self.corpus)
        return

    def create_model(self, model_class, n):
        print('create corpus')
        self.create_corpus(self.get_all_repo_files())
        print('train ngrams')
        self.train_ngrams(self.corpus, n)
        self.model = model_class(self.ngrams)
        print('model is ready')
示例#4
0
    def test_NgramCounter_breaks_given_empty_vocab(self):
        empty_vocab = NgramModelVocabulary(2, "abc")
        empty_counter = NgramCounter(2,
                                     empty_vocab,
                                     pad_left=False,
                                     pad_right=False)

        with self.assertRaises(EmptyVocabularyError) as exc_info:
            empty_counter.train_counts(['ad', 'hominem'])

        self.assertEqual(("Cannot start counting ngrams until "
                          "vocabulary contains more than one item."),
                         str(exc_info.exception))
示例#5
0
class NgramModelBaseTest(unittest.TestCase):
    """Base test class for testing ngram model classes"""
    @classmethod
    def setUpClass(self):
        # The base vocabulary contains 5 items: abcd and UNK
        self.vocab = NgramModelVocabulary(1, "abcd")
        # NgramCounter.vocabulary contains 7 items (+2 for padding symbols)
        self.counter = NgramCounter(2, self.vocab)
        self.counter.train_counts(['abcd', 'egadbe'])

    def total_vocab_score(self, context):
        """Sums up scores for the whole vocabulary given some context.

        Used to make sure they sum to 1.
        Note that we *must* loop over the counter's vocabulary so as to include
        padding symbols.
        """
        return (
            sum(self.model.score(w, context)
                for w in self.counter.vocabulary) +
            self.model.score(self.counter.unk_label, context))
示例#6
0
 def test_NgramCounter_breaks_given_invalid_order(self):
     with self.assertRaises(ValueError) as exc_info:
         NgramCounter(0, self.vocab)
     expected_error_msg = "Order of NgramCounter cannot be less than 1. Got: 0"
     self.assertEqual(str(exc_info.exception), expected_error_msg)
示例#7
0
class NgramCounterTests(unittest.TestCase):
    """Tests NgramCounter class"""
    @classmethod
    def setUpClass(self):
        self.vocab = NgramModelVocabulary(2, "abcdeadbe")

        self.trigram_counter = NgramCounter(3, self.vocab)
        self.trigram_counter.train_counts(['abcd', 'egdbe'])

        self.bigram_counter = NgramCounter(2, self.vocab)
        self.bigram_counter.train_counts(['abcd', 'egdbe'])

    def test_NgramCounter_order_attr(self):
        self.assertEqual(self.trigram_counter.order, 3)

    def test_NgramCounter_breaks_given_invalid_order(self):
        with self.assertRaises(ValueError) as exc_info:
            NgramCounter(0, self.vocab)
        expected_error_msg = "Order of NgramCounter cannot be less than 1. Got: 0"
        self.assertEqual(str(exc_info.exception), expected_error_msg)

    def test_NgramCounter_breaks_given_empty_vocab(self):
        empty_vocab = NgramModelVocabulary(2, "abc")
        empty_counter = NgramCounter(2,
                                     empty_vocab,
                                     pad_left=False,
                                     pad_right=False)

        with self.assertRaises(EmptyVocabularyError) as exc_info:
            empty_counter.train_counts(['ad', 'hominem'])

        self.assertEqual(("Cannot start counting ngrams until "
                          "vocabulary contains more than one item."),
                         str(exc_info.exception))

    def test_check_against_vocab(self):
        unk_label = "<UNK>"

        self.assertEqual("a", self.bigram_counter.check_against_vocab("a"))
        self.assertEqual(unk_label,
                         self.bigram_counter.check_against_vocab("c"))

    def test_ngram_conditional_freqdist(self):
        expected_trigram_contexts = [("<s>", "<s>"), ("<s>", "a"), ("a", "b"),
                                     ("b", "<UNK>"), ("<UNK>", "d"),
                                     ("d", "</s>"), ("<s>", "e"),
                                     ("e", "<UNK>"), ("d", "b"), ("b", "e"),
                                     (
                                         "e",
                                         "</s>",
                                     )]
        expected_bigram_contexts = [("a", ), ("b", ), ("d", ), ("e", ),
                                    ("<UNK>", ), ("<s>", ), ("</s>", )]

        bigrams = self.trigram_counter.ngrams[2]
        trigrams = self.trigram_counter.ngrams[3]

        self.assertCountEqual(expected_bigram_contexts, bigrams.conditions())
        self.assertCountEqual(expected_trigram_contexts, trigrams.conditions())

    def test_bigram_counts_seen_ngrams(self):
        bigrams = self.bigram_counter.ngrams[2]
        b_given_a_count = 1
        unk_given_b_count = 1

        self.assertEqual(b_given_a_count, bigrams[('a', )]['b'])
        self.assertEqual(unk_given_b_count, bigrams[('b', )]['<UNK>'])

    def test_bigram_counts_unseen_ngrams(self):
        bigrams = self.bigram_counter.ngrams[2]
        c_given_b_count = 0

        self.assertEqual(c_given_b_count, bigrams[('b', )]['c'])

    def test_unigram_counts_seen_words(self):
        unigrams = self.bigram_counter.unigrams
        expected_count_b = 2

        self.assertEqual(expected_count_b, unigrams['b'])

    def test_unigram_counts_completely_unseen_words(self):
        unigrams = self.bigram_counter.unigrams
        unseen_count = 0

        self.assertEqual(unseen_count, unigrams['z'])

    def test_unigram_counts_unknown_words(self):
        # The subtle difference between this and "unseen" is that the latter
        # have no counts recorded for them at all and in practice would usually
        # get assigned the "unknown" label
        unigrams = self.bigram_counter.unigrams
        unknown_count = 2

        self.assertEqual(unknown_count, unigrams['<UNK>'])
示例#8
0
 def setUpClass(self):
     # The base vocabulary contains 5 items: abcd and UNK
     self.vocab = NgramModelVocabulary(1, "abcd")
     # NgramCounter.vocabulary contains 7 items (+2 for padding symbols)
     self.counter = NgramCounter(2, self.vocab)
     self.counter.train_counts(['abcd', 'egadbe'])
示例#9
0
 def train_ngrams(self, corpus, n):
     self.ngrams = NgramCounter(n, self.vocabulary)
     self.ngrams.train_counts(self.corpus)
     return