def _load_models(self): if self._predictor is not None: return if self._ocr == "FAKE": return batch_size = self._options["batch_size"] if batch_size > 0: batch_size_kwargs = dict(batch_size=batch_size) else: batch_size_kwargs = dict() self._chunk_size = batch_size if len(self._models) == 1: self._predictor = Predictor( str(self._models[0]), **batch_size_kwargs) self._predict_kwargs = batch_size_kwargs self._voter = None self._line_height = int(self._predictor.model_params.line_height) else: logging.info("using Calamari voting with %d models." % len(self._models)) self._predictor = MultiPredictor( checkpoints=[str(p) for p in self._models], **batch_size_kwargs) self._predict_kwargs = dict() self._voter = ConfidenceVoter() self._line_height = int(self._predictor.predictors[0].model_params.line_height)
def test_raw_prediction(self): args = PredictionAttrs() predictor = Predictor(checkpoint=args.checkpoint[0]) images = [np.array(Image.open(file), dtype=np.uint8) for file in args.files] for file, image in zip(args.files, images): r = list(predictor.predict_raw([image], progress_bar=False))[0] print(file, r.sentence)
def test_raw_dataset_prediction(self): args = PredictionAttrs() predictor = Predictor(checkpoint=args.checkpoint[0]) data = create_dataset( DataSetType.FILE, DataSetMode.PREDICT, images=args.files, ) for prediction, sample in predictor.predict_dataset(data): pass
def train(self, progress_bar=False): checkpoint_params = self.checkpoint_params train_start_time = time.time() + self.checkpoint_params.total_time self.dataset.load_samples(processes=1, progress_bar=progress_bar) datas, txts = self.dataset.train_samples(skip_empty=checkpoint_params.skip_invalid_gt) if len(datas) == 0: raise Exception("Empty dataset is not allowed. Check if the data is at the correct location") if self.validation_dataset: self.validation_dataset.load_samples(processes=1, progress_bar=progress_bar) validation_datas, validation_txts = self.validation_dataset.train_samples(skip_empty=checkpoint_params.skip_invalid_gt) if len(validation_datas) == 0: raise Exception("Validation dataset is empty. Provide valid validation data for early stopping.") else: validation_datas, validation_txts = [], [] # preprocessing steps texts = self.txt_preproc.apply(txts, processes=checkpoint_params.processes, progress_bar=progress_bar) datas = self.data_preproc.apply(datas, processes=checkpoint_params.processes, progress_bar=progress_bar) validation_txts = self.txt_preproc.apply(validation_txts, processes=checkpoint_params.processes, progress_bar=progress_bar) validation_datas = self.data_preproc.apply(validation_datas, processes=checkpoint_params.processes, progress_bar=progress_bar) # compute the codec codec = self.codec if self.codec else Codec.from_texts(texts, whitelist=self.codec_whitelist) # data augmentation on preprocessed data if self.data_augmenter: datas, texts = self.data_augmenter.augment_datas(datas, texts, n_augmentations=self.n_augmentations, processes=checkpoint_params.processes, progress_bar=progress_bar) # TODO: validation data augmentation # validation_datas, validation_txts = self.data_augmenter.augment_datas(validation_datas, validation_txts, n_augmentations=0, # processes=checkpoint_params.processes, progress_bar=progress_bar) # create backend network_params = checkpoint_params.model.network network_params.features = checkpoint_params.model.line_height network_params.classes = len(codec) if self.weights: # if we load the weights, take care of codec changes as-well with open(self.weights + '.json', 'r') as f: restore_checkpoint_params = json_format.Parse(f.read(), CheckpointParams()) restore_model_params = restore_checkpoint_params.model # checks if checkpoint_params.model.line_height != network_params.features: raise Exception("The model to restore has a line height of {} but a line height of {} is requested".format( network_params.features, checkpoint_params.model.line_height )) # create codec of the same type restore_codec = codec.__class__(restore_model_params.codec.charset) # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one codec_changes = restore_codec.align(codec) codec = restore_codec print("Codec changes: {} deletions, {} appends".format(len(codec_changes[0]), len(codec_changes[1]))) # The actual weight/bias matrix will be changed after loading the old weights else: codec_changes = None # store the new codec checkpoint_params.model.codec.charset[:] = codec.charset print("CODEC: {}".format(codec.charset)) # compute the labels with (new/current) codec labels = [codec.encode(txt) for txt in texts] backend = create_backend_from_proto(network_params, weights=self.weights, ) backend.set_train_data(datas, labels) backend.set_prediction_data(validation_datas) if codec_changes: backend.realign_model_labels(*codec_changes) backend.prepare(train=True) loss_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.loss_stats) ler_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.ler_stats) dt_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.dt_stats) early_stopping_enabled = self.validation_dataset is not None \ and checkpoint_params.early_stopping_frequency > 0 \ and checkpoint_params.early_stopping_nbest > 1 early_stopping_best_accuracy = checkpoint_params.early_stopping_best_accuracy early_stopping_best_cur_nbest = checkpoint_params.early_stopping_best_cur_nbest early_stopping_best_at_iter = checkpoint_params.early_stopping_best_at_iter early_stopping_predictor = Predictor(codec=codec, text_postproc=self.txt_postproc, backend=backend) # Start the actual training # ==================================================================================== iter = checkpoint_params.iter # helper function to write a checkpoint def make_checkpoint(base_dir, prefix, version=None): if version: checkpoint_path = os.path.abspath(os.path.join(base_dir, "{}{}.ckpt".format(prefix, version))) else: checkpoint_path = os.path.abspath(os.path.join(base_dir, "{}{:08d}.ckpt".format(prefix, iter + 1))) print("Storing checkpoint to '{}'".format(checkpoint_path)) backend.save_checkpoint(checkpoint_path) checkpoint_params.iter = iter checkpoint_params.loss_stats[:] = loss_stats.values checkpoint_params.ler_stats[:] = ler_stats.values checkpoint_params.dt_stats[:] = dt_stats.values checkpoint_params.total_time = time.time() - train_start_time checkpoint_params.early_stopping_best_accuracy = early_stopping_best_accuracy checkpoint_params.early_stopping_best_cur_nbest = early_stopping_best_cur_nbest checkpoint_params.early_stopping_best_at_iter = early_stopping_best_at_iter with open(checkpoint_path + ".json", 'w') as f: f.write(json_format.MessageToJson(checkpoint_params)) return checkpoint_path try: last_checkpoint = None # Training loop, can be interrupted by early stopping for iter in range(iter, checkpoint_params.max_iters): checkpoint_params.iter = iter iter_start_time = time.time() result = backend.train_step(checkpoint_params.batch_size) if not np.isfinite(result['loss']): print("Error: Loss is not finite! Trying to restart from last checkpoint.") if not last_checkpoint: raise Exception("No checkpoint written yet. Training must be stopped.") else: # reload also non trainable weights, such as solver-specific variables backend.load_checkpoint_weights(last_checkpoint, restore_only_trainable=False) continue loss_stats.push(result['loss']) ler_stats.push(result['ler']) dt_stats.push(time.time() - iter_start_time) if iter % checkpoint_params.display == 0: pred_sentence = self.txt_postproc.apply("".join(codec.decode(result["decoded"][0]))) gt_sentence = self.txt_postproc.apply("".join(codec.decode(result["gt"][0]))) print("#{:08d}: loss={:.8f} ler={:.8f} dt={:.8f}s".format(iter, loss_stats.mean(), ler_stats.mean(), dt_stats.mean())) print(" PRED: '{}'".format(pred_sentence)) print(" TRUE: '{}'".format(gt_sentence)) if (iter + 1) % checkpoint_params.checkpoint_frequency == 0: last_checkpoint = make_checkpoint(checkpoint_params.output_dir, checkpoint_params.output_model_prefix) if early_stopping_enabled and (iter + 1) % checkpoint_params.early_stopping_frequency == 0: print("Checking early stopping model") out = early_stopping_predictor.predict_raw(validation_datas, batch_size=checkpoint_params.batch_size, progress_bar=progress_bar, apply_preproc=False) pred_texts = [d.sentence for d in out] result = Evaluator.evaluate(gt_data=validation_txts, pred_data=pred_texts, progress_bar=progress_bar) accuracy = 1 - result["avg_ler"] if accuracy > early_stopping_best_accuracy: early_stopping_best_accuracy = accuracy early_stopping_best_cur_nbest = 1 early_stopping_best_at_iter = iter + 1 # overwrite as best model last_checkpoint = make_checkpoint( checkpoint_params.early_stopping_best_model_output_dir, prefix="", version=checkpoint_params.early_stopping_best_model_prefix, ) print("Found better model with accuracy of {:%}".format(early_stopping_best_accuracy)) else: early_stopping_best_cur_nbest += 1 print("No better model found. Currently accuracy of {:%} at iter {} (remaining nbest = {})". format(early_stopping_best_accuracy, early_stopping_best_at_iter, checkpoint_params.early_stopping_nbest - early_stopping_best_cur_nbest)) if accuracy > 0 and early_stopping_best_cur_nbest >= checkpoint_params.early_stopping_nbest: print("Early stopping now.") break except KeyboardInterrupt as e: print("Storing interrupted checkpoint") make_checkpoint(checkpoint_params.output_dir, checkpoint_params.output_model_prefix, "interrupted") raise e print("Total time {}s for {} iterations.".format(time.time() - train_start_time, iter))
def _run_train(self, train_net, test_net, codec, train_start_time, progress_bar): checkpoint_params = self.checkpoint_params validation_dataset = test_net.input_dataset iters_per_epoch = max( 1, int(train_net.input_dataset.epoch_size() / checkpoint_params.batch_size)) loss_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.loss_stats) ler_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.ler_stats) dt_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.dt_stats) display = checkpoint_params.display display_epochs = display <= 1 if display <= 0: display = 0 # to not display anything elif display_epochs: display = max(1, int(display * iters_per_epoch)) # relative to epochs else: display = max(1, int(display)) # iterations checkpoint_frequency = checkpoint_params.checkpoint_frequency early_stopping_frequency = checkpoint_params.early_stopping_frequency if early_stopping_frequency < 0: # set early stopping frequency to half epoch early_stopping_frequency = int(0.5 * iters_per_epoch) elif 0 < early_stopping_frequency <= 1: early_stopping_frequency = int( early_stopping_frequency * iters_per_epoch) # relative to epochs else: early_stopping_frequency = int(early_stopping_frequency) early_stopping_frequency = max(1, early_stopping_frequency) if checkpoint_frequency < 0: checkpoint_frequency = early_stopping_frequency elif 0 < checkpoint_frequency <= 1: checkpoint_frequency = int(checkpoint_frequency * iters_per_epoch) # relative to epochs else: checkpoint_frequency = int(checkpoint_frequency) early_stopping_enabled = self.validation_dataset is not None \ and checkpoint_params.early_stopping_frequency > 0 \ and checkpoint_params.early_stopping_nbest > 1 early_stopping_best_accuracy = checkpoint_params.early_stopping_best_accuracy early_stopping_best_cur_nbest = checkpoint_params.early_stopping_best_cur_nbest early_stopping_best_at_iter = checkpoint_params.early_stopping_best_at_iter early_stopping_predictor = Predictor(codec=codec, text_postproc=self.txt_postproc, network=test_net) # Start the actual training # ==================================================================================== iter = checkpoint_params.iter # helper function to write a checkpoint def make_checkpoint(base_dir, prefix, version=None): if version: checkpoint_path = os.path.abspath( os.path.join(base_dir, "{}{}.ckpt".format(prefix, version))) else: checkpoint_path = os.path.abspath( os.path.join(base_dir, "{}{:08d}.ckpt".format(prefix, iter + 1))) print("Storing checkpoint to '{}'".format(checkpoint_path)) train_net.save_checkpoint(checkpoint_path) checkpoint_params.version = Checkpoint.VERSION checkpoint_params.iter = iter checkpoint_params.loss_stats[:] = loss_stats.values checkpoint_params.ler_stats[:] = ler_stats.values checkpoint_params.dt_stats[:] = dt_stats.values checkpoint_params.total_time = time.time() - train_start_time checkpoint_params.early_stopping_best_accuracy = early_stopping_best_accuracy checkpoint_params.early_stopping_best_cur_nbest = early_stopping_best_cur_nbest checkpoint_params.early_stopping_best_at_iter = early_stopping_best_at_iter with open(checkpoint_path + ".json", 'w') as f: f.write(json_format.MessageToJson(checkpoint_params)) return checkpoint_path try: last_checkpoint = None n_infinite_losses = 0 n_max_infinite_losses = 5 # Training loop, can be interrupted by early stopping for iter in range(iter, checkpoint_params.max_iters): checkpoint_params.iter = iter iter_start_time = time.time() result = train_net.train_step() if not np.isfinite(result['loss']): n_infinite_losses += 1 if n_max_infinite_losses == n_infinite_losses: print( "Error: Loss is not finite! Trying to restart from last checkpoint." ) if not last_checkpoint: raise Exception( "No checkpoint written yet. Training must be stopped." ) else: # reload also non trainable weights, such as solver-specific variables train_net.load_weights( last_checkpoint, restore_only_trainable=False) continue else: continue n_infinite_losses = 0 loss_stats.push(result['loss']) ler_stats.push(result['ler']) dt_stats.push(time.time() - iter_start_time) if display > 0 and iter % display == 0: # apply postprocessing to display the true output pred_sentence = self.txt_postproc.apply("".join( codec.decode(result["decoded"][0]))) gt_sentence = self.txt_postproc.apply("".join( codec.decode(result["gt"][0]))) if display_epochs: print("#{:08f}: loss={:.8f} ler={:.8f} dt={:.8f}s". format(iter / iters_per_epoch, loss_stats.mean(), ler_stats.mean(), dt_stats.mean())) else: print("#{:08d}: loss={:.8f} ler={:.8f} dt={:.8f}s". format(iter, loss_stats.mean(), ler_stats.mean(), dt_stats.mean())) # Insert utf-8 ltr/rtl direction marks for bidi support lr = "\u202A\u202B" print(" PRED: '{}{}{}'".format( lr[bidi.get_base_level(pred_sentence)], pred_sentence, "\u202C")) print(" TRUE: '{}{}{}'".format( lr[bidi.get_base_level(gt_sentence)], gt_sentence, "\u202C")) if checkpoint_frequency > 0 and ( iter + 1) % checkpoint_frequency == 0: last_checkpoint = make_checkpoint( checkpoint_params.output_dir, checkpoint_params.output_model_prefix) if early_stopping_enabled and ( iter + 1) % early_stopping_frequency == 0: print("Checking early stopping model") out_gen = early_stopping_predictor.predict_input_dataset( validation_dataset, progress_bar=progress_bar) result = Evaluator.evaluate_single_list( map( Evaluator.evaluate_single_args, map( lambda d: tuple( self.txt_preproc.apply([ ''.join(d.ground_truth), d.sentence ])), out_gen))) accuracy = 1 - result["avg_ler"] if accuracy > early_stopping_best_accuracy: early_stopping_best_accuracy = accuracy early_stopping_best_cur_nbest = 1 early_stopping_best_at_iter = iter + 1 # overwrite as best model last_checkpoint = make_checkpoint( checkpoint_params. early_stopping_best_model_output_dir, prefix="", version=checkpoint_params. early_stopping_best_model_prefix, ) print( "Found better model with accuracy of {:%}".format( early_stopping_best_accuracy)) else: early_stopping_best_cur_nbest += 1 print( "No better model found. Currently accuracy of {:%} at iter {} (remaining nbest = {})" .format( early_stopping_best_accuracy, early_stopping_best_at_iter, checkpoint_params.early_stopping_nbest - early_stopping_best_cur_nbest)) if accuracy > 0 and early_stopping_best_cur_nbest >= checkpoint_params.early_stopping_nbest: print("Early stopping now.") break if accuracy >= 1: print( "Reached perfect score on validation set. Early stopping now." ) break except KeyboardInterrupt as e: print("Storing interrupted checkpoint") make_checkpoint(checkpoint_params.output_dir, checkpoint_params.output_model_prefix, "interrupted") raise e print("Total time {}s for {} iterations.".format( time.time() - train_start_time, iter))
import time, io, sys from tqdm import tqdm import tensorflow as tf import sklearn from PIL import Image import numpy as np import pickle from cleverhans import utils_tf from util import cvt2Image, sparse_tuple_from from calamari_ocr.ocr.backends.tensorflow_backend.tensorflow_model import TensorflowModel from calamari_ocr.ocr import Predictor checkpoint = '/home/chenlu/calamari/models/antiqua_modern/4.ckpt.json' predictor = Predictor(checkpoint=checkpoint, batch_size=1, processes=10) network = predictor.network sess, graph = network.session, network.graph codec = network.codec charset = codec.charset encode, decode = codec.encode, codec.decode code2char, char2code = codec.code2char, codec.char2code def invert(data): # 反色 if data.max() < 1.5: return 1 - data else: return 255 - data def transpose(data): # 旋转90度
parser.add_argument("--eps_iter", help="coefficient to adjust step size of each iteration", type=float) parser.add_argument("--nb_iter", help="number of maximum iteration", type=int) parser.add_argument("--batch_size", help="the number of samples per batch", type=int) parser.add_argument("--clip_min", help="the minimum value of images", type=float) parser.add_argument("--clip_max", help="the maximum value of images", type=float) args = parser.parse_args() predictor = Predictor(checkpoint=args.model_path, batch_size=1, processes=10) network = predictor.network sess, graph = network.session, network.graph encode, decode = network.codec.encode, network.codec.decode # set parameters font_name = args.font_name case = args.case pert_type = args.pert_type eps = args.eps eps_iter = args.eps_iter nb_iter = args.nb_iter batch_size = args.batch_size clip_min, clip_max = args.clip_min, args.clip_max # load img data
import tensorflow as tf import sklearn from PIL import Image import numpy as np import pickle, glob, time, sys, os from tqdm import tqdm from cleverhans import utils_tf from util import get_argparse, cvt2Image, sparse_tuple_from from calamari_ocr.ocr.backends.tensorflow_backend.tensorflow_model import TensorflowModel from calamari_ocr.ocr import Predictor # parse the parameters from shell parser = get_argparse() args = parser.parse_args() predictor = Predictor(checkpoint=os.path.join("ocr_model", args.model_path), batch_size=1, processes=10) network = predictor.network sess, graph = network.session, network.graph encode, decode = network.codec.encode, network.codec.decode # build graph with graph.as_default(): # _ 是data_iterator如果是dataset input的话 inputs, input_seq_len, targets, dropout_rate, _, _ = network.create_placeholders() output_seq_len, time_major_logits, time_major_softmax, logits, softmax, decoded, sparse_decoded, scale_factor, log_prob = \ network.create_network(inputs, input_seq_len, dropout_rate, reuse_variables=tf.AUTO_REUSE) loss = tf.nn.ctc_loss(labels=targets, inputs=time_major_logits, sequence_length=output_seq_len, time_major=True, ctc_merge_repeated=True,