コード例 #1
0
    def test_get_class_memberships(self):
        vocabulary = Vocabulary.from_file(self.classes_file,
                                          'srilm-classes',
                                          oos_words=['yksitoista'])
        word_ids = numpy.array([
            vocabulary.word_to_id['yksi'], vocabulary.word_to_id['kaksi'],
            vocabulary.word_to_id['kolme'], vocabulary.word_to_id['neljä'],
            vocabulary.word_to_id['viisi'], vocabulary.word_to_id['kuusi'],
            vocabulary.word_to_id['seitsemän'],
            vocabulary.word_to_id['kahdeksan'],
            vocabulary.word_to_id['yhdeksän'],
            vocabulary.word_to_id['kymmenen'], vocabulary.word_to_id['<s>'],
            vocabulary.word_to_id['</s>'], vocabulary.word_to_id['<unk>']
        ])
        class_ids, probs = vocabulary.get_class_memberships(word_ids)
        assert_equal(class_ids, vocabulary.word_id_to_class_id[word_ids])
        assert_almost_equal(probs, [
            1.0, 0.999, 0.599 / (0.599 + 0.400), 0.400 / (0.599 + 0.400), 1.0,
            0.281 / (0.281 + 0.226 + 0.262 + 0.228), 0.226 /
            (0.281 + 0.226 + 0.262 + 0.228), 0.262 /
            (0.281 + 0.226 + 0.262 + 0.228), 0.228 /
            (0.281 + 0.226 + 0.262 + 0.228), 1.0, 1.0, 0.001, 1.0
        ])

        word_counts = compute_word_counts([self.sentences3_file])
        vocabulary.compute_probs(word_counts)
        class_ids, probs = vocabulary.get_class_memberships(word_ids)
        assert_almost_equal(probs, [
            1.0, 1.0 / 6.0, 0.5, 0.5, 1.0, 0.25, 0.25, 0.25, 0.25, 1.0, 1.0,
            5.0 / 6.0, 1.0
        ])
コード例 #2
0
    def test_compute_probs(self):
        self.classes_file.seek(0)
        vocabulary = Vocabulary.from_file(self.classes_file, 'srilm-classes')
        word_counts = compute_word_counts(
            [self.sentences1_file, self.sentences2_file])
        vocabulary.compute_probs(word_counts)

        # 10 * <s> + 10 * </s> + 20 words.
        total_count = 40.0
        word_id = vocabulary.word_to_id['yksi']
        self.assertAlmostEqual(vocabulary._unigram_probs[word_id],
                               2.0 / total_count)
        self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 1.0)
        word_id = vocabulary.word_to_id['kaksi']
        self.assertAlmostEqual(vocabulary._unigram_probs[word_id],
                               2.0 / total_count)
        self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 2.0 / 12.0)
        word_id = vocabulary.word_to_id['kolme']
        self.assertAlmostEqual(vocabulary._unigram_probs[word_id],
                               2.0 / total_count)
        self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 0.5)
        word_id = vocabulary.word_to_id['neljä']
        self.assertAlmostEqual(vocabulary._unigram_probs[word_id],
                               2.0 / total_count)
        self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 0.5)
        word_id = vocabulary.word_to_id['viisi']
        self.assertAlmostEqual(vocabulary._unigram_probs[word_id],
                               2.0 / total_count)
        self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 1.0)
        word_id = vocabulary.word_to_id['kuusi']
        self.assertAlmostEqual(vocabulary._unigram_probs[word_id],
                               2.0 / total_count)
        self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 0.25)
        word_id = vocabulary.word_to_id['seitsemän']
        self.assertAlmostEqual(vocabulary._unigram_probs[word_id],
                               2.0 / total_count)
        self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 0.25)
        word_id = vocabulary.word_to_id['kahdeksan']
        self.assertAlmostEqual(vocabulary._unigram_probs[word_id],
                               2.0 / total_count)
        self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 0.25)
        word_id = vocabulary.word_to_id['yhdeksän']
        self.assertAlmostEqual(vocabulary._unigram_probs[word_id],
                               2.0 / total_count)
        self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 0.25)
        word_id = vocabulary.word_to_id['kymmenen']
        self.assertAlmostEqual(vocabulary._unigram_probs[word_id],
                               2.0 / total_count)
        self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 1.0)
        word_id = vocabulary.word_to_id['<s>']
        self.assertAlmostEqual(vocabulary._unigram_probs[word_id],
                               10.0 / total_count)
        self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 1.0)
        word_id = vocabulary.word_to_id['</s>']
        self.assertAlmostEqual(vocabulary._unigram_probs[word_id],
                               10.0 / total_count)
        self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 10.0 / 12.0)
        word_id = vocabulary.word_to_id['<unk>']
        self.assertAlmostEqual(vocabulary._unigram_probs[word_id], 0.0)
        self.assertAlmostEqual(vocabulary.get_word_prob(word_id), 1.0)
コード例 #3
0
    def test_bigram_statistics(self):
        self.sentences_file.seek(0)
        word_counts = compute_word_counts([self.sentences_file])
        self.vocabulary = Vocabulary.from_word_counts(word_counts)
        self.sentences_file.seek(0)
        statistics = BigramStatistics([self.sentences_file], self.vocabulary)

        unigram_counts = statistics.unigram_counts
        vocabulary = self.vocabulary
        self.assertEqual(unigram_counts[vocabulary.word_to_id['a']], 13)
        self.assertEqual(unigram_counts[vocabulary.word_to_id['b']], 8)
        self.assertEqual(unigram_counts[vocabulary.word_to_id['c']], 8)
        self.assertEqual(unigram_counts[vocabulary.word_to_id['d']], 11)
        self.assertEqual(unigram_counts[vocabulary.word_to_id['e']], 15)
        self.assertEqual(unigram_counts[vocabulary.word_to_id['<unk>']], 0)
        self.assertEqual(unigram_counts[vocabulary.word_to_id['<s>']], 11)
        self.assertEqual(unigram_counts[vocabulary.word_to_id['</s>']], 11)

        bigram_counts = statistics.bigram_counts
        vocabulary = self.vocabulary
        a_id = vocabulary.word_to_id['a']
        b_id = vocabulary.word_to_id['b']
        self.assertEqual(bigram_counts[a_id, a_id], 3)
        self.assertEqual(bigram_counts[a_id, b_id], 2)
        self.assertEqual(bigram_counts[b_id, a_id], 1)
        self.assertEqual(bigram_counts[b_id, b_id], 0)
コード例 #4
0
 def test_compute_word_counts(self):
     self.sentences_file.seek(0)
     word_counts = compute_word_counts([self.sentences_file])
     self.assertEqual(word_counts['a'], 13)
     self.assertEqual(word_counts['b'], 8)
     self.assertEqual(word_counts['c'], 8)
     self.assertEqual(word_counts['d'], 11)
     self.assertEqual(word_counts['e'], 15)
     self.assertEqual(word_counts['<s>'], 11)
     self.assertEqual(word_counts['</s>'], 11)
コード例 #5
0
 def setUp(self):
     script_path = os.path.dirname(os.path.realpath(__file__))
     sentences_path = os.path.join(script_path, 'sentences.txt')
     self.sentences_file = open(sentences_path)
     self.num_classes = 2
     word_counts = compute_word_counts([self.sentences_file])
     self.vocabulary = Vocabulary.from_word_counts(word_counts,
                                                   self.num_classes)
     self.sentences_file.seek(0)
     self.statistics = BigramStatistics([self.sentences_file], self.vocabulary)
コード例 #6
0
    def test_from_word_counts(self):
        self.sentences1_file.seek(0)
        word_counts = compute_word_counts([self.sentences1_file])
        vocabulary = Vocabulary.from_word_counts(word_counts)
        self.assertEqual(vocabulary.num_words(), 10 + 3)
        self.assertEqual(vocabulary.num_shortlist_words(), 10 + 3)
        self.assertEqual(vocabulary.num_normal_classes, 10)
        self.assertEqual(vocabulary.num_classes(), 10 + 3)

        self.sentences1_file.seek(0)
        self.sentences2_file.seek(0)
        word_counts = compute_word_counts(
            [self.sentences1_file, self.sentences2_file])
        vocabulary = Vocabulary.from_word_counts(word_counts, 3)
        self.assertEqual(vocabulary.num_words(), 10 + 3)
        self.assertEqual(vocabulary.num_shortlist_words(), 10 + 3)
        self.assertEqual(vocabulary.num_normal_classes, 3)
        self.assertEqual(vocabulary.num_classes(), 3 + 3)

        sos_id = vocabulary.word_to_id['<s>']
        eos_id = vocabulary.word_to_id['</s>']
        unk_id = vocabulary.word_to_id['<unk>']
        self.assertEqual(sos_id, 10)
        self.assertEqual(eos_id, 11)
        self.assertEqual(unk_id, 12)
        self.assertEqual(vocabulary.word_id_to_class_id[sos_id], 3)
        self.assertEqual(vocabulary.word_id_to_class_id[eos_id], 4)
        self.assertEqual(vocabulary.word_id_to_class_id[unk_id], 5)
        word_ids = set()
        class_ids = set()
        for word in vocabulary.words():
            if not word.startswith('<'):
                word_id = vocabulary.word_to_id[word]
                word_ids.add(word_id)
                class_ids.add(vocabulary.word_id_to_class_id[word_id])
        self.assertEqual(word_ids, set(range(10)))
        self.assertEqual(class_ids, set(range(3)))
コード例 #7
0
    def test_from_state(self):
        self.classes_file.seek(0)
        vocabulary1 = Vocabulary.from_file(self.classes_file, 'srilm-classes')
        word_counts = compute_word_counts(
            [self.sentences1_file, self.sentences2_file])
        vocabulary1.compute_probs(word_counts)

        f = h5py.File('in-memory.h5', driver='core', backing_store=False)
        vocabulary1.get_state(f)
        vocabulary2 = Vocabulary.from_state(f)
        self.assertTrue(
            numpy.array_equal(vocabulary1.id_to_word, vocabulary2.id_to_word))
        self.assertDictEqual(vocabulary1.word_to_id, vocabulary2.word_to_id)
        self.assertTrue(
            numpy.array_equal(vocabulary1.word_id_to_class_id,
                              vocabulary2.word_id_to_class_id))
        self.assertListEqual(list(vocabulary1._word_classes),
                             list(vocabulary2._word_classes))
        self.assertTrue(
            numpy.array_equal(vocabulary1._unigram_probs,
                              vocabulary2._unigram_probs))
コード例 #8
0
ファイル: train.py プロジェクト: haiphong129/theanolm
def _read_vocabulary(args, state):
    """If ``state`` contains data, reads the vocabulary from the HDF5 state.
    Otherwise reads a vocabulary file or constructs the vocabulary from the
    training set and writes it to the HDF5 state.

    If the state does not contain data and --vocabulary argument is given, reads
    the vocabulary from the file given after the argument. The rest of the words
    in the training set will be added as out-of-shortlist words.

    If the state does not contain data and no vocabulary is given, constructs a
    vocabulary that contains all the training set words. In that case,
    --num-classes argument can be used to control the number of classes.

    :type args: argparse.Namespace
    :param args: a collection of command line arguments

    :type state: hdf5.File
    :param state: HDF5 file where the vocabulary should be saved

    :rtype: Vocabulary
    :returns: the created vocabulary
    """

    if state.keys():
        print("Reading vocabulary from existing network state.")
        sys.stdout.flush()
        result = Vocabulary.from_state(state)
        if not result.has_unigram_probs():
            # This is for backward compatibility. Remove at some point.
            print("Computing unigram word probabilities from training set.")
            sys.stdout.flush()
            word_counts = compute_word_counts(args.training_set)
            shortlist_words = list(result.id_to_word)
            shortlist_set = set(shortlist_words)
            oos_words = [
                x for x in word_counts.keys() if x not in shortlist_set
            ]
            result.id_to_word = numpy.asarray(shortlist_words + oos_words,
                                              dtype=object)
            result.word_to_id = {
                word: word_id
                for word_id, word in enumerate(result.id_to_word)
            }
            result.compute_probs(word_counts, update_class_probs=False)
            result.get_state(state)

    elif args.vocabulary is None:
        print("Constructing vocabulary from training set.")
        sys.stdout.flush()
        word_counts = compute_word_counts(args.training_set)
        result = Vocabulary.from_word_counts(word_counts, args.num_classes)
        result.get_state(state)

    else:
        print("Reading vocabulary from {}.".format(args.vocabulary))
        sys.stdout.flush()
        word_counts = compute_word_counts(args.training_set)
        oos_words = word_counts.keys()
        with open(args.vocabulary, 'rt', encoding='utf-8') as vocab_file:
            result = Vocabulary.from_file(vocab_file,
                                          args.vocabulary_format,
                                          oos_words=oos_words)

        if args.vocabulary_format == 'classes':
            print("Computing class membership probabilities and unigram "
                  "probabilities for out-of-shortlist words.")
            sys.stdout.flush()
            update_class_probs = True
        else:
            print(
                "Computing unigram probabilities for out-of-shortlist words.")
            sys.stdout.flush()
            update_class_probs = False
        result.compute_probs(word_counts,
                             update_class_probs=update_class_probs)
        result.get_state(state)

    print("Number of words in vocabulary:", result.num_words())
    print("Number of words in shortlist:", result.num_shortlist_words())
    print("Number of word classes:", result.num_classes())
    return result
コード例 #9
0
ファイル: wctool.py プロジェクト: haiphong129/theanolm
def main():
    parser = argparse.ArgumentParser(prog='wctool')

    argument_group = parser.add_argument_group("files")
    argument_group.add_argument(
        '--training-set',
        metavar='FILE',
        type=TextFileType('r'),
        nargs='+',
        required=True,
        help='text or .gz files containing training data (one sentence per '
        'line)')
    argument_group.add_argument(
        '--vocabulary',
        metavar='FILE',
        type=TextFileType('r'),
        default=None,
        help='text or .gz file containing a list of words to include in class '
        'forming, and possibly their initial classes')
    argument_group.add_argument(
        '--vocabulary-format',
        metavar='FORMAT',
        type=str,
        default='words',
        help='vocabulary format, one of "words" (one word per line, default), '
        '"classes" (word and class ID per line), "srilm-classes" (class '
        'name, membership probability, and word per line)')
    argument_group.add_argument(
        '--output-file',
        metavar='FILE',
        type=TextFileType('w'),
        default='-',
        help='where to write the word classes (default stdout)')
    argument_group.add_argument(
        '--output-format',
        metavar='FORMAT',
        type=str,
        default='srilm-classes',
        help='format of the output file, one of "classes" (word and class ID '
        'per line), "srilm-classes" (default; class name, membership '
        'probability, and word per line)')
    argument_group.add_argument(
        '--output-frequency',
        metavar='N',
        type=int,
        default='1',
        help='save classes N times per optimization iteration (default 1)')

    argument_group = parser.add_argument_group("optimization")
    argument_group.add_argument(
        '--num-classes',
        metavar='N',
        type=int,
        default=2000,
        help='number of classes to form, if vocabulary is not specified '
        '(default 2000)')
    argument_group.add_argument(
        '--method',
        metavar='NAME',
        type=str,
        default='bigram-theano',
        help='method for creating word classes, one of "bigram-theano", '
        '"bigram-numpy" (default "bigram-theano")')

    argument_group = parser.add_argument_group("logging and debugging")
    argument_group.add_argument(
        '--log-file',
        metavar='FILE',
        type=str,
        default='-',
        help='path where to write log file (default is standard output)')
    argument_group.add_argument(
        '--log-level',
        metavar='LEVEL',
        type=str,
        default='info',
        help='minimum level of events to log, one of "debug", "info", "warn" '
        '(default "info")')
    argument_group.add_argument(
        '--log-interval',
        metavar='N',
        type=int,
        default=1000,
        help='print statistics after every Nth word; quiet if less than one '
        '(default 1000)')

    args = parser.parse_args()

    log_file = args.log_file
    log_level = getattr(logging, args.log_level.upper(), None)
    if not isinstance(log_level, int):
        raise ValueError("Invalid logging level requested: " + args.log_level)
    log_format = '%(asctime)s %(funcName)s: %(message)s'
    if args.log_file == '-':
        logging.basicConfig(stream=sys.stdout,
                            format=log_format,
                            level=log_level)
    else:
        logging.basicConfig(filename=log_file,
                            format=log_format,
                            level=log_level)

    if args.vocabulary is None:
        word_counts = compute_word_counts(args.training_set)
        vocabulary = Vocabulary.from_word_counts(word_counts, args.num_classes)
        for subset_file in args.training_set:
            subset_file.seek(0)
    else:
        vocabulary = Vocabulary.from_file(args.vocabulary,
                                          args.vocabulary_format)

    print("Number of words in vocabulary:", vocabulary.num_shortlist_words())
    print("Number of word classes:", vocabulary.num_classes())
    print("Number of normal word classes:", vocabulary.num_normal_classes)

    logging.info("Reading word unigram and bigram statistics.")
    statistics = BigramStatistics(args.training_set, vocabulary)

    if args.method == 'bigram-theano':
        optimizer = TheanoBigramOptimizer(statistics, vocabulary)
    elif args.method == 'bigram-numpy':
        optimizer = NumpyBigramOptimizer(statistics, vocabulary)
    else:
        raise ValueError("Invalid method requested: " + args.method)

    iteration = 1
    while True:
        logging.info("Starting iteration %d.", iteration)
        num_words = 0
        num_moves = 0
        for word in vocabulary.words():
            start_time = time()
            num_words += 1
            if optimizer.move_to_best_class(word):
                num_moves += 1
            duration = time() - start_time
            if (args.log_interval >= 1) and \
               (num_words % args.log_interval == 0):
                logging.info(
                    "[%d] (%.1f %%) of iteration %d -- moves = %d, cost = %.2f, duration = %.1f ms",
                    num_words,
                    num_words / vocabulary.num_shortlist_words() * 100,
                    iteration, num_moves, optimizer.log_likelihood(),
                    duration * 100)
            if is_scheduled(num_words, args.output_frequency,
                            vocabulary.num_shortlist_words()):
                save(optimizer, args.output_file, args.output_format)

        if num_moves == 0:
            break
        iteration += 1

    logging.info("Optimization finished.")
    save(optimizer, args.output_file, args.output_format)