예제 #1
0
 def test_normalize_probs(self):
     word_class = WordClass(1, 10, 0.5)
     word_class.add(11, 1.0)
     word_class.add(12, 0.5)
     word_class.normalize_probs()
     self.assertEqual(word_class.get_prob(10), 0.25)
     self.assertEqual(word_class.get_prob(11), 0.5)
     self.assertEqual(word_class.get_prob(12), 0.25)
예제 #2
0
def _add_special_tokens(id_to_word, word_id_to_class_id, word_classes):
    """Makes sure that the special symbols ``<s>``, ``</s>``, and ``<unk>``
    exist in the word list ``id_to_word``. If not, creates them and their word
    classes.

    :type id_to_word: list of strs
    :param id_to_word: mapping from word IDs to word names

    :type word_id_to_class_id: list of ints
    :param word_id_to_class_id: mapping from word IDs to indices in
                                ``word_classes``

    :type word_classes: list of WordClass objects
    :param word_classes: list of all the word classes
    """

    if len(id_to_word) != len(word_id_to_class_id):
        raise ValueError(
            "Every word must be assigned to a class before adding "
            "special tokens.")

    if '<s>' not in id_to_word:
        word_id = len(id_to_word)
        assert word_id == len(word_id_to_class_id)
        class_id = len(word_classes)
        id_to_word.append('<s>')
        word_id_to_class_id.append(class_id)
        word_class = WordClass(class_id, word_id, 1.0)
        word_classes.append(word_class)

    if '</s>' not in id_to_word:
        word_id = len(id_to_word)
        assert word_id == len(word_id_to_class_id)
        class_id = len(word_classes)
        id_to_word.append('</s>')
        word_id_to_class_id.append(class_id)
        word_class = WordClass(class_id, word_id, 1.0)
        word_classes.append(word_class)

    if '<unk>' not in id_to_word:
        word_id = len(id_to_word)
        assert word_id == len(word_id_to_class_id)
        class_id = len(word_classes)
        id_to_word.append('<unk>')
        word_id_to_class_id.append(class_id)
        word_class = WordClass(class_id, word_id, 1.0)
        word_classes.append(word_class)
예제 #3
0
    def from_state(cls, state):
        """Reads the vocabulary from a network state.

        :type state: hdf5.File
        :param state: HDF5 file that contains the architecture parameters
        """

        if 'vocabulary' not in state:
            raise IncompatibleStateError(
                "Vocabulary is missing from neural network state.")
        h5_vocabulary = state['vocabulary']

        if 'words' not in h5_vocabulary:
            raise IncompatibleStateError(
                "Vocabulary parameter 'words' is missing from neural network "
                "state.")
        id_to_word = h5_vocabulary['words'].value

        if 'classes' not in h5_vocabulary:
            raise IncompatibleStateError(
                "Vocabulary parameter 'classes' is missing from neural network "
                "state.")
        word_id_to_class_id = h5_vocabulary['classes'].value

        if 'probs' not in h5_vocabulary:
            raise IncompatibleStateError(
                "Vocabulary parameter 'probs' is missing from neural network "
                "state.")
        num_classes = word_id_to_class_id.max() + 1
        word_classes = [None] * num_classes
        h5_probs = h5_vocabulary['probs'].value
        for word_id, prob in enumerate(h5_probs):
            class_id = word_id_to_class_id[word_id]
            if word_classes[class_id] is None:
                word_class = WordClass(class_id, word_id, prob)
                word_classes[class_id] = word_class
            else:
                word_classes[class_id].add(word_id, prob)

        result = cls(id_to_word.tolist(),
                     word_id_to_class_id.tolist(),
                     word_classes)

        if 'unigram_probs' in h5_vocabulary:
            result._unigram_probs = h5_vocabulary['unigram_probs'].value
            if len(result._unigram_probs) != result.num_words():
                raise IncompatibleStateError(
                    "Incorrect number of word unigram probabilities in neural "
                    "network state.")
            oos_probs = result._unigram_probs[result.num_shortlist_words():]
            if oos_probs.size:
                logging.debug("Out-of-shortlist word log probabilities are in "
                              "the range [%f, %f].",
                              numpy.log(oos_probs.min()),
                              numpy.log(oos_probs.max()))
        else:
            logging.debug("Word unigram probabilities are missing from state.")

        return result
예제 #4
0
    def from_word_counts(cls, word_counts, num_classes=None):
        """Creates a vocabulary and classes from word counts. All the words will
        be in the shortlist.

        If ``num_classes`` is specified, words will be assigned to classes using
        modulo arithmetic. The class membership probabilities will be computed
        from the word counts.

        :type word_counts: dict
        :param word_counts: dictionary from words to the number of occurrences
                            in the corpus

        :type num_classes: int
        :param num_classes: number of classes to create in addition to the
                            special classes, or None for one class per word
        """

        # The special tokens should not be included when creating the classes.
        # They are added to separate classes in the end.
        word_counts = dict(word_counts)
        if '<s>' in word_counts:
            del word_counts['<s>']
        if '</s>' in word_counts:
            del word_counts['</s>']
        if '<unk>' in word_counts:
            del word_counts['<unk>']

        id_to_word = []
        word_id_to_class_id = []
        word_classes = []

        if num_classes is None:
            num_classes = len(word_counts)

        class_id = 0
        for word, _ in sorted(word_counts.items(),
                              key=lambda x: x[1]):
            word_id = len(id_to_word)
            id_to_word.append(word)

            if class_id < len(word_classes):
                word_classes[class_id].add(word_id, 1.0)
            else:
                assert class_id == len(word_classes)
                word_class = WordClass(class_id, word_id, 1.0)
                word_classes.append(word_class)

            assert word_id == len(word_id_to_class_id)
            word_id_to_class_id.append(class_id)
            class_id = (class_id + 1) % num_classes

        _add_special_tokens(id_to_word, word_id_to_class_id, word_classes)

        result = cls(id_to_word, word_id_to_class_id, word_classes)
        result.compute_probs(word_counts, update_class_probs=True)
        return result
예제 #5
0
    def from_file(cls, input_file, input_format, oos_words=None):
        """Reads the shortlist words and possibly word classes from a vocabulary
        file.

        ``input_format`` is one of:

        * "words": ``input_file`` contains one word per line. Each word will be
                   assigned to its own class.
        * "classes": ``input_file`` contains a word followed by whitespace
                     followed by class ID on each line. Each word will be
                     assigned to the specified class. The class IDs can be
                     anything; they will be translated to consecutive numbers
                     after reading the file.
        * "srilm-classes": ``input_file`` contains a class name, membership
                           probability, and word, separated by whitespace, on
                           each line.

        The words read from the vocabulary file are put in the shortlist. If
        ``oos_words`` is given, those words are given an ID and added to the
        vocabulary as out-of-shortlist words if they don't exist in the
        vocabulary file.

        :type input_file: file object
        :param input_file: input vocabulary file

        :type input_format: str
        :param input_format: format of the input vocabulary file, "words",
                             "classes", or "srilm-classes"

        :type oos_words: list of strs
        :param oos_words: add words from this list to the vocabulary as
                          out-of-shortlist words, if they're not in the
                          vocabulary file
        """

        # We have also a set of the words just for faster checking if a word has
        # already been encountered.
        words = set()
        id_to_word = []
        word_id_to_class_id = []
        word_classes = []
        # Mapping from the IDs in the file to our internal class IDs.
        file_id_to_class_id = dict()

        for line in input_file:
            line = line.strip()
            fields = line.split()
            if not fields:
                continue
            if input_format == 'words' and len(fields) == 1:
                word = fields[0]
                file_id = None
                prob = 1.0
            elif input_format == 'classes' and len(fields) == 2:
                word = fields[0]
                file_id = int(fields[1])
                prob = 1.0
            elif input_format == 'srilm-classes' and len(fields) == 3:
                file_id = fields[0]
                prob = float(fields[1])
                word = fields[2]
            else:
                raise InputError(
                    "%d fields on one line of vocabulary file: %s" %
                    (len(fields), line))

            if word in words:
                raise InputError("Word `%s´ appears more than once in the "
                                 "vocabulary file." % word)
            words.add(word)
            word_id = len(id_to_word)
            id_to_word.append(word)

            if file_id in file_id_to_class_id:
                class_id = file_id_to_class_id[file_id]
                word_classes[class_id].add(word_id, prob)
            else:
                # No ID in the file or a new ID.
                class_id = len(word_classes)
                word_class = WordClass(class_id, word_id, prob)
                word_classes.append(word_class)
                if file_id is not None:
                    file_id_to_class_id[file_id] = class_id

            assert word_id == len(word_id_to_class_id)
            word_id_to_class_id.append(class_id)

        _add_special_tokens(id_to_word, word_id_to_class_id, word_classes)
        words |= {'<s>', '</s>', '<unk>'}

        if oos_words is not None:
            for word in oos_words:
                if word not in words:
                    words.add(word)
                    id_to_word.append(word)

        return cls(id_to_word, word_id_to_class_id, word_classes)