Ejemplo n.º 1
0
    def test_from_file(self):
        self.vocabulary_file.seek(0)
        vocabulary = Vocabulary.from_file(self.vocabulary_file, 'words')
        self.assertEqual(vocabulary.num_words(), 10 + 3)
        self.assertEqual(vocabulary.num_shortlist_words(), 10 + 3)
        self.assertEqual(vocabulary.num_classes(), 10 + 3)

        oos_words = ['yksi', 'kaksi', 'yksitoista', 'kaksitoista']
        self.vocabulary_file.seek(0)
        vocabulary = Vocabulary.from_file(self.vocabulary_file,
                                          'words',
                                          oos_words=oos_words)
        self.assertEqual(vocabulary.num_words(), 12 + 3)
        self.assertEqual(vocabulary.num_shortlist_words(), 10 + 3)
        self.assertEqual(vocabulary.num_classes(), 10 + 3)
Ejemplo n.º 2
0
 def test_get_oos_probs(self):
     oos_words = ['yksitoista', 'kaksitoista']
     self.vocabulary_file.seek(0)
     vocabulary = Vocabulary.from_file(self.vocabulary_file,
                                       'words',
                                       oos_words=oos_words)
     word_counts = {
         'yksi': 1,
         'kaksi': 2,
         'kolme': 3,
         'neljä': 4,
         'viisi': 5,
         'kuusi': 6,
         'seitsemän': 7,
         'kahdeksan': 8,
         'yhdeksän': 9,
         'kymmenen': 10,
         '<s>': 11,
         '</s>': 12,
         '<unk>': 13,
         'yksitoista': 3,
         'kaksitoista': 7
     }
     vocabulary.compute_probs(word_counts)
     oos_logprobs = vocabulary.get_oos_probs()
     assert_almost_equal(oos_logprobs, [
         1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
         0.3, 0.7
     ])
Ejemplo n.º 3
0
    def test_get_class_memberships(self):
        vocabulary = Vocabulary.from_file(self.classes_file,
                                          'srilm-classes',
                                          oos_words=['yksitoista'])
        word_ids = numpy.array([
            vocabulary.word_to_id['yksi'], vocabulary.word_to_id['kaksi'],
            vocabulary.word_to_id['kolme'], vocabulary.word_to_id['neljä'],
            vocabulary.word_to_id['viisi'], vocabulary.word_to_id['kuusi'],
            vocabulary.word_to_id['seitsemän'],
            vocabulary.word_to_id['kahdeksan'],
            vocabulary.word_to_id['yhdeksän'],
            vocabulary.word_to_id['kymmenen'], vocabulary.word_to_id['<s>'],
            vocabulary.word_to_id['</s>'], vocabulary.word_to_id['<unk>']
        ])
        class_ids, probs = vocabulary.get_class_memberships(word_ids)
        assert_equal(class_ids, vocabulary.word_id_to_class_id[word_ids])
        assert_almost_equal(probs, [
            1.0, 0.999, 0.599 / (0.599 + 0.400), 0.400 / (0.599 + 0.400), 1.0,
            0.281 / (0.281 + 0.226 + 0.262 + 0.228), 0.226 /
            (0.281 + 0.226 + 0.262 + 0.228), 0.262 /
            (0.281 + 0.226 + 0.262 + 0.228), 0.228 /
            (0.281 + 0.226 + 0.262 + 0.228), 1.0, 1.0, 0.001, 1.0
        ])

        word_counts = compute_word_counts([self.sentences3_file])
        vocabulary.compute_probs(word_counts)
        class_ids, probs = vocabulary.get_class_memberships(word_ids)
        assert_almost_equal(probs, [
            1.0, 1.0 / 6.0, 0.5, 0.5, 1.0, 0.25, 0.25, 0.25, 0.25, 1.0, 1.0,
            5.0 / 6.0, 1.0
        ])
Ejemplo n.º 4
0
    def test_compute_probs(self):
        self.classes_file.seek(0)
        vocabulary = Vocabulary.from_file(self.classes_file, 'srilm-classes')
        word_counts = compute_word_counts(
            [self.sentences1_file, self.sentences2_file])
        vocabulary.compute_probs(word_counts)

        # 10 * <s> + 10 * </s> + 20 words.
        total_count = 40.0
        word_id = vocabulary.word_to_id['yksi']
        self.assertAlmostEqual(vocabulary._unigram_probs[word_id],
                               2.0 / total_count)
        self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 1.0)
        word_id = vocabulary.word_to_id['kaksi']
        self.assertAlmostEqual(vocabulary._unigram_probs[word_id],
                               2.0 / total_count)
        self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 2.0 / 12.0)
        word_id = vocabulary.word_to_id['kolme']
        self.assertAlmostEqual(vocabulary._unigram_probs[word_id],
                               2.0 / total_count)
        self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 0.5)
        word_id = vocabulary.word_to_id['neljä']
        self.assertAlmostEqual(vocabulary._unigram_probs[word_id],
                               2.0 / total_count)
        self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 0.5)
        word_id = vocabulary.word_to_id['viisi']
        self.assertAlmostEqual(vocabulary._unigram_probs[word_id],
                               2.0 / total_count)
        self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 1.0)
        word_id = vocabulary.word_to_id['kuusi']
        self.assertAlmostEqual(vocabulary._unigram_probs[word_id],
                               2.0 / total_count)
        self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 0.25)
        word_id = vocabulary.word_to_id['seitsemän']
        self.assertAlmostEqual(vocabulary._unigram_probs[word_id],
                               2.0 / total_count)
        self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 0.25)
        word_id = vocabulary.word_to_id['kahdeksan']
        self.assertAlmostEqual(vocabulary._unigram_probs[word_id],
                               2.0 / total_count)
        self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 0.25)
        word_id = vocabulary.word_to_id['yhdeksän']
        self.assertAlmostEqual(vocabulary._unigram_probs[word_id],
                               2.0 / total_count)
        self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 0.25)
        word_id = vocabulary.word_to_id['kymmenen']
        self.assertAlmostEqual(vocabulary._unigram_probs[word_id],
                               2.0 / total_count)
        self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 1.0)
        word_id = vocabulary.word_to_id['<s>']
        self.assertAlmostEqual(vocabulary._unigram_probs[word_id],
                               10.0 / total_count)
        self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 1.0)
        word_id = vocabulary.word_to_id['</s>']
        self.assertAlmostEqual(vocabulary._unigram_probs[word_id],
                               10.0 / total_count)
        self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 10.0 / 12.0)
        word_id = vocabulary.word_to_id['<unk>']
        self.assertAlmostEqual(vocabulary._unigram_probs[word_id], 0.0)
        self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 1.0)
Ejemplo n.º 5
0
 def test_class_ids(self):
     self.classes_file.seek(0)
     vocabulary = Vocabulary.from_file(self.classes_file, 'srilm-classes')
     word_id = vocabulary.word_to_id['yksi']
     yksi_class_id = vocabulary.word_id_to_class_id[word_id]
     word_id = vocabulary.word_to_id['kaksi']
     kaksi_class_id = vocabulary.word_id_to_class_id[word_id]
     word_id = vocabulary.word_to_id['kolme']
     kolme_class_id = vocabulary.word_id_to_class_id[word_id]
     word_id = vocabulary.word_to_id['neljä']
     nelja_class_id = vocabulary.word_id_to_class_id[word_id]
     word_id = vocabulary.word_to_id['</s>']
     eos_class_id = vocabulary.word_id_to_class_id[word_id]
     self.assertNotEqual(yksi_class_id, kaksi_class_id)
     self.assertEqual(kolme_class_id, nelja_class_id)
     self.assertNotEqual(kolme_class_id, eos_class_id)
     self.assertEqual(kaksi_class_id, eos_class_id)
Ejemplo n.º 6
0
    def test_from_state(self):
        self.classes_file.seek(0)
        vocabulary1 = Vocabulary.from_file(self.classes_file, 'srilm-classes')
        word_counts = compute_word_counts(
            [self.sentences1_file, self.sentences2_file])
        vocabulary1.compute_probs(word_counts)

        f = h5py.File('in-memory.h5', driver='core', backing_store=False)
        vocabulary1.get_state(f)
        vocabulary2 = Vocabulary.from_state(f)
        self.assertTrue(
            numpy.array_equal(vocabulary1.id_to_word, vocabulary2.id_to_word))
        self.assertDictEqual(vocabulary1.word_to_id, vocabulary2.word_to_id)
        self.assertTrue(
            numpy.array_equal(vocabulary1.word_id_to_class_id,
                              vocabulary2.word_id_to_class_id))
        self.assertListEqual(list(vocabulary1._word_classes),
                             list(vocabulary2._word_classes))
        self.assertTrue(
            numpy.array_equal(vocabulary1._unigram_probs,
                              vocabulary2._unigram_probs))
Ejemplo n.º 7
0
def main():
    parser = argparse.ArgumentParser(prog='wctool')

    argument_group = parser.add_argument_group("files")
    argument_group.add_argument(
        '--training-set',
        metavar='FILE',
        type=TextFileType('r'),
        nargs='+',
        required=True,
        help='text or .gz files containing training data (one sentence per '
        'line)')
    argument_group.add_argument(
        '--vocabulary',
        metavar='FILE',
        type=TextFileType('r'),
        default=None,
        help='text or .gz file containing a list of words to include in class '
        'forming, and possibly their initial classes')
    argument_group.add_argument(
        '--vocabulary-format',
        metavar='FORMAT',
        type=str,
        default='words',
        help='vocabulary format, one of "words" (one word per line, default), '
        '"classes" (word and class ID per line), "srilm-classes" (class '
        'name, membership probability, and word per line)')
    argument_group.add_argument(
        '--output-file',
        metavar='FILE',
        type=TextFileType('w'),
        default='-',
        help='where to write the word classes (default stdout)')
    argument_group.add_argument(
        '--output-format',
        metavar='FORMAT',
        type=str,
        default='srilm-classes',
        help='format of the output file, one of "classes" (word and class ID '
        'per line), "srilm-classes" (default; class name, membership '
        'probability, and word per line)')
    argument_group.add_argument(
        '--output-frequency',
        metavar='N',
        type=int,
        default='1',
        help='save classes N times per optimization iteration (default 1)')

    argument_group = parser.add_argument_group("optimization")
    argument_group.add_argument(
        '--num-classes',
        metavar='N',
        type=int,
        default=2000,
        help='number of classes to form, if vocabulary is not specified '
        '(default 2000)')
    argument_group.add_argument(
        '--method',
        metavar='NAME',
        type=str,
        default='bigram-theano',
        help='method for creating word classes, one of "bigram-theano", '
        '"bigram-numpy" (default "bigram-theano")')

    argument_group = parser.add_argument_group("logging and debugging")
    argument_group.add_argument(
        '--log-file',
        metavar='FILE',
        type=str,
        default='-',
        help='path where to write log file (default is standard output)')
    argument_group.add_argument(
        '--log-level',
        metavar='LEVEL',
        type=str,
        default='info',
        help='minimum level of events to log, one of "debug", "info", "warn" '
        '(default "info")')
    argument_group.add_argument(
        '--log-interval',
        metavar='N',
        type=int,
        default=1000,
        help='print statistics after every Nth word; quiet if less than one '
        '(default 1000)')

    args = parser.parse_args()

    log_file = args.log_file
    log_level = getattr(logging, args.log_level.upper(), None)
    if not isinstance(log_level, int):
        raise ValueError("Invalid logging level requested: " + args.log_level)
    log_format = '%(asctime)s %(funcName)s: %(message)s'
    if args.log_file == '-':
        logging.basicConfig(stream=sys.stdout,
                            format=log_format,
                            level=log_level)
    else:
        logging.basicConfig(filename=log_file,
                            format=log_format,
                            level=log_level)

    if args.vocabulary is None:
        word_counts = compute_word_counts(args.training_set)
        vocabulary = Vocabulary.from_word_counts(word_counts, args.num_classes)
        for subset_file in args.training_set:
            subset_file.seek(0)
    else:
        vocabulary = Vocabulary.from_file(args.vocabulary,
                                          args.vocabulary_format)

    print("Number of words in vocabulary:", vocabulary.num_shortlist_words())
    print("Number of word classes:", vocabulary.num_classes())
    print("Number of normal word classes:", vocabulary.num_normal_classes)

    logging.info("Reading word unigram and bigram statistics.")
    statistics = BigramStatistics(args.training_set, vocabulary)

    if args.method == 'bigram-theano':
        optimizer = TheanoBigramOptimizer(statistics, vocabulary)
    elif args.method == 'bigram-numpy':
        optimizer = NumpyBigramOptimizer(statistics, vocabulary)
    else:
        raise ValueError("Invalid method requested: " + args.method)

    iteration = 1
    while True:
        logging.info("Starting iteration %d.", iteration)
        num_words = 0
        num_moves = 0
        for word in vocabulary.words():
            start_time = time()
            num_words += 1
            if optimizer.move_to_best_class(word):
                num_moves += 1
            duration = time() - start_time
            if (args.log_interval >= 1) and \
               (num_words % args.log_interval == 0):
                logging.info(
                    "[%d] (%.1f %%) of iteration %d -- moves = %d, cost = %.2f, duration = %.1f ms",
                    num_words,
                    num_words / vocabulary.num_shortlist_words() * 100,
                    iteration, num_moves, optimizer.log_likelihood(),
                    duration * 100)
            if is_scheduled(num_words, args.output_frequency,
                            vocabulary.num_shortlist_words()):
                save(optimizer, args.output_file, args.output_format)

        if num_moves == 0:
            break
        iteration += 1

    logging.info("Optimization finished.")
    save(optimizer, args.output_file, args.output_format)