def __init__(self, baseurl, cachefile, login, password=None): """ Create a nashi client Parameters ---------- baseurl : web address of nashi instance cachefile : filename of hdf5-cache login : user for nashi password : asks for user input if empty """ self.baseurl = baseurl self.session = None self.traindata = None self.recogdata = None self.valdata = None self.bookcache = {} self.cachefile = cachefile self.login(login, password) params = DataPreprocessorParams() params.line_height = 48 params.pad = 16 params.pad_value = 1 params.no_invert = False params.no_transpose = False self.data_proc = MultiDataProcessor([ DataRangeNormalizer(), CenterNormalizer(params), FinalPreparation(params, as_uint8=True), ]) # Text pre processing (reading) preproc = TextProcessorParams() preproc.type = TextProcessorParams.MULTI_NORMALIZER default_text_normalizer_params(preproc.children.add(), default="NFC") default_text_regularizer_params(preproc.children.add(), groups=["extended"]) strip_processor_params = preproc.children.add() strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER self.txt_preproc = text_processor_from_proto(preproc, "pre") # Text post processing (prediction) postproc = TextProcessorParams() postproc.type = TextProcessorParams.MULTI_NORMALIZER default_text_normalizer_params(postproc.children.add(), default="NFC") default_text_regularizer_params(postproc.children.add(), groups=["extended"]) strip_processor_params = postproc.children.add() strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER self.text_postproc = text_processor_from_proto(postproc, "post") # BIDI text preprocessing bidi_processor_params = preproc.children.add() bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER bidi_processor_params.bidi_direction = TextProcessorParams.BIDI_RTL self.bidi_preproc = text_processor_from_proto(preproc, "pre") bidi_processor_params = postproc.children.add() bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER bidi_processor_params.bidi_direction = TextProcessorParams.BIDI_AUTO self.bidi_postproc = text_processor_from_proto(postproc, "post")
def main(): parser = argparse.ArgumentParser() parser.add_argument("--files", type=str, nargs="+", required=True, help="Text files to apply text processing") parser.add_argument("--line_height", type=int, default=48, help="The line height") parser.add_argument("--pad", type=int, default=16, help="Padding (left right) of the line") parser.add_argument("--pad_value", type=int, default=1, help="Padding (left right) of the line") parser.add_argument("--processes", type=int, default=1) parser.add_argument("--verbose", action="store_true") parser.add_argument("--invert", action="store_true") parser.add_argument("--transpose", action="store_true") parser.add_argument("--dry_run", action="store_true", help="No not overwrite files, just run") args = parser.parse_args() params = DataPreprocessorParams() params.line_height = args.line_height params.pad = args.pad params.pad_value = args.pad_value params.no_invert = not args.invert params.no_transpos = not args.transpose data_proc = MultiDataProcessor([ DataRangeNormalizer(), CenterNormalizer(params), FinalPreparation(params, as_uint8=True), ]) print("Resolving files") img_files = sorted(glob_all(args.files)) handler = Handler(data_proc, args.dry_run) with multiprocessing.Pool(processes=args.processes, maxtasksperchild=100) as pool: list( tqdm(pool.imap(handler.handle_single, img_files), desc="Processing", total=len(img_files)))
class NashiClient(): def __init__(self, baseurl, cachefile, login, password=None): """ Create a nashi client Parameters ---------- baseurl : web address of nashi instance cachefile : filename of hdf5-cache login : user for nashi password : asks for user input if empty """ self.baseurl = baseurl self.session = None self.traindata = None self.recogdata = None self.valdata = None self.bookcache = {} self.cachefile = cachefile self.login(login, password) params = DataPreprocessorParams() params.line_height = 48 params.pad = 16 params.pad_value = 1 params.no_invert = False params.no_transpose = False self.data_proc = MultiDataProcessor([ DataRangeNormalizer(), CenterNormalizer(params), FinalPreparation(params, as_uint8=True), ]) # Text pre processing (reading) preproc = TextProcessorParams() preproc.type = TextProcessorParams.MULTI_NORMALIZER default_text_normalizer_params(preproc.children.add(), default="NFC") default_text_regularizer_params(preproc.children.add(), groups=["extended"]) strip_processor_params = preproc.children.add() strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER self.txt_preproc = text_processor_from_proto(preproc, "pre") # Text post processing (prediction) postproc = TextProcessorParams() postproc.type = TextProcessorParams.MULTI_NORMALIZER default_text_normalizer_params(postproc.children.add(), default="NFC") default_text_regularizer_params(postproc.children.add(), groups=["extended"]) strip_processor_params = postproc.children.add() strip_processor_params.type = TextProcessorParams.STRIP_NORMALIZER self.text_postproc = text_processor_from_proto(postproc, "post") # BIDI text preprocessing bidi_processor_params = preproc.children.add() bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER bidi_processor_params.bidi_direction = TextProcessorParams.BIDI_RTL self.bidi_preproc = text_processor_from_proto(preproc, "pre") bidi_processor_params = postproc.children.add() bidi_processor_params.type = TextProcessorParams.BIDI_NORMALIZER bidi_processor_params.bidi_direction = TextProcessorParams.BIDI_AUTO self.bidi_postproc = text_processor_from_proto(postproc, "post") def login(self, email, pw): if pw is None: pw = getpass("Password: "******"/login") res = html.document_fromstring(r.text) csrf = res.get_element_by_id("csrf_token").attrib["value"] lg = s.post(self.baseurl+"/login", data={ "csrf_token": csrf, "email": email, "password": pw, "submit": "Login" }) if "<li>Invalid password</li>" in lg.text: raise Exception("Login failed.") self.session = s def update_books(self, books, gt_layer=0, rect=False, rtl=False): """ Update books cache Parameters ---------- books : book title or list of titles gt_layer : index of ground truth in PAGE files rect : cut out rectangles instead of line polygons rtl : set text direction to rtl """ cache = self.cache = h5py.File(self.cachefile, 'a', libver='latest') if isinstance(books, str): books = [books] icnt, tcnt = 0, 0 for b in books: print("Updating {}…".format(b)) if b.endswith("_ar") and not rtl: print("Warning: Title ends with _ar but rtl is not set!") if b.endswith("_ar") and not rect: print("Warning: Title ends with _ar but rect is not set!") book = self.getbook(b) # remove pages not contained in the nashi book if b in cache: for p in cache.get(b): if p not in book: _ = cache[b].pop(p) else: cache.create_group(b) cache[b].attrs["dir"] = "rtl" if rtl else "ltr" for p, root in book.items(): icnt_this, tcnt_this = 0, 0 print(p, end="… ") ns = {"ns": root.nsmap[None]} if p not in cache[b]: cache.create_group(b+"/"+p) cache[b][p].attrs["img_w"] = int(root.xpath('//ns:Page', namespaces=ns)[0].attrib["imageWidth"]) cache[b][p].attrs["image_file"] = root.xpath('//ns:Page', namespaces=ns)[0].attrib["imageFilename"] pageimg = None lines = root.xpath('//ns:TextLine', namespaces=ns) lids = [l.attrib["id"] for l in lines] # remove lines not contained in the page anymore for lid in cache[b][p]: if lid not in lids: _ = cache[b][p].pop(lid) for l in lines: coords = l.xpath('./ns:Coords', namespaces=ns).pop().attrib.get("points") lid = l.attrib["id"] # update line image if coords changed if lid not in cache[b][p] \ or cache[b][p][lid].attrs.get("coords") != coords: icnt_this += 1 if pageimg is None: imgresp = self.session.get(self.baseurl+"/books/{}/{}".format( b, cache[b][p].attrs["image_file"]), params={"upgrade": "nrm"}) f = BytesIO(imgresp.content) im = Image.open(f) pageimg = np.array(im) f.close() if len(pageimg.shape) > 2: pageimg = pageimg[:, :, 0] if pageimg.dtype == bool: pageimg = pageimg.astype("uint8") * 255 limg = cutout(pageimg, coords, scale=pageimg.shape[1] / cache[b][p].attrs["img_w"], rect=rect) limg = self.data_proc.apply(limg)[0] if lid not in cache[b][p]: cache[b][p].create_dataset(lid, data=limg, maxshape=(None, 48)) else: if cache[b][p][lid].shape != limg.shape: cache[b][p][lid].resize(limg.shape) cache[b][p][lid][:, :] = limg cache[b][p][lid].attrs["coords"] = coords comments = l.attrib.get("comments") if comments is not None and comments.strip(): cache[b][p][lid].attrs["comments"] = comments.strip() rtype = l.getparent().attrib.get("type") cache[b][p][lid].attrs["rtype"] = rtype if rtype is not None else "" ucd = l.xpath('./ns:TextEquiv[@index="{}"]/ns:Unicode'.format(gt_layer), namespaces=ns) rawtext = ucd[0].text if ucd else None if rawtext != cache[b][p][lid].attrs.get("text_raw"): tcnt_this += 1 cache[b][p][lid].attrs["text_raw"] = rawtext preproc = self.bidi_preproc if rtl else self.txt_preproc cache[b][p][lid].attrs["text"] = preproc.apply(rawtext) icnt += icnt_this tcnt += tcnt_this print("i: {} / t: {}".format(icnt_this, tcnt_this)) cache.flush() cache.close() def train_books(self, books, output_model_prefix, weights=None, train_to_val=1, max_iters=100000, display=500, checkpoint_frequency=-1, preload=False): if isinstance(books, str): books = [books] dset = Nash5DataSet(DataSetMode.TRAIN, self.cachefile, books) if 0 < train_to_val < 1: valsamples = random.sample(dset._samples, int((1-train_to_val)*len(dset))) for s in valsamples: dset._samples.remove(s) vdset = Nash5DataSet(DataSetMode.TRAIN, self.cachefile, []) vdset._samples = valsamples else: vdset = None parser = argparse.ArgumentParser() setup_train_args(parser, omit=["files", "validation"]) args = parser.parse_known_args()[0] with h5py.File(self.cachefile, 'r', libver='latest', swmr=True) as cache: if all(cache[b].attrs.get("dir") == "rtl" for b in books): args.bidi_dir = "rtl" params = params_from_args(args) params.output_model_prefix = output_model_prefix params.early_stopping_best_model_prefix = "best_" + output_model_prefix params.max_iters = max_iters params.display = display params.checkpoint_frequency = checkpoint_frequency trainer = Trainer(params, dset, txt_preproc=NoopTextProcessor(), data_preproc=NoopDataPreprocessor(), validation_dataset=vdset, weights=weights, preload_training=preload, preload_validation=True) trainer.train(progress_bar=True) def predict_books(self, books, models, pageupload=False, text_index=1): if type(books) == str: books = [books] if type(models) == str: models = [models] dset = Nash5DataSet(DataSetMode.PREDICT, self.cachefile, books) voter_params = VoterParams() voter_params.type = VoterParams.Type.Value("confidence_voter_default_ctc".upper()) voter = voter_from_proto(voter_params) # predict for all models predictor = MultiPredictor(checkpoints=models, data_preproc=NoopDataPreprocessor(), batch_size=1, processes=1) do_prediction = predictor.predict_dataset(dset, progress_bar=True) avg_sentence_confidence = 0 n_predictions = 0 # output the voted results to the appropriate files for result, sample in do_prediction: n_predictions += 1 for i, p in enumerate(result): p.prediction.id = "fold_{}".format(i) # vote the results (if only one model is given, this will just return the sentences) prediction = voter.vote_prediction_result(result) prediction.id = "voted" sentence = prediction.sentence avg_sentence_confidence += prediction.avg_char_probability dset.store_text(sentence, sample, output_dir=None, extension=".pred.txt") print("Average sentence confidence: {:.2%}".format(avg_sentence_confidence / n_predictions)) dset.store() print("All files written") def evaluate_books(self, books, models, mode="auto", sample=-1): if type(books) == str: books = [books] if type(models) == str: models = [models] results = {} if mode == "auto": with h5py.File(self.cachefile, 'r', libver='latest', swmr=True) as cache: for b in books: for p in cache[b]: for s in cache[b][p]: if "text" in cache[b][p][s].attrs: mode = "eval" break if mode != "auto": break if mode != "auto": break if mode == "auto": mode = "conf" if mode == "conf": dset = Nash5DataSet(DataSetMode.PREDICT, self.cachefile, books) else: dset = Nash5DataSet(DataSetMode.EVAL, self.cachefile, books) if 0 < sample < len(dset): delsamples = random.sample(dset._samples, len(dset) - sample) for s in delsamples: dset._samples.remove(s) if mode == "conf": for model in models: if isinstance(model, str): model = [model] predictor = MultiPredictor(checkpoints=model, data_preproc=NoopDataPreprocessor(), batch_size=1, processes=1) voter_params = VoterParams() voter_params.type = VoterParams.Type.Value("confidence_voter_default_ctc".upper()) voter = voter_from_proto(voter_params) do_prediction = predictor.predict_dataset(dset, progress_bar=True) avg_sentence_confidence = 0 n_predictions = 0 for result, sample in do_prediction: n_predictions += 1 prediction = voter.vote_prediction_result(result) avg_sentence_confidence += prediction.avg_char_probability results["/".join(model)] = avg_sentence_confidence / n_predictions else: for model in models: if isinstance(model, str): model = [model] predictor = MultiPredictor(checkpoint=model, data_preproc=NoopDataPreprocessor(), batch_size=1, processes=1, with_gt=True) out_gen = predictor.predict_dataset(dset, progress_bar=True, apply_preproc=False) result = Evaluator.evaluate_single_list(map(Evaluator.evaluate_single_args, map(lambda d: tuple([''.join(d[0].ground_truth), ''.join(d[0].chars)]), out_gen))) results["/".join(model)] = 1 - result["avg_ler"] return results def upload_books(self, books, text_index=1): """ Upload books from the cachefile to the server Parameters ---------- bookname : title of the book or list of titles text_index : index of the TextEquiv to write to Returns ---------- dict mapping page names to lxml etree instances """ cache = h5py.File(self.cachefile, 'r', libver='latest', swmr=True) ocrdata = {} if type(books) == str: books = [books] savelines = [cache[b][p][l] for b in books for p in cache[b] for l in cache[b][p] if cache[b][p][l].attrs.get("pred") is not None] for line in savelines: _, b, p, l = line.name.split("/") if b not in ocrdata: ocrdata[b] = {} if p not in ocrdata[b]: ocrdata[b][p] = {} ocrdata[b][p][l] = line.attrs.get("pred") data = {"ocrdata": ocrdata, "index": text_index} self.session.post(self.baseurl+"/_ocrdata", data=gzip.compress(json.dumps(data).encode("utf-8")), headers={"Content-Type": "application/json;charset=UTF-8", "Content-Encoding": "gzip"}) cache.close() def getbook(self, bookname): """ Download a book from the nashi server Parameters ---------- bookname : title of the book to load Returns ---------- dict mapping page names to lxml etree instances """ pagezip = self.session.get(self.baseurl+"/books/{}_PageXML.zip" .format(bookname)) f = BytesIO(pagezip.content) zf = zipfile.ZipFile(f) book = {} for fn in zf.namelist(): filename = fn pagename = path.splitext(path.split(filename)[1])[0] with zf.open(fn) as fo: book[pagename] = etree.parse(fo).getroot() f.close() return book