def train(self, data_path, num_epoch): logging.info('num_epoch: %d', num_epoch) s_gen = DataGen( data_path, self.buckets, epochs=num_epoch, max_width=self.max_original_width ) step_time = 0.0 loss = 0.0 current_step = 0 skipped_counter = 0 writer = tf.summary.FileWriter(self.model_dir, self.sess.graph) logging.info('Starting the training process.') for batch in s_gen.gen(self.batch_size): current_step += 1 start_time = time.time() # result = self.step(batch, self.forward_only) result = None try: result = self.step(batch, self.forward_only) except Exception as e: skipped_counter += 1 logging.info("Step {} failed, batch skipped." + " Total skipped: {}".format(current_step, skipped_counter)) logging.error( "Step {} failed. Exception details: {}".format(current_step, str(e))) continue loss += result['loss'] / self.steps_per_checkpoint curr_step_time = (time.time() - start_time) step_time += curr_step_time / self.steps_per_checkpoint # num_correct = 0 # step_outputs = result['prediction'] # grounds = batch['labels'] # for output, ground in zip(step_outputs, grounds): # if self.use_distance: # incorrect = distance.levenshtein(output, ground) # incorrect = float(incorrect) / len(ground) # incorrect = min(1.0, incorrect) # else: # incorrect = 0 if output == ground else 1 # num_correct += 1. - incorrect writer.add_summary(result['summaries'], current_step) # precision = num_correct / len(batch['labels']) step_perplexity = math.exp(result['loss']) if result['loss'] < 300 else float('inf') # logging.info('Step %i: %.3fs, precision: %.2f, loss: %f, perplexity: %f.' # % (current_step, curr_step_time, precision*100, # result['loss'], step_perplexity)) logging.info('Step %i: %.3fs, loss: %f, perplexity: %f.', current_step, curr_step_time, result['loss'], step_perplexity) # Once in a while, we save checkpoint, print statistics, and run evals. if current_step % self.steps_per_checkpoint == 0: perplexity = math.exp(loss) if loss < 300 else float('inf') # Print statistics for the previous epoch. logging.info("Global step %d. Time: %.3f, loss: %f, perplexity: %.2f.", self.sess.run(self.global_step), step_time, loss, perplexity) # Save checkpoint and reset timer and loss. logging.info("Saving the model at step %d.", current_step) self.saver_all.save(self.sess, self.checkpoint_path, global_step=self.global_step) step_time, loss = 0.0, 0.0 # Print statistics for the previous epoch. perplexity = math.exp(loss) if loss < 300 else float('inf') logging.info("Global step %d. Time: %.3f, loss: %f, perplexity: %.2f.", self.sess.run(self.global_step), step_time, loss, perplexity) if skipped_counter: logging.info("Skipped {} batches due to errors.".format(skipped_counter)) # Save checkpoint and reset timer and loss. logging.info("Finishing the training and saving the model at step %d.", current_step) self.saver_all.save(self.sess, self.checkpoint_path, global_step=self.global_step)
def test(self, data_path): current_step = 0 num_correct = 0.0 num_total = 0.0 s_gen = DataGen(data_path, self.buckets, epochs=1, max_width=self.max_original_width) for batch in s_gen.gen(1): current_step += 1 # Get a batch (one image) and make a step. start_time = time.time() result = self.step(batch, self.forward_only) curr_step_time = (time.time() - start_time) num_total += 1 output = result['prediction'] ground = batch['labels'][0] comment = batch['comments'][0] ''' if sys.version_info >= (3,): output = output.decode('iso-8859-1') ground = ground.decode('iso-8859-1') comment = comment.decode('iso-8859-1') ''' if sys.version_info >= (3, ): output = output.decode('utf-8') ground = ground.decode('utf-8') comment = comment.decode('utf-8') ground = revert_lex(ground) #print('output ground ', output, ground) probability = result['probability'] if self.use_distance: incorrect = distance.levenshtein(output, ground) if not ground: if not output: incorrect = 0 else: incorrect = 1 else: incorrect = float(incorrect) / len(ground) incorrect = min(1, incorrect) else: incorrect = 0 if output == ground else 1 num_correct += 1. - incorrect if self.visualize: # Attention visualization. thFold = 0.5 normalize = True binarize = True attns_list = [[a.tolist() for a in step_attn] for step_attn in result['attentions']] attns = np.array(attns_list).transpose([1, 0, 2]) visualize_attention(batch['data'], 'out', attns, output, self.max_width, DataGen.IMAGE_HEIGHT, threshold=threshold, normalize=normalize, binarize=binarize, ground=ground, flag=None) step_accuracy = "{:>4.0%}".format(1. - incorrect) if incorrect: correctness = step_accuracy + " ({} vs {}) {}".format( output, ground, comment) else: correctness = step_accuracy + " (" + str(ground) + ")" logging.info( 'Step {:.0f} ({:.3f}s). ' 'Accuracy: {:6.2%}, ' 'loss: {:f}, perplexity: {:0<7.6}, probability: {:6.2%} {}'. format( current_step, curr_step_time, num_correct / num_total, result['loss'], math.exp(result['loss']) if result['loss'] < 300 else float('inf'), probability, correctness))
def main(args=None): if args is None: args = sys.argv[1:] parameters = process_args(args, Config) logging.basicConfig( level=logging.DEBUG, format='%(asctime)-15s %(name)-5s %(levelname)-8s %(message)s', filename=parameters.log_path) console = logging.StreamHandler() console.setLevel(logging.INFO) formatter = logging.Formatter( '%(asctime)-15s %(name)-5s %(levelname)-8s %(message)s') console.setFormatter(formatter) logging.getLogger('').addHandler(console) with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: if parameters.phase == 'dataset': dataset.generate(parameters.annotations_path, parameters.output_path, parameters.log_step, parameters.force_uppercase, parameters.save_filename) return if parameters.full_ascii: DataGen.set_full_ascii_charmap() model = Model( phase=parameters.phase, visualize=parameters.visualize, output_dir=parameters.output_dir, batch_size=parameters.batch_size, initial_learning_rate=parameters.initial_learning_rate, steps_per_checkpoint=parameters.steps_per_checkpoint, model_dir=parameters.model_dir, target_embedding_size=parameters.target_embedding_size, attn_num_hidden=parameters.attn_num_hidden, attn_num_layers=parameters.attn_num_layers, clip_gradients=parameters.clip_gradients, max_gradient_norm=parameters.max_gradient_norm, session=sess, load_model=parameters.load_model, gpu_id=parameters.gpu_id, use_gru=parameters.use_gru, use_distance=parameters.use_distance, max_image_width=parameters.max_width, max_image_height=parameters.max_height, max_prediction_length=parameters.max_prediction, channels=parameters.channels, ) if parameters.phase == 'train': model.train(data_path=parameters.dataset_path, num_epoch=parameters.num_epoch) elif parameters.phase == 'test': model.test(data_path=parameters.dataset_path) elif parameters.phase == 'predict': for line in sys.stdin: filename = line.rstrip() try: with open(filename, 'rb') as img_file: img_file_data = img_file.read() except IOError: logging.error('Result: error while opening file %s.', filename) continue text, probability = model.predict(img_file_data) logging.info('Result: OK. %s %s', '{:.2f}'.format(probability), text) elif parameters.phase == 'export': exporter = Exporter(model) exporter.save(parameters.export_path, parameters.format) return else: raise NotImplementedError