def inference(config_file, image_file): """ Run text recognition network on an image file. """ # Get config FLAGS = Flags(config_file).get() out_charset = load_charset(FLAGS.charset) num_classes = len(out_charset) net = get_network(FLAGS, out_charset) if FLAGS.use_rgb: num_channel = 3 mode = cv2.IMREAD_COLOR else: num_channel = 1 mode = cv2.IMREAD_GRAYSCALE # Input node image = tf.placeholder(tf.uint8, shape=[None, None, num_channel], name='input_node') # Network proc_image = net.preprocess_image(image, is_train=False) proc_image = tf.expand_dims(proc_image, axis=0) proc_image.set_shape( [None, FLAGS.resize_hw.height, FLAGS.resize_hw.width, num_channel]) logits, sequence_length = net.get_logits(proc_image, is_train=False, label=None) prediction, log_prob = net.get_prediction(logits, sequence_length) prediction = tf.sparse_to_dense(sparse_indices=prediction.indices, sparse_values=prediction.values, output_shape=prediction.dense_shape, default_value=num_classes, name='output_node') # Restore restore_model = get_init_trained() sess = tf.Session() restore_model(sess, FLAGS.eval.model_path) # Run img = cv2.imread(image_file, mode) img = np.reshape(img, [img.shape[0], img.shape[1], num_channel]) predicted = sess.run(prediction, feed_dict={image: img}) string = get_string(predicted[0], out_charset) string = adjust_string(string, FLAGS.eval.lowercase, FLAGS.eval.alphanumeric) print(string) return string
def main(config_file): """ Train text recognition network """ # Parse configs FLAGS = Flags(config_file).get() # Set directory, seed, logger model_dir = create_model_dir(FLAGS.model_dir) logger = get_logger(model_dir, 'train') best_model_dir = os.path.join(model_dir, 'best_models') set_seed(FLAGS.seed) # Print configs flag_strs = [ '{}:\t{}'.format(name, value) for name, value in FLAGS._asdict().items() ] log_formatted(logger, '[+] Model configurations', *flag_strs) # Print system environments num_gpus = count_available_gpus() num_cpus = os.cpu_count() mem_size = virtual_memory().available // (1024**3) log_formatted(logger, '[+] System environments', 'The number of gpus : {}'.format(num_gpus), 'The number of cpus : {}'.format(num_cpus), 'Memory Size : {}G'.format(mem_size)) # Get optimizer and network global_step = tf.train.get_or_create_global_step() optimizer, learning_rate = get_optimizer(FLAGS.train.optimizer, global_step) out_charset = load_charset(FLAGS.charset) net = get_network(FLAGS, out_charset) is_ctc = (net.loss_fn == 'ctc_loss') # Multi tower for multi-gpu training tower_grads = [] tower_extra_update_ops = [] tower_preds = [] tower_gts = [] tower_losses = [] batch_size = FLAGS.train.batch_size tower_batch_size = batch_size // num_gpus val_tower_outputs = [] eval_tower_outputs = [] for gpu_indx in range(num_gpus): # Train tower print('[+] Build Train tower GPU:%d' % gpu_indx) input_device = '/gpu:%d' % gpu_indx tower_batch_size = tower_batch_size \ if gpu_indx < num_gpus-1 \ else batch_size - tower_batch_size * (num_gpus-1) train_loader = DatasetLodaer( dataset_paths=FLAGS.train.dataset_paths, dataset_portions=FLAGS.train.dataset_portions, batch_size=tower_batch_size, label_maxlen=FLAGS.label_maxlen, out_charset=out_charset, preprocess_image=net.preprocess_image, is_train=True, is_ctc=is_ctc, shuffle_and_repeat=True, concat_batch=True, input_device=input_device, num_cpus=num_cpus, num_gpus=num_gpus, worker_index=gpu_indx, use_rgb=FLAGS.use_rgb, seed=FLAGS.seed, name='train') tower_output = single_tower(net, gpu_indx, train_loader, out_charset, optimizer, name='train', is_train=True) tower_grads.append([x for x in tower_output.grads if x[0] is not None]) tower_extra_update_ops.append(tower_output.extra_update_ops) tower_preds.append(tower_output.prediction) tower_gts.append(tower_output.text) tower_losses.append(tower_output.loss) # Print network structure if gpu_indx == 0: param_stats = tf.profiler.profile(tf.get_default_graph()) logger.info('total_params: %d\n' % param_stats.total_parameters) # Valid tower print('[+] Build Valid tower GPU:%d' % gpu_indx) valid_loader = DatasetLodaer(dataset_paths=FLAGS.valid.dataset_paths, dataset_portions=None, batch_size=FLAGS.valid.batch_size // num_gpus, label_maxlen=FLAGS.label_maxlen, out_charset=out_charset, preprocess_image=net.preprocess_image, is_train=False, is_ctc=is_ctc, shuffle_and_repeat=False, concat_batch=False, input_device=input_device, num_cpus=num_cpus, num_gpus=num_gpus, worker_index=gpu_indx, use_rgb=FLAGS.use_rgb, seed=FLAGS.seed, name='valid') val_tower_output = single_tower(net, gpu_indx, valid_loader, out_charset, optimizer=None, name='valid', is_train=False) val_tower_outputs.append( (val_tower_output.loss, val_tower_output.prediction, val_tower_output.text, val_tower_output.filename, val_tower_output.dataset)) # Aggregate gradients losses = tf.reduce_mean(tower_losses) grads = _average_gradients(tower_grads) with tf.control_dependencies(tower_extra_update_ops[-1]): if FLAGS.train.optimizer.grad_clip_norm is not None: grads, global_norm = _clip_gradients( grads, FLAGS.train.optimizer.grad_clip_norm) tf.summary.scalar('global_norm', global_norm) train_op = optimizer.apply_gradients(grads, global_step=global_step) # Define config, scaffold saver = tf.train.Saver() sess_config = get_session_config() scaffold = get_scaffold(saver, FLAGS.train.tune_from, 'train') restore_model = get_init_trained() # Define validation saver, summary writer summaries = tf.get_collection(tf.GraphKeys.SUMMARIES) val_summary_op = tf.summary.merge( [s for s in summaries if 'valid' in s.name]) val_summary_writer = { dataset_name: tf.summary.FileWriter(os.path.join(model_dir, 'valid', dataset_name)) for dataset_name in valid_loader.dataset_names } val_summary_writer['total_valid'] = tf.summary.FileWriter( os.path.join(model_dir, 'valid', 'total_valid')) val_saver = tf.train.Saver(max_to_keep=len(valid_loader.dataset_names) + 1) best_val_err_rates = {} best_steps = {} # Training print('[+] Make Session...') with tf.train.MonitoredTrainingSession( checkpoint_dir=model_dir, scaffold=scaffold, config=sess_config, save_checkpoint_steps=FLAGS.train.save_steps, save_checkpoint_secs=None, save_summaries_steps=FLAGS.train.summary_steps, save_summaries_secs=None, ) as sess: log_formatted(logger, 'Training started!') _step = 0 train_t = 0 start_t = time.time() while _step < FLAGS.train.max_num_steps \ and not sess.should_stop(): # Train step step_t = time.time() [step_loss, _, _step, preds, gts, lr] = sess.run([ losses, train_op, global_step, tower_preds[0], tower_gts[0], learning_rate ]) train_t += time.time() - step_t # Summary if _step % FLAGS.valid.steps == 0: # Train summary train_err = 0. for i, (p, g) in enumerate(zip(preds, gts)): s = get_string(p, out_charset, is_ctc=is_ctc) g = g.decode('utf8').replace(DELIMITER, '') s = adjust_string(s, FLAGS.train.lowercase, FLAGS.train.alphanumeric) g = adjust_string(g, FLAGS.train.lowercase, FLAGS.train.alphanumeric) e = int(s != g) train_err += e if FLAGS.train.verbose and i < 5: print('TRAIN :\t{}\t{}\t{}'.format(s, g, not bool(e))) train_err_rate = \ train_err / len(gts) # Valid summary val_cnts, val_errs, val_err_rates, _ = \ validate(sess, _step, val_tower_outputs, out_charset, is_ctc, val_summary_op, val_summary_writer, val_saver, best_val_err_rates, best_steps, best_model_dir, FLAGS.valid.lowercase, FLAGS.valid.alphanumeric) # Logging log_strings = ['', '-' * 28 + ' VALID_DETAIL ' + '-' * 28, ''] for dataset in sorted(val_err_rates.keys()): if dataset == 'total_valid': continue cnt = val_cnts[dataset] err = val_errs[dataset] err_rate = val_err_rates[dataset] best_step = best_steps[dataset] s = '%s : %.2f%%(%d/%d)\tBEST_STEP : %d' % \ (dataset, (1.-err_rate)*100, cnt-err, cnt, best_step) log_strings.append(s) elapsed_t = float(time.time() - start_t) / 60 remain_t = (elapsed_t / (_step+1)) * \ (FLAGS.train.max_num_steps - _step - 1) log_formatted( logger, 'STEP : %d\tTRAIN_LOSS : %f' % (_step, step_loss), 'ELAPSED : %.2f min\tREMAIN : %.2f min\t' 'STEP_TIME: %.1f sec' % (elapsed_t, remain_t, float(train_t) / (_step + 1)), 'TRAIN_SEQ_ERR : %f\tVALID_SEQ_ERR : %f' % (train_err_rate, val_err_rates['total_valid']), 'BEST_STEP : %d\tBEST_VALID_SEQ_ERR : %f' % (best_steps['total_valid'], best_val_err_rates['total_valid']), *log_strings) log_formatted(logger, 'Training is completed!')
def main(config_file=None): """ Run evaluation. """ # Parse Config print('[+] Model configurations') FLAGS = Flags(config_file).get() for name, value in FLAGS._asdict().items(): print('{}:\t{}'.format(name, value)) print('\n') # System environments num_gpus = count_available_gpus() num_cpus = os.cpu_count() mem_size = virtual_memory().available // (1024**3) out_charset = load_charset(FLAGS.charset) print('[+] System environments') print('The number of gpus : {}'.format(num_gpus)) print('The number of cpus : {}'.format(num_cpus)) print('Memory Size : {}G'.format(mem_size)) print('The number of characters : {}\n'.format(len(out_charset))) # Make results dir res_dir = os.path.join(FLAGS.eval.model_path) os.makedirs(res_dir, exist_ok=True) # Get network net = get_network(FLAGS, out_charset) is_ctc = (net.loss_fn == 'ctc_loss') # Define Graph eval_tower_outputs = [] global_step = tf.train.get_or_create_global_step() for gpu_indx in range(num_gpus): # Get eval dataset input_device = '/gpu:%d' % gpu_indx print('[+] Build Eval tower GPU:%d' % gpu_indx) eval_loader = DatasetLodaer(dataset_paths=FLAGS.eval.dataset_paths, dataset_portions=None, batch_size=FLAGS.eval.batch_size, label_maxlen=FLAGS.label_maxlen, out_charset=out_charset, preprocess_image=net.preprocess_image, is_train=False, is_ctc=is_ctc, shuffle_and_repeat=False, concat_batch=False, input_device=input_device, num_cpus=num_cpus, num_gpus=num_gpus, worker_index=gpu_indx, use_rgb=FLAGS.use_rgb, seed=FLAGS.seed, name='eval') eval_tower_output = single_tower(net, gpu_indx, eval_loader, out_charset, optimizer=None, name='eval', is_train=False) eval_tower_outputs.append( (eval_tower_output.loss, eval_tower_output.prediction, eval_tower_output.text, eval_tower_output.filename, eval_tower_output.dataset)) # Summary summaries = tf.get_collection(tf.GraphKeys.SUMMARIES) summary_op = tf.summary.merge([s for s in summaries]) summary_writer = { dataset_name: tf.summary.FileWriter(os.path.join(res_dir, dataset_name)) for dataset_name in eval_loader.dataset_names } summary_writer['total_valid'] = tf.summary.FileWriter( os.path.join(res_dir, 'total_eval')) # Define config, scaffold, hooks saver = tf.train.Saver() sess_config = get_session_config() restore_model = get_init_trained() scaffold = get_scaffold(saver, None, 'eval') # Testing with tf.train.MonitoredTrainingSession(scaffold=scaffold, config=sess_config) as sess: # Restore and init. restore_model(sess, FLAGS.eval.model_path) _step = sess.run(global_step) infet_t = 0 # Run test start_t = time.time() eval_cnts, eval_errs, eval_err_rates, eval_preds = \ validate(sess, _step, eval_tower_outputs, out_charset, is_ctc, summary_op, summary_writer, lowercase=FLAGS.eval.lowercase, alphanumeric=FLAGS.eval.alphanumeric) infer_t = time.time() - start_t # Log total_total = 0 for dataset, result in eval_preds.items(): res_file = open(os.path.join(res_dir, '{}.txt'.format(dataset)), 'w') total = eval_cnts[dataset] correct = total - eval_errs[dataset] acc = 1. - eval_err_rates[dataset] total_total += total for f, s, g in result: f = f.decode('utf8') if FLAGS.eval.verbose: print('FILE : ' + f) print('PRED : ' + s) print('ANSW : ' + g) print('=' * 50) res_file.write('{}\t{}\n'.format(f, s)) res_s = 'DATASET : %s\tCORRECT : %d\tTOTAL : %d\tACC : %f' % \ (dataset, correct, total, acc) print(res_s) res_file.write(res_s) res_file.close() eval_loader.flush_tmpfile() print('INFER TIME(PER IMAGE) : %f s' % (float(infer_t) / total_total))