Esempio n. 1
0
def main():
    data_root_dir = r'D:\myData\huawei_datetext\train_img'
    data_path = r'D:\myData\huawei_datetext\train_txt.txt'
    lexicon_file = 'date_lexicon.txt'

    gens = DataGen(data_root_dir,
                   data_path,
                   lexicon_file=lexicon_file,
                   mean=[128],
                   channel=1,
                   evaluate=False,
                   valid_target_len=float('inf'))
    batch_size = 1
    count = 2000
    for k in range(8):
        batch_size *= 2
        count = count // 2
        print('batch_size = ', batch_size)
        for i, batch in enumerate(gens.gen(batch_size)):
            if i % count == 0:
                print("get batch index : " + str(i))
Esempio n. 2
0
    def train(self, data_path, num_epoch, learning_rate):
        logging.info('num_epoch: %d', num_epoch)
        s_gen = DataGen(data_path,
                        self.buckets,
                        epochs=num_epoch,
                        max_width=self.max_original_width)
        step_time = 0.0
        loss = 0.0
        current_step = 0
        skipped_counter = 0
        writer = tf.summary.FileWriter(self.model_dir, self.sess.graph)

        logging.info('Starting the training process.')
        for batch in s_gen.gen(self.batch_size):

            current_step += 1

            start_time = time.time()
            # result = self.step(batch, self.forward_only)
            result = None
            try:
                result = self.step(batch, self.forward_only, learning_rate)
            except Exception as e:
                skipped_counter += 1
                logging.info(
                    "Step {} failed, batch skipped." +
                    " Total skipped: {}".format(current_step, skipped_counter))
                logging.error("Step {} failed. Exception details: {}".format(
                    current_step, str(e)))
                continue

            loss += result['loss'] / self.steps_per_checkpoint
            curr_step_time = (time.time() - start_time)
            step_time += curr_step_time / self.steps_per_checkpoint

            # num_correct = 0

            # step_outputs = result['prediction']
            # grounds = batch['labels']
            # for output, ground in zip(step_outputs, grounds):
            #     if self.use_distance:
            #         incorrect = distance.levenshtein(output, ground)
            #         incorrect = float(incorrect) / len(ground)
            #         incorrect = min(1.0, incorrect)
            #     else:
            #         incorrect = 0 if output == ground else 1
            #     num_correct += 1. - incorrect

            writer.add_summary(result['summaries'], current_step)

            # precision = num_correct / len(batch['labels'])
            step_perplexity = math.exp(
                result['loss']) if result['loss'] < 300 else float('inf')

            # logging.info('Step %i: %.3fs, precision: %.2f, loss: %f, perplexity: %f.'
            #              % (current_step, curr_step_time, precision*100,
            #                 result['loss'], step_perplexity))

            logging.info('Step %i: %.3fs, loss: %f, perplexity: %f.',
                         current_step, curr_step_time, result['loss'],
                         step_perplexity)

            # Once in a while, we save checkpoint, print statistics, and run evals.
            if current_step % self.steps_per_checkpoint == 0:
                perplexity = math.exp(loss) if loss < 300 else float('inf')
                # Print statistics for the previous epoch.
                logging.info(
                    "Global step %d. Time: %.3f, loss: %f, perplexity: %.2f.",
                    self.sess.run(self.global_step), step_time, loss,
                    perplexity)
                # Save checkpoint and reset timer and loss.
                logging.info("Saving the model at step %d.", current_step)
                self.saver_all.save(self.sess,
                                    self.checkpoint_path,
                                    global_step=self.global_step)
                step_time, loss = 0.0, 0.0

        # Print statistics for the previous epoch.
        perplexity = math.exp(loss) if loss < 300 else float('inf')
        logging.info("Global step %d. Time: %.3f, loss: %f, perplexity: %.2f.",
                     self.sess.run(self.global_step), step_time, loss,
                     perplexity)

        if skipped_counter:
            logging.info(
                "Skipped {} batches due to errors.".format(skipped_counter))

        # Save checkpoint and reset timer and loss.
        logging.info("Finishing the training and saving the model at step %d.",
                     current_step)
        self.saver_all.save(self.sess,
                            self.checkpoint_path,
                            global_step=self.global_step)
Esempio n. 3
0
    def train(self, data_path, num_epoch):
        logging.info('num_epoch: %d' % num_epoch)
        s_gen = DataGen(data_path,
                        self.buckets,
                        epochs=num_epoch,
                        max_width=self.max_original_width)
        step_time = 0.0
        loss = 0.0
        current_step = 0
        writer = tf.summary.FileWriter(self.model_dir, self.sess.graph)

        log = open('log.txt', 'a')
        log.write('Starting the training process.')
        log.close()
        for batch in s_gen.gen(self.batch_size):

            current_step += 1

            start_time = time.time()
            result = self.step(batch, self.forward_only)
            loss += result['loss'] / self.steps_per_checkpoint
            curr_step_time = (time.time() - start_time)
            step_time += curr_step_time / self.steps_per_checkpoint

            # num_correct = 0

            # step_outputs = result['prediction']
            # grounds = batch['labels']
            # for output, ground in zip(step_outputs, grounds):
            #     if self.use_distance:
            #         incorrect = distance.levenshtein(output, ground)
            #         incorrect = float(incorrect) / len(ground)
            #         incorrect = min(1.0, incorrect)
            #     else:
            #         incorrect = 0 if output == ground else 1
            #     num_correct += 1. - incorrect

            writer.add_summary(result['summaries'], current_step)

            # precision = num_correct / len(batch['labels'])
            step_perplexity = math.exp(
                result['loss']) if result['loss'] < 300 else float('inf')

            # logging.info('Step %i: %.3fs, precision: %.2f, loss: %f, perplexity: %f.'
            #              % (current_step, curr_step_time, precision*100, result['loss'], step_perplexity))

            log = open('log.txt', 'a')
            log.write('\nStep %i: %.3fs, loss: %f, perplexity: %f.' %
                      (current_step, curr_step_time, result['loss'],
                       step_perplexity))

            # Once in a while, we save checkpoint, print statistics, and run evals.
            if current_step % self.steps_per_checkpoint == 0:
                perplexity = math.exp(loss) if loss < 300 else float('inf')
                # Print statistics for the previous epoch.
                log.write(
                    "\nGlobal step %d. Time: %.3f, loss: %f, perplexity: %.2f."
                    % (self.sess.run(
                        self.global_step), step_time, loss, perplexity))
                # Save checkpoint and reset timer and loss.
                log.write("\nSaving the model at step %d." % current_step)
                self.saver_all.save(self.sess,
                                    self.checkpoint_path,
                                    global_step=self.global_step)
                step_time, loss = 0.0, 0.0

            log.close()

        # Print statistics for the previous epoch.
        log = open('log.txt', 'a')
        perplexity = math.exp(loss) if loss < 300 else float('inf')
        log.write(
            "Global step %d. Time: %.3f, loss: %f, perplexity: %.2f." %
            (self.sess.run(self.global_step), step_time, loss, perplexity))
        # Save checkpoint and reset timer and loss.
        log.write("Finishing the training and saving the model at step %d." %
                  current_step)
        log.close()
        self.saver_all.save(self.sess,
                            self.checkpoint_path,
                            global_step=self.global_step)
Esempio n. 4
0
    def test(self, data_path):
        current_step = 0
        num_correct = 0.0
        num_total = 0.0

        s_gen = DataGen(data_path,
                        self.buckets,
                        epochs=1,
                        max_width=self.max_original_width)
        for batch in s_gen.gen(1):
            current_step += 1
            # Get a batch (one image) and make a step.
            start_time = time.time()
            result = self.step(batch, self.forward_only, 0.0)
            curr_step_time = (time.time() - start_time)

            num_total += 1

            output = result['prediction']
            ground = batch['labels'][0]
            comment = batch['comments'][0]
            if sys.version_info >= (3, ):
                output = output.decode('iso-8859-1')
                ground = ground.decode('iso-8859-1')
                comment = comment.decode('iso-8859-1')

            probability = result['probability']

            if self.use_distance:
                incorrect = distance.levenshtein(output, ground)
                if not ground:
                    if not output:
                        incorrect = 0
                    else:
                        incorrect = 1
                else:
                    incorrect = float(incorrect) / len(ground)
                incorrect = min(1, incorrect)
            else:
                incorrect = 0 if output == ground else 1

            num_correct += 1. - incorrect

            if self.visualize:
                # Attention visualization.
                threshold = 0.5
                normalize = True
                binarize = True
                attns_list = [[a.tolist() for a in step_attn]
                              for step_attn in result['attentions']]
                attns = np.array(attns_list).transpose([1, 0, 2])
                visualize_attention(batch['data'],
                                    'out',
                                    attns,
                                    output,
                                    self.max_width,
                                    DataGen.IMAGE_HEIGHT,
                                    threshold=threshold,
                                    normalize=normalize,
                                    binarize=binarize,
                                    ground=ground,
                                    flag=None)

            step_accuracy = "{:>4.0%}".format(1. - incorrect)
            if incorrect:
                correctness = step_accuracy + " ({} vs {}) {}".format(
                    output, ground, comment)
            else:
                correctness = step_accuracy + " (" + ground + ")"

            logging.info(
                'Step {:.0f} ({:.3f}s). '
                'Accuracy: {:6.2%}, '
                'loss: {:f}, perplexity: {:0<7.6}, probability: {:6.2%} {}'.
                format(
                    current_step, curr_step_time, num_correct / num_total,
                    result['loss'],
                    math.exp(result['loss']) if result['loss'] < 300 else
                    float('inf'), probability, correctness))
        return num_correct / num_total
Esempio n. 5
0
    def test(self, data_path):
        current_step = 0
        num_correct = 0.0
        num_total = 0.0

        s_gen = DataGen(data_path,
                        self.buckets,
                        epochs=1,
                        max_width=self.max_original_width)
        for batch in s_gen.gen(1):
            current_step += 1
            # Get a batch (one image) and make a step.
            start_time = time.time()
            result = self.step(batch, self.forward_only)
            curr_step_time = (time.time() - start_time)

            if self.visualize:
                step_attns = np.array([[a.tolist() for a in step_attn]
                                       for step_attn in result['attentions']
                                       ]).transpose([1, 0, 2])

            num_total += 1

            output = result['prediction']
            ground = batch['labels'][0]
            comment = batch['comments'][0]
            if sys.version_info >= (3, ):
                output = output.decode('iso-8859-1')
                ground = ground.decode('iso-8859-1')
                comment = comment.decode('iso-8859-1')

            probability = result['probability']

            if self.use_distance:
                incorrect = distance.levenshtein(output, ground)
                if len(ground) == 0:
                    if len(output) == 0:
                        incorrect = 0
                    else:
                        incorrect = 1
                else:
                    incorrect = float(incorrect) / len(ground)
                incorrect = min(1, incorrect)
            else:
                incorrect = 0 if output == ground else 1

            num_correct += 1. - incorrect

            if self.visualize:
                self.visualize_attention(batch['data'], step_attns[0], output,
                                         ground, incorrect)

            step_accuracy = "{:>4.0%}".format(1. - incorrect)
            correctness = step_accuracy + (" ({} vs {}) {}".format(
                output, ground, comment) if incorrect else " (" + ground + ")")

            log = open('log.txt', 'a')
            log.write(
                '\nStep {:.0f} ({:.3f}s). Accuracy: {:6.2%}, loss: {:f}, perplexity: {:0<7.6}, probability: {:6.2%} {}'
                .format(
                    current_step, curr_step_time, num_correct / num_total,
                    result['loss'],
                    math.exp(result['loss']) if result['loss'] < 300 else
                    float('inf'), probability, correctness))
            log.close()