示例#1
0
    def train(self, data_path, num_epoch):
        logging.info('num_epoch: %d', num_epoch)
        s_gen = DataGen(
            data_path, self.buckets,
            epochs=num_epoch, max_width=self.max_original_width
        )
        step_time = 0.0
        loss = 0.0
        current_step = 0
        skipped_counter = 0
        writer = tf.summary.FileWriter(self.model_dir, self.sess.graph)

        logging.info('Starting the training process.')
        for batch in s_gen.gen(self.batch_size):

            current_step += 1

            start_time = time.time()
            # result = self.step(batch, self.forward_only)
            result = None
            try:
                result = self.step(batch, self.forward_only)
            except Exception as e:
                skipped_counter += 1
                logging.info("Step {} failed, batch skipped." +
                             " Total skipped: {}".format(current_step, skipped_counter))
                logging.error(
                    "Step {} failed. Exception details: {}".format(current_step, str(e)))
                continue

            loss += result['loss'] / self.steps_per_checkpoint
            curr_step_time = (time.time() - start_time)
            step_time += curr_step_time / self.steps_per_checkpoint

            # num_correct = 0

            # step_outputs = result['prediction']
            # grounds = batch['labels']
            # for output, ground in zip(step_outputs, grounds):
            #     if self.use_distance:
            #         incorrect = distance.levenshtein(output, ground)
            #         incorrect = float(incorrect) / len(ground)
            #         incorrect = min(1.0, incorrect)
            #     else:
            #         incorrect = 0 if output == ground else 1
            #     num_correct += 1. - incorrect

            writer.add_summary(result['summaries'], current_step)

            # precision = num_correct / len(batch['labels'])
            step_perplexity = math.exp(result['loss']) if result['loss'] < 300 else float('inf')

            # logging.info('Step %i: %.3fs, precision: %.2f, loss: %f, perplexity: %f.'
            #              % (current_step, curr_step_time, precision*100,
            #                 result['loss'], step_perplexity))

            logging.info('Step %i: %.3fs, loss: %f, perplexity: %f.',
                         current_step, curr_step_time, result['loss'], step_perplexity)

            # Once in a while, we save checkpoint, print statistics, and run evals.
            if current_step % self.steps_per_checkpoint == 0:
                perplexity = math.exp(loss) if loss < 300 else float('inf')
                # Print statistics for the previous epoch.
                logging.info("Global step %d. Time: %.3f, loss: %f, perplexity: %.2f.",
                             self.sess.run(self.global_step), step_time, loss, perplexity)
                # Save checkpoint and reset timer and loss.
                logging.info("Saving the model at step %d.", current_step)
                self.saver_all.save(self.sess, self.checkpoint_path, global_step=self.global_step)
                step_time, loss = 0.0, 0.0

        # Print statistics for the previous epoch.
        perplexity = math.exp(loss) if loss < 300 else float('inf')
        logging.info("Global step %d. Time: %.3f, loss: %f, perplexity: %.2f.",
                     self.sess.run(self.global_step), step_time, loss, perplexity)

        if skipped_counter:
            logging.info("Skipped {} batches due to errors.".format(skipped_counter))

        # Save checkpoint and reset timer and loss.
        logging.info("Finishing the training and saving the model at step %d.", current_step)
        self.saver_all.save(self.sess, self.checkpoint_path, global_step=self.global_step)
示例#2
0
    def test(self, data_path):
        current_step = 0
        num_correct = 0.0
        num_total = 0.0

        s_gen = DataGen(data_path,
                        self.buckets,
                        epochs=1,
                        max_width=self.max_original_width)

        for batch in s_gen.gen(1):
            current_step += 1
            # Get a batch (one image) and make a step.
            start_time = time.time()
            result = self.step(batch, self.forward_only)
            curr_step_time = (time.time() - start_time)

            num_total += 1

            output = result['prediction']
            ground = batch['labels'][0]
            comment = batch['comments'][0]
            '''     
            if sys.version_info >= (3,):
                output = output.decode('iso-8859-1')
                ground = ground.decode('iso-8859-1')
                comment = comment.decode('iso-8859-1')
            
            '''
            if sys.version_info >= (3, ):
                output = output.decode('utf-8')
                ground = ground.decode('utf-8')
                comment = comment.decode('utf-8')

            ground = revert_lex(ground)
            #print('output ground ', output, ground)

            probability = result['probability']

            if self.use_distance:
                incorrect = distance.levenshtein(output, ground)
                if not ground:
                    if not output:
                        incorrect = 0
                    else:
                        incorrect = 1
                else:
                    incorrect = float(incorrect) / len(ground)
                incorrect = min(1, incorrect)
            else:
                incorrect = 0 if output == ground else 1

            num_correct += 1. - incorrect

            if self.visualize:
                # Attention visualization.
                thFold = 0.5
                normalize = True
                binarize = True
                attns_list = [[a.tolist() for a in step_attn]
                              for step_attn in result['attentions']]
                attns = np.array(attns_list).transpose([1, 0, 2])
                visualize_attention(batch['data'],
                                    'out',
                                    attns,
                                    output,
                                    self.max_width,
                                    DataGen.IMAGE_HEIGHT,
                                    threshold=threshold,
                                    normalize=normalize,
                                    binarize=binarize,
                                    ground=ground,
                                    flag=None)

            step_accuracy = "{:>4.0%}".format(1. - incorrect)

            if incorrect:
                correctness = step_accuracy + " ({} vs {}) {}".format(
                    output, ground, comment)
            else:
                correctness = step_accuracy + " (" + str(ground) + ")"

            logging.info(
                'Step {:.0f} ({:.3f}s). '
                'Accuracy: {:6.2%}, '
                'loss: {:f}, perplexity: {:0<7.6}, probability: {:6.2%} {}'.
                format(
                    current_step, curr_step_time, num_correct / num_total,
                    result['loss'],
                    math.exp(result['loss']) if result['loss'] < 300 else
                    float('inf'), probability, correctness))
示例#3
0
def main(args=None):

    if args is None:
        args = sys.argv[1:]

    parameters = process_args(args, Config)
    logging.basicConfig(
        level=logging.DEBUG,
        format='%(asctime)-15s %(name)-5s %(levelname)-8s %(message)s',
        filename=parameters.log_path)
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)-15s %(name)-5s %(levelname)-8s %(message)s')
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:

        if parameters.phase == 'dataset':
            dataset.generate(parameters.annotations_path,
                             parameters.output_path, parameters.log_step,
                             parameters.force_uppercase,
                             parameters.save_filename)
            return

        if parameters.full_ascii:
            DataGen.set_full_ascii_charmap()

        model = Model(
            phase=parameters.phase,
            visualize=parameters.visualize,
            output_dir=parameters.output_dir,
            batch_size=parameters.batch_size,
            initial_learning_rate=parameters.initial_learning_rate,
            steps_per_checkpoint=parameters.steps_per_checkpoint,
            model_dir=parameters.model_dir,
            target_embedding_size=parameters.target_embedding_size,
            attn_num_hidden=parameters.attn_num_hidden,
            attn_num_layers=parameters.attn_num_layers,
            clip_gradients=parameters.clip_gradients,
            max_gradient_norm=parameters.max_gradient_norm,
            session=sess,
            load_model=parameters.load_model,
            gpu_id=parameters.gpu_id,
            use_gru=parameters.use_gru,
            use_distance=parameters.use_distance,
            max_image_width=parameters.max_width,
            max_image_height=parameters.max_height,
            max_prediction_length=parameters.max_prediction,
            channels=parameters.channels,
        )

        if parameters.phase == 'train':
            model.train(data_path=parameters.dataset_path,
                        num_epoch=parameters.num_epoch)
        elif parameters.phase == 'test':
            model.test(data_path=parameters.dataset_path)
        elif parameters.phase == 'predict':
            for line in sys.stdin:
                filename = line.rstrip()
                try:
                    with open(filename, 'rb') as img_file:
                        img_file_data = img_file.read()
                except IOError:
                    logging.error('Result: error while opening file %s.',
                                  filename)
                    continue
                text, probability = model.predict(img_file_data)
                logging.info('Result: OK. %s %s', '{:.2f}'.format(probability),
                             text)
        elif parameters.phase == 'export':
            exporter = Exporter(model)
            exporter.save(parameters.export_path, parameters.format)
            return
        else:
            raise NotImplementedError