def _read_slf_header(self, fields): """Reads SLF lattice header fields and saves them in member variables. :type fields: list of strs :param fields: fields, such as name="value" """ for field in fields: name, value = _split_slf_field(field) if (name == 'UTTERANCE') or (name == 'U'): self.utterance_id = value elif (name == 'SUBLAT') or (name == 'S'): raise InputError("Sub-lattices are not supported.") elif name == 'base': value = numpy.float64(value) if value == 0.0: self._log_scale = None else: self._log_scale = logprob_type(numpy.log(value)) elif name == 'lmscale': self.lm_scale = logprob_type(value) elif name == 'wdpenalty': self.wi_penalty = logprob_type(value) elif name == 'start': self._initial_node_id = int(value) elif name == 'end': self._final_node_id = int(value) elif (name == 'NODES') or (name == 'N'): self._num_nodes = int(value) elif (name == 'LINKS') or (name == 'L'): self._num_links = int(value)
def _layer_options_from_description(self, description): """Creates layer options based on textual architecture description. Most of the fields in a layer description are kept as strings. The field ``input_layers`` is converted to a list of actual layers found from ``self.layers``. :type description: dict :param description: dictionary of textual layer fields :rtype: dict :result: layer options """ result = dict() for variable, value in description.items(): if variable == 'inputs': try: result['input_layers'] = [self.layers[x] for x in value] except KeyError as e: raise InputError("Input layer `{}' does not exist, when " "creating layer `{}'.".format( e.args[0], description['name'])) else: result[variable] = value return result
def sorted_nodes(self): """Sorts nodes topologically, then by time. Returns a list which contains the nodes in sorted order. Uses the Kahn's algorithm to sort the nodes topologically, but always picks the node from the queue that has the lowest time stamp, if the nodes contain time stamps. """ result = [] # A queue of nodes to be visited next: node_queue = [self.initial_node] # The number of incoming links not traversed yet: in_degrees = [len(node.in_links) for node in self.nodes] while node_queue: node = node_queue.pop() result.append(node) for link in node.out_links: next_node = link.end_node in_degrees[next_node.id] -= 1 if in_degrees[next_node.id] == 0: node_queue.append(next_node) node_queue.sort(key=lambda x: (x.time is None, x.time), reverse=True) elif in_degrees[next_node.id] < 0: raise InputError("Word lattice contains a cycle.") if len(result) < len(self.nodes): logging.warning("Word lattice contains unreachable nodes.") else: assert len(result) == len(self.nodes) return result
def _read_slf_link(self, link_id, fields): """Reads SLF lattice link fields and creates such link. :type link_id: int :param link_id: ID of the link :type fields: list of strs :param fields: the rest of the link fields after ID """ start_node = None end_node = None word = None ac_logprob = None lm_logprob = None for field in fields: name, value = _split_slf_field(field) if (name == 'START') or (name == 'S'): start_node = self.nodes[int(value)] elif (name == 'END') or (name == 'E'): end_node = self.nodes[int(value)] elif (name == 'WORD') or (name == 'W'): word = value elif (name == 'acoustic') or (name == 'a'): if self._log_scale is None: ac_logprob = logprob_type(numpy.log(numpy.float64(value))) else: ac_logprob = logprob_type(value) * self._log_scale elif (name == 'language') or (name == 'l'): if self._log_scale is None: lm_logprob = logprob_type(numpy.log(numpy.float64(value))) else: lm_logprob = logprob_type(value) * self._log_scale if start_node is None: raise InputError( "Start node is not specified for link {}.".format(link_id)) if end_node is None: raise InputError( "End node is not specified for link {}.".format(link_id)) link = self._add_link(start_node, end_node) link.word = word link.ac_logprob = ac_logprob link.lm_logprob = lm_logprob
def visit_link(link): end_node = link.end_node if hasattr(end_node, 'word') and isinstance(end_node.word, str): if link.word is None: link.word = end_node.word else: raise InputError("SLF lattice contains words both in nodes " "and links.") if not end_node.id in visited: visited.add(end_node.id) for next_link in end_node.out_links: visit_link(next_link)
def visit_link(link): """A function that is called recursively to move a word from the link end node to the link. """ end_node = link.end_node if hasattr(end_node, 'word') and isinstance(end_node.word, str): if link.word is None: link.word = end_node.word else: raise InputError( "SLF lattice contains words both in nodes " "and links.") if end_node.id not in visited: visited.add(end_node.id) for next_link in end_node.out_links: visit_link(next_link)
def _split_slf_field(field): """Parses the name and value from an SLF lattice field. :type field: str :param field: a field, such as 'UTTERANCE=utterance 123' :rtype: tuple of two strs :returns: the name and value of the field """ name_value = field.split('=', 1) if len(name_value) != 2: raise InputError( "Expected '=' in SLF lattice field: '{}'".format(field)) name = name_value[0] value = name_value[1] return name, value
def from_description(classname, description_file): """Reads a description of the network architecture from a text file. :type description_file: file or file-like object :param description_file: text file containing the description :rtype: Network.Architecture :returns: an object describing the network architecture """ inputs = [] layers = [] for line in description_file: fields = line.split() if not fields: continue if fields[0] == 'input': input_description = dict() for field in fields[1:]: parts = field.split('=', 1) if len(parts) != 2: raise InputError( "'field=value' expected but '{}' found in an input " "description in '{}'.".format( field, description_file.name)) variable, value = parts input_description[variable] = value if not 'type' in input_description: raise InputError( "'type' is not given in an input description in '{}'.". format(description_file.name)) if not 'name' in input_description: raise InputError( "'name' is not given in an input description in '{}'.". format(description_file.name)) inputs.append(input_description) elif fields[0] == 'layer': layer_description = {'inputs': []} for field in fields[1:]: parts = field.split('=', 1) if len(parts) != 2: raise InputError( "'field=value' expected but '{}' found in a layer " "description in '{}'.".format( field, description_file.name)) variable, value = parts if variable == 'size': layer_description[variable] = int(value) elif variable == 'input': layer_description['inputs'].append(value) else: layer_description[variable] = value if not 'type' in layer_description: raise InputError( "'type' is not given in a layer description in '{}'.". format(description_file.name)) if not 'name' in layer_description: raise InputError( "'name' is not given in a layer description in '{}'.". format(description_file.name)) if not layer_description['inputs']: raise InputError( "'input' is not given in a layer description in '{}'.". format(description_file.name)) layers.append(layer_description) else: raise InputError("Invalid element in architecture " "description: {}".format(fields[0])) if not inputs: raise InputError("Architecture description contains no inputs.") if not layers: raise InputError("Architecture description contains no layers.") return classname(inputs, layers)
def from_file(cls, input_file, input_format, oos_words=None): """Reads the shortlist words and possibly word classes from a vocabulary file. ``input_format`` is one of: * "words": ``input_file`` contains one word per line. Each word will be assigned to its own class. * "classes": ``input_file`` contains a word followed by whitespace followed by class ID on each line. Each word will be assigned to the specified class. The class IDs can be anything; they will be translated to consecutive numbers after reading the file. * "srilm-classes": ``input_file`` contains a class name, membership probability, and word, separated by whitespace, on each line. The words read from the vocabulary file are put in the shortlist. If ``oos_words`` is given, those words are given an ID and added to the vocabulary as out-of-shortlist words if they don't exist in the vocabulary file. :type input_file: file object :param input_file: input vocabulary file :type input_format: str :param input_format: format of the input vocabulary file, "words", "classes", or "srilm-classes" :type oos_words: list of strs :param oos_words: add words from this list to the vocabulary as out-of-shortlist words, if they're not in the vocabulary file """ # We have also a set of the words just for faster checking if a word has # already been encountered. words = set() id_to_word = [] word_id_to_class_id = [] word_classes = [] # Mapping from the IDs in the file to our internal class IDs. file_id_to_class_id = dict() for line in input_file: line = line.strip() fields = line.split() if not fields: continue if input_format == 'words' and len(fields) == 1: word = fields[0] file_id = None prob = 1.0 elif input_format == 'classes' and len(fields) == 2: word = fields[0] file_id = int(fields[1]) prob = 1.0 elif input_format == 'srilm-classes' and len(fields) == 3: file_id = fields[0] prob = float(fields[1]) word = fields[2] else: raise InputError( "%d fields on one line of vocabulary file: %s" % (len(fields), line)) if word in words: raise InputError("Word `%s' appears more than once in the " "vocabulary file." % word) words.add(word) word_id = len(id_to_word) id_to_word.append(word) if file_id in file_id_to_class_id: class_id = file_id_to_class_id[file_id] word_classes[class_id].add(word_id, prob) else: # No ID in the file or a new ID. class_id = len(word_classes) word_class = WordClass(class_id, word_id, prob) word_classes.append(word_class) if file_id is not None: file_id_to_class_id[file_id] = class_id assert word_id == len(word_id_to_class_id) word_id_to_class_id.append(class_id) _add_special_tokens(id_to_word, word_id_to_class_id, word_classes) words |= {'<s>', '</s>', '<unk>'} if oos_words is not None: for word in oos_words: if word not in words: words.add(word) id_to_word.append(word) return cls(id_to_word, word_id_to_class_id, word_classes)
def __init__(self, lattice_file): """Reads an SLF lattice file. If ``lattice_file`` is ``None``, creates an empty lattice (useful for testing). :type lattice_file: file object :param lattice_file: a file in SLF lattice format """ super().__init__() # No log conversion by default. "None" means the lattice file uses # linear probabilities. self._log_scale = logprob_type(1.0) self._initial_node_id = None self._final_node_id = None if lattice_file is None: self._num_nodes = 0 self._num_links = 0 return self._num_nodes = None self._num_links = None for line in lattice_file: fields = _split_slf_line(line) self._read_slf_header(fields) if (self._num_nodes is not None) and (self._num_links is not None): break if self.wi_penalty is not None: if self._log_scale is None: self.wi_penalty = numpy.log(self.wi_penalty) else: self.wi_penalty *= self._log_scale self.nodes = [self.Node(node_id) for node_id in range(self._num_nodes)] for line in lattice_file: fields = _split_slf_line(line) if not fields: continue name, value = _split_slf_field(fields[0]) if name == 'I': self._read_slf_node(int(value), fields[1:]) elif name == 'J': self._read_slf_link(int(value), fields[1:]) if len(self.links) != self._num_links: raise InputError( "Number of links in SLF lattice doesn't match the " "LINKS field.") if self._initial_node_id is not None: self.initial_node = self.nodes[self._initial_node_id] else: # Find the node with no incoming links. self.initial_node = None for node in self.nodes: if len(node.in_links) == 0: self.initial_node = node break if self.initial_node is None: raise InputError("Could not find initial node in SLF lattice.") if self._final_node_id is not None: self.final_node = self.nodes[self._final_node_id] else: # Find the node with no outgoing links. self.final_node = None for node in self.nodes: if len(node.out_links) == 0: self.final_node = node break if self.final_node is None: raise InputError("Could not find final node in SLF lattice.") # If word identity information is not present in node definitions then # it must appear in link definitions. self._move_words_to_links() for link in self.links: if link.word is None: raise InputError("SLF lattice does not contain word identity " "in link {} or in the following node.".format( link.id))
def from_file(classname, input_file, input_format): """Reads vocabulary and possibly word classes from a text file. ``input_format`` is one of: * "words": ``input_file`` contains one word per line. Each word will be assigned to its own class. * "classes": ``input_file`` contains a word followed by whitespace followed by class ID on each line. Each word will be assigned to the specified class. The class IDs can be anything; they will be translated to consecutive numbers after reading the file. * "srilm-classes": ``input_file`` contains a class name, membership probability, and word, separated by whitespace, on each line. :type input_file: file object :param input_file: input vocabulary file :type input_format: str :param input_format: format of the input vocabulary file, "words", "classes", or "srilm-classes" """ # We have also a set of the words just for faster checking if a word has # already been encountered. words = set() id_to_word = [] word_id_to_class_id = [] word_classes = [] # Mapping from the IDs in the file to our internal class IDs. file_id_to_class_id = dict() for line in input_file: line = line.strip() fields = line.split() if not fields: continue if input_format == 'words' and len(fields) == 1: word = fields[0] file_id = None prob = 1.0 elif input_format == 'classes' and len(fields) == 2: word = fields[0] file_id = int(fields[1]) prob = 1.0 elif input_format == 'srilm-classes' and len(fields) == 3: file_id = fields[0] prob = float(fields[1]) word = fields[2] else: raise InputError("%d fields on one line of vocabulary file: %s" % (len(fields), line)) if word in ('<s>', '</s>', '<unk>'): # These special symbols are automatically added continue if word in words: raise InputError("Word `%s' appears more than once in the vocabulary file." % word) words.add(word) word_id = len(id_to_word) id_to_word.append(word) if file_id in file_id_to_class_id: class_id = file_id_to_class_id[file_id] word_classes[class_id].add(word_id, prob) else: # No ID in the file or a new ID. class_id = len(word_classes) word_class = Vocabulary.WordClass(class_id, word_id, prob) word_classes.append(word_class) if not file_id is None: file_id_to_class_id[file_id] = class_id assert word_id == len(word_id_to_class_id) word_id_to_class_id.append(class_id) return classname(id_to_word, word_id_to_class_id, word_classes)
def decode(self, lattice): """Propagates tokens through given lattice and returns a list of tokens in the final node. Propagates tokens at a node to every outgoing link by creating a copy of each token and updating the language model scores according to the link. :type lattice: Lattice :param lattice: a word lattice to be decoded :rtype: list of LatticeDecoder.Tokens :returns: the final tokens sorted by total log probability in descending order """ if self._lm_scale is not None: lm_scale = logprob_type(self._lm_scale) elif lattice.lm_scale is not None: lm_scale = logprob_type(lattice.lm_scale) else: lm_scale = logprob_type(1.0) if self._wi_penalty is not None: wi_penalty = logprob_type(self._wi_penalty) if lattice.wi_penalty is not None: wi_penalty = logprob_type(lattice.wi_penalty) else: wi_penalty = logprob_type(0.0) self._tokens = [list() for _ in lattice.nodes] initial_state = RecurrentState(self._network.recurrent_state_size) initial_token = self.Token(history=[self._sos_id], state=initial_state) initial_token.recompute_hash(self._recombination_order) initial_token.recompute_total(self._nnlm_weight, lm_scale, wi_penalty, self._linear_interpolation) self._tokens[lattice.initial_node.id].append(initial_token) lattice.initial_node.best_logprob = initial_token.total_logprob self._sorted_nodes = lattice.sorted_nodes() nodes_processed = 0 for node in self._sorted_nodes: node_tokens = self._tokens[node.id] assert node_tokens num_pruned_tokens = len(node_tokens) self._prune(node) node_tokens = self._tokens[node.id] assert node_tokens num_pruned_tokens -= len(node_tokens) if node.id == lattice.final_node.id: new_tokens = self._propagate(node_tokens, None, lm_scale, wi_penalty) return sorted(new_tokens, key=lambda token: token.total_logprob, reverse=True) num_new_tokens = 0 for link in node.out_links: new_tokens = self._propagate(node_tokens, link, lm_scale, wi_penalty) self._tokens[link.end_node.id].extend(new_tokens) num_new_tokens += len(new_tokens) nodes_processed += 1 if nodes_processed % math.ceil(len(self._sorted_nodes) / 20) == 0: logging.debug("[%d] (%.2f %%) -- tokens = %d +%d -%d", nodes_processed, nodes_processed / len(self._sorted_nodes) * 100, len(node_tokens), num_new_tokens, num_pruned_tokens) raise InputError("Could not reach the final node of word lattice.")