def tensor_define(self, model_path, charset_path, label_file): converter = LabelConverter(chars_file=charset_path) dataset = ImgDataset(label_file, converter, batch_size=1, shuffle=False) model = CRNN(self.cfg, num_classes=converter.num_classes) sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) sess.run(dataset.init_op) sess = self.load_model(model_path=model_path, sess=sess) global_step = sess.run(model.global_step) self.sess = sess self.converter = converter self.dataset = dataset self.input = [ model.inputs, model.labels, model.bat_labels, model.len_labels, model.char_num, model.char_pos_init, model.is_training ] self.output = [model.dense_decoded, model.char_pos, model.embedding] self.global_step = global_step
def infer(args): converter = LabelConverter(chars_file=args.chars_file) dataset = ImgDataset(args.infer_dir, converter, args.infer_batch_size, shuffle=False) config = tf.ConfigProto(allow_soft_placement=True) with tf.Session(config=config) as sess: restore_ckpt(sess, args.ckpt_dir) # for node in sess.graph.as_graph_def().node: # print(node.name) # https://stackoverflow.com/questions/46912721/tensorflow-restore-model-with-sparse-placeholder labels_placeholder = tf.SparseTensor( values=sess.graph.get_tensor_by_name('labels/values:0'), indices=sess.graph.get_tensor_by_name('labels/indices:0'), dense_shape=sess.graph.get_tensor_by_name('labels/shape:0') ) feeds = { 'inputs': sess.graph.get_tensor_by_name('inputs:0'), 'is_training': sess.graph.get_tensor_by_name('is_training:0'), 'labels': labels_placeholder } fetches = [ sess.graph.get_tensor_by_name('SparseToDense:0'), # dense_decoded sess.graph.get_tensor_by_name('Mean_1:0'), # mean edit distance sess.graph.get_tensor_by_name('edit_distance:0') # batch edit distances ] validation(sess, feeds, fetches, dataset, converter, args.result_dir, name='infer', print_batch_info=True, copy_failed=args.infer_copy_failed)
def __init__(self, args): self.args = args self.cfg = load_config(args.cfg_name) self.converter = LabelConverter(chars_file=args.chars_file) self.tr_ds = ImgDataset(args.train_dir, self.converter, self.cfg.batch_size) self.cfg.lr_boundaries = [ self.tr_ds.num_batches * epoch for epoch in self.cfg.lr_decay_epochs ] self.cfg.lr_values = [ self.cfg.lr * (self.cfg.lr_decay_rate**i) for i in range(len(self.cfg.lr_boundaries) + 1) ] if args.val_dir is None: self.val_ds = None else: self.val_ds = ImgDataset(args.val_dir, self.converter, self.cfg.batch_size, shuffle=False) if args.test_dir is None: self.test_ds = None else: # Test images often have different size, so set batch_size to 1 self.test_ds = ImgDataset(args.test_dir, self.converter, shuffle=False, batch_size=1) self.model = CRNN(self.cfg, num_classes=self.converter.num_classes) self.sess = tf.Session(config=tf.ConfigProto( allow_soft_placement=True)) self.epoch_start_index = 0 self.batch_start_index = 0
def validation(sess, feeds, fetches, dataset: ImgDataset, converter, result_dir, name, step=None, print_batch_info=False, copy_failed=False): """ Save file name: {acc}_{step}.txt :param sess: tensorflow session :param model: crnn network :param result_dir: :param name: val, test, infer :return: """ sess.run(dataset.init_op) img_paths = [] predicts = [] labels = [] edit_distances = [] total_batch_time = 0 for batch in range(dataset.num_batches): img_batch, label_batch, batch_labels, batch_img_paths = dataset.get_next_batch( sess) batch_start_time = time.time() feed = { feeds['inputs']: img_batch, feeds['labels']: label_batch, feeds['is_training']: False } batch_predicts, edit_distance, batch_edit_distances = sess.run( fetches, feed) batch_predicts = [ converter.decode(p, CRNN.CTC_INVALID_INDEX) for p in batch_predicts ] img_paths.extend(batch_img_paths) predicts.extend(batch_predicts) labels.extend(batch_labels) edit_distances.extend(batch_edit_distances) acc, correct_count = calculate_accuracy(batch_predicts, batch_labels) batch_time = time.time() - batch_start_time total_batch_time += batch_time if print_batch_info: print( "Batch [{}/{}] {:.03f}s accuracy: {:.03f} ({}/{}), edit_distance: {:.03f}" .format(batch, dataset.num_batches, batch_time, acc, correct_count, dataset.batch_size, edit_distance)) acc, correct_count = calculate_accuracy(predicts, labels) edit_distance_mean = calculate_edit_distance_mean(edit_distances) acc_str = "Accuracy: {:.03f} ({}/{}), Average edit distance: {:.03f}, Average batch time: {:.03f}" \ .format(acc, correct_count, dataset.size, edit_distance_mean, total_batch_time / dataset.num_batches) print(acc_str) save_dir = os.path.join(result_dir, name) utils.check_dir_exist(save_dir) if step is not None: file_path = os.path.join(save_dir, '%.3f_%d.txt' % (acc, step)) else: file_path = os.path.join(save_dir, '%.3f.txt' % acc) print("Write result to %s" % file_path) with open(file_path, 'w', encoding='utf-8') as f: for i, p_label in enumerate(predicts): t_label = labels[i] f.write("{:08d}\n".format(i)) f.write("input: {:17s} length: {}\n".format( t_label, len(t_label))) f.write("predict: {:17s} length: {}\n".format( p_label, len(p_label))) f.write("all match: {}\n".format(1 if t_label == p_label else 0)) f.write("edit distance: {}\n".format(edit_distances[i])) f.write('-' * 30 + '\n') f.write(acc_str + "\n") # Copy image not all match to a dir if copy_failed: failed_infer_img_dir = file_path[:-4] + "_failed" if os.path.exists(failed_infer_img_dir) and os.path.isdir( failed_infer_img_dir): shutil.rmtree(failed_infer_img_dir) utils.check_dir_exist(failed_infer_img_dir) failed_image_indices = [] for i, val in enumerate(edit_distances): if val != 0: failed_image_indices.append(i) for i in failed_image_indices: img_path = img_paths[i] img_name = img_path.split("/")[-1] dst_path = os.path.join(failed_infer_img_dir, img_name) shutil.copyfile(img_path, dst_path) failed_infer_result_file_path = os.path.join(failed_infer_img_dir, "result.txt") with open(failed_infer_result_file_path, 'w', encoding='utf-8') as f: for i in failed_image_indices: p_label = predicts[i] t_label = labels[i] f.write("{:08d}\n".format(i)) f.write("input: {:17s} length: {}\n".format( t_label, len(t_label))) f.write("predict: {:17s} length: {}\n".format( p_label, len(p_label))) f.write("edit distance: {}\n".format(edit_distances[i])) f.write('-' * 30 + '\n') return acc, edit_distance_mean
def train(self, log_dir, restore, log_step, ckpt_dir, val_step, cfg_name, chars_file, train_txt, val_txt, test_txt, result_dir): cfg = load_config(cfg_name) converter = LabelConverter(chars_file=chars_file) tr_ds = ImgDataset(train_txt, converter, cfg.batch_size) cfg.lr_boundaries = [10000] cfg.lr_values = [cfg.lr * (cfg.lr_decay_rate ** i) for i in range(len(cfg.lr_boundaries) + 1)] if val_txt is None: val_ds = None else: val_ds = ImgDataset(val_txt, converter, cfg.batch_size, shuffle=False) if test_txt is None: test_ds = None else: # Test images often have different size, so set batch_size to 1 test_ds = ImgDataset(test_txt, converter, shuffle=False, batch_size=1) model = CRNN(cfg, num_classes=converter.num_classes) epoch_start_index = 0 batch_start_index = 0 config=tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True config.gpu_options.per_process_gpu_memory_fraction = 0.8 sess = tf.Session(config=config) sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(tf.global_variables(), max_to_keep=8) train_writer = tf.summary.FileWriter(log_dir, sess.graph) if restore: self._restore(sess, saver, model,tr_ds, ckpt_dir) print('Begin training...') for epoch in range(epoch_start_index, cfg.epochs): sess.run(tr_ds.init_op) for batch in range(batch_start_index, tr_ds.num_batches): batch_start_time = time.time() if batch != 0 and (batch % log_step == 0): batch_cost, global_step, lr = self._train_with_summary( model, tr_ds, sess, train_writer, converter) else: batch_cost, global_step, lr = self._train(model, tr_ds, sess) print("epoch: {}, batch: {}/{}, step: {}, time: {:.02f}s, loss: {:.05}, lr: {:.05}" .format(epoch, batch, tr_ds.num_batches, global_step, time.time() - batch_start_time, batch_cost, lr)) if global_step != 0 and (global_step % val_step == 0): val_acc = self._do_val(val_ds, epoch, global_step, "val", sess, model, converter, train_writer, cfg, result_dir) test_acc = self._do_val(test_ds, epoch, global_step, "test", sess, model, converter, train_writer, cfg, result_dir) self._save_checkpoint(ckpt_dir, global_step, saver, sess, val_acc, test_acc) batch_start_index = 0
class Trainer(object): def __init__(self, args): self.args = args self.cfg = load_config(args.cfg_name) self.converter = LabelConverter(chars_file=args.chars_file) self.tr_ds = ImgDataset(args.train_dir, self.converter, self.cfg.batch_size) self.cfg.lr_boundaries = [ self.tr_ds.num_batches * epoch for epoch in self.cfg.lr_decay_epochs ] self.cfg.lr_values = [ self.cfg.lr * (self.cfg.lr_decay_rate**i) for i in range(len(self.cfg.lr_boundaries) + 1) ] if args.val_dir is None: self.val_ds = None else: self.val_ds = ImgDataset(args.val_dir, self.converter, self.cfg.batch_size, shuffle=False) if args.test_dir is None: self.test_ds = None else: # Test images often have different size, so set batch_size to 1 self.test_ds = ImgDataset(args.test_dir, self.converter, shuffle=False, batch_size=1) self.model = CRNN(self.cfg, num_classes=self.converter.num_classes) self.sess = tf.Session(config=tf.ConfigProto( allow_soft_placement=True)) self.epoch_start_index = 0 self.batch_start_index = 0 def train(self): self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=8) self.train_writer = tf.summary.FileWriter(self.args.log_dir, self.sess.graph) if self.args.restore: self._restore() print('Begin training...') for epoch in range(self.epoch_start_index, self.cfg.epochs): self.sess.run(self.tr_ds.init_op) for batch in range(self.batch_start_index, self.tr_ds.num_batches): batch_start_time = time.time() if batch != 0 and (batch % self.args.log_step == 0): batch_cost, global_step, lr = self._train_with_summary() else: batch_cost, global_step, lr = self._train() print( "epoch: {}, batch: {}/{}, step: {}, time: {:.02f}s, loss: {:.05}, lr: {:.05}" .format(epoch, batch, self.tr_ds.num_batches, global_step, time.time() - batch_start_time, batch_cost, lr)) if global_step != 0 and (global_step % self.args.val_step == 0): val_acc = self._do_val(self.val_ds, epoch, global_step, "val") test_acc = self._do_val(self.test_ds, epoch, global_step, "test") self._save_checkpoint(self.args.ckpt_dir, global_step, val_acc, test_acc) self.batch_start_index = 0 def _restore(self): utils.restore_ckpt(self.sess, self.saver, self.args.ckpt_dir) step_restored = self.sess.run(self.model.global_step) self.epoch_start_index = math.floor(step_restored / self.tr_ds.num_batches) self.batch_start_index = step_restored % self.tr_ds.num_batches print("Restored global step: %d" % step_restored) print("Restored epoch: %d" % self.epoch_start_index) print("Restored batch_start_index: %d" % self.batch_start_index) def _train(self): img_batch, label_batch, labels, _ = self.tr_ds.get_next_batch( self.sess) feed = { self.model.inputs: img_batch, self.model.labels: label_batch, self.model.is_training: True } fetches = [ self.model.total_loss, self.model.ctc_loss, self.model.regularization_loss, self.model.global_step, self.model.lr, self.model.train_op ] batch_cost, _, _, global_step, lr, _ = self.sess.run(fetches, feed) return batch_cost, global_step, lr def _train_with_summary(self): img_batch, label_batch, labels, _ = self.tr_ds.get_next_batch( self.sess) feed = { self.model.inputs: img_batch, self.model.labels: label_batch, self.model.is_training: True } fetches = [ self.model.total_loss, self.model.ctc_loss, self.model.regularization_loss, self.model.global_step, self.model.lr, self.model.merged_summay, self.model.dense_decoded, self.model.edit_distance, self.model.train_op ] batch_cost, _, _, global_step, lr, summary, predicts, edit_distance, _ = self.sess.run( fetches, feed) self.train_writer.add_summary(summary, global_step) predicts = [ self.converter.decode(p, CRNN.CTC_INVALID_INDEX) for p in predicts ] accuracy, _ = infer.calculate_accuracy(predicts, labels) tf_utils.add_scalar_summary(self.train_writer, "train_accuracy", accuracy, global_step) tf_utils.add_scalar_summary(self.train_writer, "train_edit_distance", edit_distance, global_step) return batch_cost, global_step, lr def _do_val(self, dataset, epoch, step, name): if dataset is None: return None accuracy, edit_distance = infer.validation(self.sess, self.model.feeds(), self.model.fetches(), dataset, self.converter, self.args.result_dir, name, step) tf_utils.add_scalar_summary(self.train_writer, "%s_accuracy" % name, accuracy, step) tf_utils.add_scalar_summary(self.train_writer, "%s_edit_distance" % name, edit_distance, step) print("epoch: %d/%d, %s accuracy = %.3f" % (epoch, self.cfg.epochs, name, accuracy)) return accuracy def _save_checkpoint(self, ckpt_dir, step, val_acc=None, test_acc=None): ckpt_name = "crnn_%d" % step if val_acc is not None: ckpt_name += '_val_%.03f' % val_acc if test_acc is not None: ckpt_name += '_test_%.03f' % test_acc name = os.path.join(ckpt_dir, ckpt_name) print("save checkpoint %s" % name) meta_exists, meta_file_name = self._meta_file_exist(ckpt_dir) self.saver.save(self.sess, name) # remove old meta file to save disk space if meta_exists: try: os.remove(os.path.join(ckpt_dir, meta_file_name)) except: print('Remove meta file failed: %s' % meta_file_name) def _meta_file_exist(self, ckpt_dir): fnames = os.listdir(ckpt_dir) meta_exists = False meta_file_name = '' for n in fnames: if 'meta' in n: meta_exists = True meta_file_name = n break return meta_exists, meta_file_name