def __init__(self, *args, **kwargs): kwargs['ocrd_tool'] = OCRD_TOOL['tools']['ocrd-cor-asv-fst-process'] kwargs['version'] = OCRD_TOOL['version'] super(PageXMLProcessor, self).__init__(*args, **kwargs) if not hasattr(self, 'workspace') or not self.workspace: raise RuntimeError('no workspace specified!') # initialize the decoder LOG.info("Loading the correction models") self.latticegen = FSTLatticeGenerator( self.parameter['lexicon_file'], self.parameter['error_model_file'], lattice_format='networkx', words_per_window=self.parameter['words_per_window'], rejection_weight=self.parameter['rejection_weight'], beam_width=self.parameter['beam_width']) # initialize the language model self.rater = Rater(logger=LOG) self.rater.load_config(self.parameter['keraslm_file']) # overrides for incremental mode necessary before compilation: self.rater.stateful = False # no implicit state transfer self.rater.incremental = True # but explicit state transfer self.rater.configure() self.rater.load_weights(self.parameter['keraslm_file'])
def __init__(self, *args, **kwargs): kwargs['ocrd_tool'] = OCRD_TOOL['tools']['ocrd-cor-asv-fst-process'] kwargs['version'] = OCRD_TOOL['version'] super(FSTCorrection, self).__init__(*args, **kwargs) if not hasattr(self, 'parameter'): # instantiated in non-processing context (e.g. -J/-h) return # initialize the decoder LOG.info("Loading the correction models") self.latticegen = FSTLatticeGenerator( self.parameter['lexicon_file'], self.parameter['error_model_file'], lattice_format='networkx', words_per_window=self.parameter['words_per_window'], rejection_weight=self.parameter['rejection_weight'], pruning_weight=self.parameter['pruning_weight']) # initialize the language model self.rater = Rater(logger=LOG) self.rater.load_config(self.parameter['keraslm_file']) # overrides for incremental mode necessary before compilation: self.rater.stateful = False # no implicit state transfer self.rater.incremental = True # but explicit state transfer self.rater.configure() self.rater.load_weights(self.parameter['keraslm_file'])
class PageXMLProcessor(Processor): ''' Class responsible for processing the input data in PageXML format within the OCR-D workflow. ''' def __init__(self, *args, **kwargs): kwargs['ocrd_tool'] = OCRD_TOOL['tools']['ocrd-cor-asv-fst-process'] kwargs['version'] = OCRD_TOOL['version'] super(PageXMLProcessor, self).__init__(*args, **kwargs) if not hasattr(self, 'workspace') or not self.workspace: raise RuntimeError('no workspace specified!') # initialize the decoder LOG.info("Loading the correction models") self.latticegen = FSTLatticeGenerator( self.parameter['lexicon_file'], self.parameter['error_model_file'], lattice_format='networkx', words_per_window=self.parameter['words_per_window'], rejection_weight=self.parameter['rejection_weight'], beam_width=self.parameter['beam_width']) # initialize the language model self.rater = Rater(logger=LOG) self.rater.load_config(self.parameter['keraslm_file']) # overrides for incremental mode necessary before compilation: self.rater.stateful = False # no implicit state transfer self.rater.incremental = True # but explicit state transfer self.rater.configure() self.rater.load_weights(self.parameter['keraslm_file']) def process(self): for (n, input_file) in enumerate(self.input_files): LOG.info("INPUT FILE %i / %s", n, input_file) local_input_file = self.workspace.download_file(input_file) pcgts = parse(local_input_file.url, silence=True) LOG.info("Scoring text in page '%s' at the %s level", pcgts.get_pcGtsId(), self.parameter['textequiv_level']) self._process_page(pcgts) # write back result file_id = concat_padded(self.output_file_grp, n) self.workspace.add_file( ID=file_id, file_grp=self.output_file_grp, local_filename=os.path.join(self.output_file_grp, file_id), mimetype=MIMETYPE_PAGE, content=to_xml(pcgts), ) def _process_page(self, pcgts): self._add_my_metadata_to_page(pcgts) prev_traceback = None # TODO: pass from previous page prev_line = None # TODO: pass from previous page for n_line in _page_get_lines(pcgts): # decoding: line -> windows -> lattice windows = self._line_to_windows(n_line) self._process_windows(windows) graph = self._combine_windows_to_line_graph(windows) # find best path for previous line, advance traceback/beam for # current line line_start_node = 0 line_end_node = max(i + j for i, j in windows) context = self._get_context_from_identifier(\ self.workspace.mets.unique_identifier) path, entropy, traceback = self.rater.rate_best( graph, line_start_node, line_end_node, start_traceback = prev_traceback, context = context, lm_weight = self.parameter['lm_weight'], beam_width = self.parameter['lm_beam_width'], beam_clustering_dist = \ BEAM_CLUSTERING_DIST if BEAM_CLUSTERING_ENABLE else 0) # apply best path to line in PAGE if prev_line: _line_update_from_path(prev_line, path, entropy) prev_line = n_line prev_traceback = traceback # apply best path to last line in PAGE # TODO only to last line in document (when passing traceback between # pages) if prev_line: path, entropy, _ = self.rater.next_path(prev_traceback[0], ([], prev_traceback[1])) _line_update_from_path(prev_line, path, entropy) # ensure parent textequivs are up to date: page_update_higher_textequiv_levels('word', pcgts) def _add_my_metadata_to_page(self, pcgts): metadata = pcgts.get_Metadata() metadata.add_MetadataItem( MetadataItemType(type_='processingStep', name=OCRD_TOOL['tools'] ['ocrd-cor-asv-fst-process']['steps'][0], value='ocrd-cor-asv-fst-process', Labels=[ LabelsType( externalRef='parameters', Label=[ LabelType(type_=name, value=self.parameter[name]) for name in self.parameter.keys() ]) ])) def _line_to_tokens(self, n_line): result = [] n_words = n_line.get_Word() if not n_words: LOG.warning("Line '%s' contains no word", n_line.id) for n_word in n_words: n_textequivs = n_word.get_TextEquiv() if n_textequivs and n_textequivs[0].Unicode: result.append(n_textequivs[0].Unicode) else: LOG.warning("Word '%s' contains no text results", n_word.id) return result def _line_to_windows(self, n_line): # TODO currently: gets the text from the Word elements and returns # lines as lists of words; # needed: also get glyph alternatives # FIXME code duplication! this should be done by FSTLatticeGenerator n_words = n_line.get_Word() tokens = self._line_to_tokens(n_line) return { (i, j) : (self._merge_word_nodes(n_words[i:i+j]), create_window(tokens[i:i+j]), tokens[i:i+j]) \ for i in range(len(tokens)) \ for j in range(1, min(self.parameter['max_window_size']+1, len(tokens)-i+1)) } def _merge_word_nodes(self, nodes): if not nodes: LOG.error('nothing to merge') return None merged = WordType() merged.set_id(','.join([n.id for n in nodes])) points = ' '.join([n.get_Coords().points for n in nodes]) if points: merged.set_Coords(nodes[0].get_Coords()) # TODO merge # make other attributes and TextStyle a majority vote, but no Glyph # (too fine-grained) or TextEquiv (overwritten from best path anyway) languages = list(map(lambda elem: elem.get_language(), nodes)) if languages: merged.set_language(max(set(languages), key=languages.count)) # TODO other attributes... styles = map(lambda elem: elem.get_TextStyle(), nodes) if any(styles): # TODO make a majority vote on each attribute here merged.set_TextStyle(nodes[0].get_TextStyle()) return merged def _process_windows(self, windows): for (i, j), (ref, fst, tokens) in windows.items(): LOG.debug('Processing window ({}, {})'.format(i, j)) fst = process_window( ' '.join(tokens), fst, (self.latticegen.error_fst, self.latticegen.window_fst), beam_width=self.parameter['beam_width'], rejection_weight=self.parameter['rejection_weight']) windows[(i, j)] = (ref, fst, tokens) def _combine_windows_to_line_graph(self, windows): graph = nx.DiGraph() line_end_node = max(i + j for i, j in windows) graph.add_nodes_from(range(line_end_node + 1)) for (i, j), (ref, fst, tokens) in windows.items(): start_node = i end_node = i + j paths = [(output_str, float(weight)) \ for input_str, output_str, weight in \ fst.paths().items()] if paths: for path in paths: LOG.info('({}, {}, \'{}\', {})'.format(\ start_node, end_node, path[0], pow(2, -path[1]))) graph.add_edge( start_node, end_node, element=ref, alternatives=list( map( lambda path: TextEquivType(Unicode=path[0], conf=pow(2, -path[1])), paths))) else: LOG.warning('No path from {} to {}.'.format(i, i + j)) return graph def _get_context_from_identifier(self, identifier): context = [0] if identifier: name = identifier.split('/')[-1] year = name.split('_')[-1] if year.isnumeric(): year = ceil(int(year) / 10) context = [year] return context
def main(): ''' Read OCR-ed lines: - either from files following the path scheme <directory>/<ID>.<suffix>, where each file contains one line of text, - or from a single, two-column file: <ID> <TAB> <line>. Correct each line and save output according to one of the two above-mentioned schemata. ''' global PROCESSOR # parse command-line arguments and set up various parameters args = parse_arguments() logging.basicConfig(level=logging.getLevelName(args.log_level)) # check the validity of parameters specifying input/output if args.input_file is None and \ (args.input_suffix is None or args.directory is None): raise RuntimeError('No input data supplied! You have to specify either' ' -i or -I and the data directory.') if args.output_file is None and \ (args.output_suffix is None or args.directory is None): raise RuntimeError('No output file speficied! You have to specify ' 'either -o or -O and the data directory.') using_lm = (args.language_model_file is not None) latticegen = FSTLatticeGenerator( args.lexicon_file, args.error_model_file, lattice_format='networkx' if using_lm else 'fst', words_per_window=args.words_per_window, rejection_weight=args.rejection_weight, pruning_weight=args.pruning_weight) lm = None if using_lm: lm = Rater(logger=logging) lm.load_config(args.language_model_file) # overrides for incremental mode necessary before compilation: lm.stateful = False # no implicit state transfer lm.incremental = True # but explicit state transfer lm.configure() lm.load_weights(args.language_model_file) PROCESSOR = PlaintextProcessor(latticegen, lm) # load input data pairs = load_pairs_from_file(args.input_file) \ if args.input_file is not None \ else load_pairs_from_dir(args.directory, args.input_suffix) # process results = parallel_process(pairs, args.processes) \ if args.processes > 1 \ else [(basename, PROCESSOR.correct_string(input_str)) \ for basename, input_str in pairs] # save results if args.output_file is not None: save_pairs_to_file(results, args.output_file) else: save_pairs_to_dir(results, args.directory, args.output_suffix)
class FSTCorrection(Processor): '''Perform OCR post-correction with error/lexicon FST and character-level LSTM LM. Open and deserialise PAGE input files, then iterate over the element hierarchy down to the requested `textequiv_level`, creating a lattice of Word elements with different spans (from 1 input token up to N successors) for each line. (When merging input tokens, concatenate their string values (TextEquiv) with spaces, and combine their coordinates and other attributes as precise as possible. Where the output contains spaces, introduced by the correction model, do not attempt to split, but keep the original Word.) Each lattice element (multi-token Word) now represents a _window_ of input string hypotheses which can be FST-processed efficiently, producing a number of output string hypotheses from its local n-best paths. These strings are written to the elements' TextEquivs. The lattice is then passed to language model rescoring and best path search: The LM decoder combines alternatives from all elements into sequences which can be fed into the LM rater, but not exhaustively (which is infeasible) but in a A* depth-first beam search. It does so by iteratively adding new input characters from the lattice to existing LM state representations of a priority queue (beam) of best-scoring character sequences (i.e. histories / lattice paths) up to that point. For each line, the LM decoder outputs the beam at the end of the input lattice, which will be passed in with the next line, and it outputs the decision on the best-scoring path up to the end of the previous line. (This way, the context on the next line is used to re-rank the beam of the current.) This path is used to concatenate the Word elements to be annotated for the line. Finally, make the levels above `textequiv_level` consistent with that textual result (by concatenation joined by whitespace). Produce new output files by serialising the resulting hierarchy. ''' def __init__(self, *args, **kwargs): kwargs['ocrd_tool'] = OCRD_TOOL['tools']['ocrd-cor-asv-fst-process'] kwargs['version'] = OCRD_TOOL['version'] super(FSTCorrection, self).__init__(*args, **kwargs) if not hasattr(self, 'parameter'): # instantiated in non-processing context (e.g. -J/-h) return # initialize the decoder LOG.info("Loading the correction models") self.latticegen = FSTLatticeGenerator( self.parameter['lexicon_file'], self.parameter['error_model_file'], lattice_format='networkx', words_per_window=self.parameter['words_per_window'], rejection_weight=self.parameter['rejection_weight'], pruning_weight=self.parameter['pruning_weight']) # initialize the language model self.rater = Rater(logger=LOG) self.rater.load_config(self.parameter['keraslm_file']) # overrides for incremental mode necessary before compilation: self.rater.stateful = False # no implicit state transfer self.rater.incremental = True # but explicit state transfer self.rater.configure() self.rater.load_weights(self.parameter['keraslm_file']) def process(self): assert_file_grp_cardinality(self.input_file_grp, 1) assert_file_grp_cardinality(self.output_file_grp, 1) for (n, input_file) in enumerate(self.input_files): LOG.info("INPUT FILE %i / %s", n, input_file) pcgts = page_from_file(self.workspace.download_file(input_file)) LOG.info("Scoring text in page '%s' at the %s level", pcgts.get_pcGtsId(), self.parameter['textequiv_level']) self._process_page(pcgts) # write back result file_id = make_file_id(input_file, self.output_file_grp) self.workspace.add_file( ID=file_id, file_grp=self.output_file_grp, pageId=input_file.pageId, local_filename=os.path.join(self.output_file_grp, file_id + '.xml'), mimetype=MIMETYPE_PAGE, content=to_xml(pcgts), ) def _process_page(self, pcgts): self.add_metadata(pcgts) prev_traceback = None # TODO: pass from previous page prev_line = None # TODO: pass from previous page for n_line in _page_get_lines(pcgts): # decoding: line -> windows -> lattice windows = self._line_to_windows(n_line) self._process_windows(windows) graph = self._combine_windows_to_line_graph(windows) # find best path for previous line, advance traceback/beam for # current line line_start_node = 0 line_end_node = max(i + j for i, j in windows) context = self._get_context_from_identifier(\ self.workspace.mets.unique_identifier) path, entropy, traceback = self.rater.rate_best( graph, line_start_node, line_end_node, start_traceback = prev_traceback, context = context, lm_weight = self.parameter['lm_weight'], beam_width = self.parameter['beam_width'], beam_clustering_dist = \ BEAM_CLUSTERING_DIST if BEAM_CLUSTERING_ENABLE else 0) # apply best path to line in PAGE if prev_line: _line_update_from_path(prev_line, path, entropy) prev_line = n_line prev_traceback = traceback # apply best path to last line in PAGE # TODO only to last line in document (when passing traceback between # pages) if prev_line: path, entropy, _ = self.rater.next_path(prev_traceback[0], ([], prev_traceback[1])) _line_update_from_path(prev_line, path, entropy) # ensure parent textequivs are up to date: page_update_higher_textequiv_levels('word', pcgts) def _line_to_tokens(self, n_line): result = [] n_words = n_line.get_Word() if not n_words: LOG.warning("Line '%s' contains no word", n_line.id) for n_word in n_words: n_textequivs = n_word.get_TextEquiv() if n_textequivs and n_textequivs[0].Unicode: result.append(n_textequivs[0].Unicode) else: LOG.warning("Word '%s' contains no text results", n_word.id) return result def _line_to_windows(self, n_line): # currently: creates a lattice of (multi-word-) tokens/windows, # each as a tuple of (merged) Word object, input string FST, and # input string list; # todo: read graph directly from OCR's CTC decoder, 'splitting' sub-graphs at whitespace candidates # FIXME: also get glyph alternatives (textequiv_level=glyph) # FIXME: also import confidence # FIXME: split the line(s) into words (textequiv_level=line) # FIXME code duplication! this should be done by FSTLatticeGenerator n_words = n_line.get_Word() tokens = self._line_to_tokens(n_line) return { (i, j) : (self._merge_word_nodes(n_words[i:i+j]), create_window(tokens[i:i+j]), tokens[i:i+j]) \ for i in range(len(tokens)) \ for j in range(1, min(self.parameter['max_window_size']+1, len(tokens)-i+1)) } def _merge_word_nodes(self, nodes): if not nodes: LOG.error('nothing to merge') return None merged = WordType() merged.set_id(','.join([n.id for n in nodes])) points = ' '.join([n.get_Coords().points for n in nodes]) if points: merged.set_Coords(nodes[0].get_Coords()) # TODO merge # make other attributes and TextStyle a majority vote, but no Glyph # (too fine-grained) or TextEquiv (overwritten from best path anyway) languages = list(map(lambda elem: elem.get_language(), nodes)) if languages: merged.set_language(max(set(languages), key=languages.count)) # TODO other attributes... styles = map(lambda elem: elem.get_TextStyle(), nodes) if any(styles): # TODO make a majority vote on each attribute here merged.set_TextStyle(nodes[0].get_TextStyle()) return merged def _process_windows(self, windows): for (i, j), (ref, fst, tokens) in windows.items(): LOG.debug('Processing window ({}, {})'.format(i, j)) # FIXME: this NEEDS multiprocessing (as before 81dd2c0c)! fst = process_window( ' '.join(tokens), fst, (self.latticegen.error_fst, self.latticegen.window_fst), pruning_weight=self.parameter['pruning_weight'], rejection_weight=self.parameter['rejection_weight']) windows[(i, j)] = (ref, fst, tokens) def _combine_windows_to_line_graph(self, windows): graph = nx.DiGraph() line_end_node = max(i + j for i, j in windows) graph.add_nodes_from(range(line_end_node + 1)) for (i, j), (ref, fst, tokens) in windows.items(): start_node = i end_node = i + j # FIXME: this will NOT work without spaces and newlines (as before 81dd2c0c)! paths = [(output_str, float(weight)) \ for input_str, output_str, weight in \ fst.paths().items()] if paths: for path in paths: LOG.info('({}, {}, \'{}\', {})'.format(\ start_node, end_node, path[0], pow(2, -path[1]))) graph.add_edge( start_node, end_node, element=ref, alternatives=list( map( lambda path: TextEquivType(Unicode=path[0], conf=pow(2, -path[1])), paths))) else: LOG.warning('No path from {} to {}.'.format(i, i + j)) return graph def _get_context_from_identifier(self, identifier): context = [0] if identifier: name = identifier.split('/')[-1] year = name.split('_')[-1] if year.isnumeric(): year = ceil(int(year) / 10) context = [year] return context