示例#1
0
    def _read_slf_header(self, fields):
        """Reads SLF lattice header fields and saves them in member variables.

        :type fields: list of strs
        :param fields: fields, such as name="value"
        """

        for field in fields:
            name, value = _split_slf_field(field)
            if (name == 'UTTERANCE') or (name == 'U'):
                self.utterance_id = value
            elif (name == 'SUBLAT') or (name == 'S'):
                raise InputError("Sub-lattices are not supported.")
            elif name == 'base':
                value = numpy.float64(value)
                if value == 0.0:
                    self._log_scale = None
                else:
                    self._log_scale = logprob_type(numpy.log(value))
            elif name == 'lmscale':
                self.lm_scale = logprob_type(value)
            elif name == 'wdpenalty':
                self.wi_penalty = logprob_type(value)
            elif name == 'start':
                self._initial_node_id = int(value)
            elif name == 'end':
                self._final_node_id = int(value)
            elif (name == 'NODES') or (name == 'N'):
                self._num_nodes = int(value)
            elif (name == 'LINKS') or (name == 'L'):
                self._num_links = int(value)
示例#2
0
    def _layer_options_from_description(self, description):
        """Creates layer options based on textual architecture description.

        Most of the fields in a layer description are kept as strings. The field
        ``input_layers`` is converted to a list of actual layers found from
        ``self.layers``.

        :type description: dict
        :param description: dictionary of textual layer fields

        :rtype: dict
        :result: layer options
        """

        result = dict()
        for variable, value in description.items():
            if variable == 'inputs':
                try:
                    result['input_layers'] = [self.layers[x] for x in value]
                except KeyError as e:
                    raise InputError("Input layer `{}' does not exist, when "
                                     "creating layer `{}'.".format(
                                         e.args[0], description['name']))
            else:
                result[variable] = value
        return result
示例#3
0
    def sorted_nodes(self):
        """Sorts nodes topologically, then by time.

        Returns a list which contains the nodes in sorted order. Uses the Kahn's
        algorithm to sort the nodes topologically, but always picks the node
        from the queue that has the lowest time stamp, if the nodes contain time
        stamps.
        """

        result = []
        # A queue of nodes to be visited next:
        node_queue = [self.initial_node]
        # The number of incoming links not traversed yet:
        in_degrees = [len(node.in_links) for node in self.nodes]
        while node_queue:
            node = node_queue.pop()
            result.append(node)
            for link in node.out_links:
                next_node = link.end_node
                in_degrees[next_node.id] -= 1
                if in_degrees[next_node.id] == 0:
                    node_queue.append(next_node)
                    node_queue.sort(key=lambda x: (x.time is None, x.time),
                                    reverse=True)
                elif in_degrees[next_node.id] < 0:
                    raise InputError("Word lattice contains a cycle.")

        if len(result) < len(self.nodes):
            logging.warning("Word lattice contains unreachable nodes.")
        else:
            assert len(result) == len(self.nodes)

        return result
示例#4
0
    def _read_slf_link(self, link_id, fields):
        """Reads SLF lattice link fields and creates such link.

        :type link_id: int
        :param link_id: ID of the link

        :type fields: list of strs
        :param fields: the rest of the link fields after ID
        """

        start_node = None
        end_node = None
        word = None
        ac_logprob = None
        lm_logprob = None

        for field in fields:
            name, value = _split_slf_field(field)
            if (name == 'START') or (name == 'S'):
                start_node = self.nodes[int(value)]
            elif (name == 'END') or (name == 'E'):
                end_node = self.nodes[int(value)]
            elif (name == 'WORD') or (name == 'W'):
                word = value
            elif (name == 'acoustic') or (name == 'a'):
                if self._log_scale is None:
                    ac_logprob = logprob_type(numpy.log(numpy.float64(value)))
                else:
                    ac_logprob = logprob_type(value) * self._log_scale
            elif (name == 'language') or (name == 'l'):
                if self._log_scale is None:
                    lm_logprob = logprob_type(numpy.log(numpy.float64(value)))
                else:
                    lm_logprob = logprob_type(value) * self._log_scale

        if start_node is None:
            raise InputError(
                "Start node is not specified for link {}.".format(link_id))
        if end_node is None:
            raise InputError(
                "End node is not specified for link {}.".format(link_id))
        link = self._add_link(start_node, end_node)
        link.word = word
        link.ac_logprob = ac_logprob
        link.lm_logprob = lm_logprob
示例#5
0
 def visit_link(link):
     end_node = link.end_node
     if hasattr(end_node, 'word') and isinstance(end_node.word, str):
         if link.word is None:
             link.word = end_node.word
         else:
             raise InputError("SLF lattice contains words both in nodes "
                              "and links.")
     if not end_node.id in visited:
         visited.add(end_node.id)
         for next_link in end_node.out_links:
             visit_link(next_link)
示例#6
0
 def visit_link(link):
     """A function that is called recursively to move a word from the
     link end node to the link.
     """
     end_node = link.end_node
     if hasattr(end_node, 'word') and isinstance(end_node.word, str):
         if link.word is None:
             link.word = end_node.word
         else:
             raise InputError(
                 "SLF lattice contains words both in nodes "
                 "and links.")
     if end_node.id not in visited:
         visited.add(end_node.id)
         for next_link in end_node.out_links:
             visit_link(next_link)
示例#7
0
def _split_slf_field(field):
    """Parses the name and value from an SLF lattice field.

    :type field: str
    :param field: a field, such as 'UTTERANCE=utterance 123'

    :rtype: tuple of two strs
    :returns: the name and value of the field
    """

    name_value = field.split('=', 1)
    if len(name_value) != 2:
        raise InputError(
            "Expected '=' in SLF lattice field: '{}'".format(field))
    name = name_value[0]
    value = name_value[1]
    return name, value
示例#8
0
    def from_description(classname, description_file):
        """Reads a description of the network architecture from a text file.

        :type description_file: file or file-like object
        :param description_file: text file containing the description

        :rtype: Network.Architecture
        :returns: an object describing the network architecture
        """

        inputs = []
        layers = []

        for line in description_file:
            fields = line.split()
            if not fields:
                continue

            if fields[0] == 'input':
                input_description = dict()
                for field in fields[1:]:
                    parts = field.split('=', 1)
                    if len(parts) != 2:
                        raise InputError(
                            "'field=value' expected but '{}' found in an input "
                            "description in '{}'.".format(
                                field, description_file.name))
                    variable, value = parts
                    input_description[variable] = value
                if not 'type' in input_description:
                    raise InputError(
                        "'type' is not given in an input description in '{}'.".
                        format(description_file.name))
                if not 'name' in input_description:
                    raise InputError(
                        "'name' is not given in an input description in '{}'.".
                        format(description_file.name))
                inputs.append(input_description)

            elif fields[0] == 'layer':
                layer_description = {'inputs': []}
                for field in fields[1:]:
                    parts = field.split('=', 1)
                    if len(parts) != 2:
                        raise InputError(
                            "'field=value' expected but '{}' found in a layer "
                            "description in '{}'.".format(
                                field, description_file.name))
                    variable, value = parts
                    if variable == 'size':
                        layer_description[variable] = int(value)
                    elif variable == 'input':
                        layer_description['inputs'].append(value)
                    else:
                        layer_description[variable] = value
                if not 'type' in layer_description:
                    raise InputError(
                        "'type' is not given in a layer description in '{}'.".
                        format(description_file.name))
                if not 'name' in layer_description:
                    raise InputError(
                        "'name' is not given in a layer description in '{}'.".
                        format(description_file.name))
                if not layer_description['inputs']:
                    raise InputError(
                        "'input' is not given in a layer description in '{}'.".
                        format(description_file.name))
                layers.append(layer_description)

            else:
                raise InputError("Invalid element in architecture "
                                 "description: {}".format(fields[0]))

        if not inputs:
            raise InputError("Architecture description contains no inputs.")
        if not layers:
            raise InputError("Architecture description contains no layers.")

        return classname(inputs, layers)
示例#9
0
    def from_file(cls, input_file, input_format, oos_words=None):
        """Reads the shortlist words and possibly word classes from a vocabulary
        file.

        ``input_format`` is one of:

        * "words": ``input_file`` contains one word per line. Each word will be
                   assigned to its own class.
        * "classes": ``input_file`` contains a word followed by whitespace
                     followed by class ID on each line. Each word will be
                     assigned to the specified class. The class IDs can be
                     anything; they will be translated to consecutive numbers
                     after reading the file.
        * "srilm-classes": ``input_file`` contains a class name, membership
                           probability, and word, separated by whitespace, on
                           each line.

        The words read from the vocabulary file are put in the shortlist. If
        ``oos_words`` is given, those words are given an ID and added to the
        vocabulary as out-of-shortlist words if they don't exist in the
        vocabulary file.

        :type input_file: file object
        :param input_file: input vocabulary file

        :type input_format: str
        :param input_format: format of the input vocabulary file, "words",
                             "classes", or "srilm-classes"

        :type oos_words: list of strs
        :param oos_words: add words from this list to the vocabulary as
                          out-of-shortlist words, if they're not in the
                          vocabulary file
        """

        # We have also a set of the words just for faster checking if a word has
        # already been encountered.
        words = set()
        id_to_word = []
        word_id_to_class_id = []
        word_classes = []
        # Mapping from the IDs in the file to our internal class IDs.
        file_id_to_class_id = dict()

        for line in input_file:
            line = line.strip()
            fields = line.split()
            if not fields:
                continue
            if input_format == 'words' and len(fields) == 1:
                word = fields[0]
                file_id = None
                prob = 1.0
            elif input_format == 'classes' and len(fields) == 2:
                word = fields[0]
                file_id = int(fields[1])
                prob = 1.0
            elif input_format == 'srilm-classes' and len(fields) == 3:
                file_id = fields[0]
                prob = float(fields[1])
                word = fields[2]
            else:
                raise InputError(
                    "%d fields on one line of vocabulary file: %s" %
                    (len(fields), line))

            if word in words:
                raise InputError("Word `%s' appears more than once in the "
                                 "vocabulary file." % word)
            words.add(word)
            word_id = len(id_to_word)
            id_to_word.append(word)

            if file_id in file_id_to_class_id:
                class_id = file_id_to_class_id[file_id]
                word_classes[class_id].add(word_id, prob)
            else:
                # No ID in the file or a new ID.
                class_id = len(word_classes)
                word_class = WordClass(class_id, word_id, prob)
                word_classes.append(word_class)
                if file_id is not None:
                    file_id_to_class_id[file_id] = class_id

            assert word_id == len(word_id_to_class_id)
            word_id_to_class_id.append(class_id)

        _add_special_tokens(id_to_word, word_id_to_class_id, word_classes)
        words |= {'<s>', '</s>', '<unk>'}

        if oos_words is not None:
            for word in oos_words:
                if word not in words:
                    words.add(word)
                    id_to_word.append(word)

        return cls(id_to_word, word_id_to_class_id, word_classes)
示例#10
0
    def __init__(self, lattice_file):
        """Reads an SLF lattice file.

        If ``lattice_file`` is ``None``, creates an empty lattice (useful for
        testing).

        :type lattice_file: file object
        :param lattice_file: a file in SLF lattice format
        """

        super().__init__()

        # No log conversion by default. "None" means the lattice file uses
        # linear probabilities.
        self._log_scale = logprob_type(1.0)

        self._initial_node_id = None
        self._final_node_id = None

        if lattice_file is None:
            self._num_nodes = 0
            self._num_links = 0
            return

        self._num_nodes = None
        self._num_links = None
        for line in lattice_file:
            fields = _split_slf_line(line)
            self._read_slf_header(fields)
            if (self._num_nodes is not None) and (self._num_links is not None):
                break

        if self.wi_penalty is not None:
            if self._log_scale is None:
                self.wi_penalty = numpy.log(self.wi_penalty)
            else:
                self.wi_penalty *= self._log_scale

        self.nodes = [self.Node(node_id) for node_id in range(self._num_nodes)]

        for line in lattice_file:
            fields = _split_slf_line(line)
            if not fields:
                continue
            name, value = _split_slf_field(fields[0])
            if name == 'I':
                self._read_slf_node(int(value), fields[1:])
            elif name == 'J':
                self._read_slf_link(int(value), fields[1:])

        if len(self.links) != self._num_links:
            raise InputError(
                "Number of links in SLF lattice doesn't match the "
                "LINKS field.")

        if self._initial_node_id is not None:
            self.initial_node = self.nodes[self._initial_node_id]
        else:
            # Find the node with no incoming links.
            self.initial_node = None
            for node in self.nodes:
                if len(node.in_links) == 0:
                    self.initial_node = node
                    break
            if self.initial_node is None:
                raise InputError("Could not find initial node in SLF lattice.")

        if self._final_node_id is not None:
            self.final_node = self.nodes[self._final_node_id]
        else:
            # Find the node with no outgoing links.
            self.final_node = None
            for node in self.nodes:
                if len(node.out_links) == 0:
                    self.final_node = node
                    break
            if self.final_node is None:
                raise InputError("Could not find final node in SLF lattice.")

        # If word identity information is not present in node definitions then
        # it must appear in link definitions.
        self._move_words_to_links()
        for link in self.links:
            if link.word is None:
                raise InputError("SLF lattice does not contain word identity "
                                 "in link {} or in the following node.".format(
                                     link.id))
示例#11
0
    def from_file(classname, input_file, input_format):
        """Reads vocabulary and possibly word classes from a text file.

        ``input_format`` is one of:

        * "words": ``input_file`` contains one word per line. Each word will be
                   assigned to its own class.
        * "classes": ``input_file`` contains a word followed by whitespace
                     followed by class ID on each line. Each word will be
                     assigned to the specified class. The class IDs can be
                     anything; they will be translated to consecutive numbers
                     after reading the file.
        * "srilm-classes": ``input_file`` contains a class name, membership
                           probability, and word, separated by whitespace, on
                           each line.

        :type input_file: file object
        :param input_file: input vocabulary file

        :type input_format: str
        :param input_format: format of the input vocabulary file, "words",
                             "classes", or "srilm-classes"
        """

        # We have also a set of the words just for faster checking if a word has
        # already been encountered.
        words = set()
        id_to_word = []
        word_id_to_class_id = []
        word_classes = []
        # Mapping from the IDs in the file to our internal class IDs.
        file_id_to_class_id = dict()

        for line in input_file:
            line = line.strip()
            fields = line.split()
            if not fields:
                continue
            if input_format == 'words' and len(fields) == 1:
                word = fields[0]
                file_id = None
                prob = 1.0
            elif input_format == 'classes' and len(fields) == 2:
                word = fields[0]
                file_id = int(fields[1])
                prob = 1.0
            elif input_format == 'srilm-classes' and len(fields) == 3:
                file_id = fields[0]
                prob = float(fields[1])
                word = fields[2]
            else:
                raise InputError("%d fields on one line of vocabulary file: %s" % (len(fields), line))

            if word in ('<s>', '</s>', '<unk>'):
                # These special symbols are automatically added
                continue

            if word in words:
                raise InputError("Word `%s' appears more than once in the vocabulary file." % word)
            words.add(word)
            word_id = len(id_to_word)
            id_to_word.append(word)

            if file_id in file_id_to_class_id:
                class_id = file_id_to_class_id[file_id]
                word_classes[class_id].add(word_id, prob)
            else:
                # No ID in the file or a new ID.
                class_id = len(word_classes)
                word_class = Vocabulary.WordClass(class_id, word_id, prob)
                word_classes.append(word_class)
                if not file_id is None:
                    file_id_to_class_id[file_id] = class_id

            assert word_id == len(word_id_to_class_id)
            word_id_to_class_id.append(class_id)

        return classname(id_to_word, word_id_to_class_id, word_classes)
示例#12
0
    def decode(self, lattice):
        """Propagates tokens through given lattice and returns a list of tokens
        in the final node.

        Propagates tokens at a node to every outgoing link by creating a copy of
        each token and updating the language model scores according to the link.

        :type lattice: Lattice
        :param lattice: a word lattice to be decoded

        :rtype: list of LatticeDecoder.Tokens
        :returns: the final tokens sorted by total log probability in descending
                  order
        """

        if self._lm_scale is not None:
            lm_scale = logprob_type(self._lm_scale)
        elif lattice.lm_scale is not None:
            lm_scale = logprob_type(lattice.lm_scale)
        else:
            lm_scale = logprob_type(1.0)

        if self._wi_penalty is not None:
            wi_penalty = logprob_type(self._wi_penalty)
        if lattice.wi_penalty is not None:
            wi_penalty = logprob_type(lattice.wi_penalty)
        else:
            wi_penalty = logprob_type(0.0)

        self._tokens = [list() for _ in lattice.nodes]
        initial_state = RecurrentState(self._network.recurrent_state_size)
        initial_token = self.Token(history=[self._sos_id], state=initial_state)
        initial_token.recompute_hash(self._recombination_order)
        initial_token.recompute_total(self._nnlm_weight, lm_scale, wi_penalty,
                                      self._linear_interpolation)
        self._tokens[lattice.initial_node.id].append(initial_token)
        lattice.initial_node.best_logprob = initial_token.total_logprob

        self._sorted_nodes = lattice.sorted_nodes()
        nodes_processed = 0
        for node in self._sorted_nodes:
            node_tokens = self._tokens[node.id]
            assert node_tokens
            num_pruned_tokens = len(node_tokens)
            self._prune(node)
            node_tokens = self._tokens[node.id]
            assert node_tokens
            num_pruned_tokens -= len(node_tokens)

            if node.id == lattice.final_node.id:
                new_tokens = self._propagate(node_tokens, None, lm_scale,
                                             wi_penalty)
                return sorted(new_tokens,
                              key=lambda token: token.total_logprob,
                              reverse=True)

            num_new_tokens = 0
            for link in node.out_links:
                new_tokens = self._propagate(node_tokens, link, lm_scale,
                                             wi_penalty)
                self._tokens[link.end_node.id].extend(new_tokens)
                num_new_tokens += len(new_tokens)

            nodes_processed += 1
            if nodes_processed % math.ceil(len(self._sorted_nodes) / 20) == 0:
                logging.debug("[%d] (%.2f %%) -- tokens = %d +%d -%d",
                              nodes_processed,
                              nodes_processed / len(self._sorted_nodes) * 100,
                              len(node_tokens), num_new_tokens,
                              num_pruned_tokens)

        raise InputError("Could not reach the final node of word lattice.")