Esempio n. 1
0
  def train(self, corpus, epochs=10, direction="pronounce", window=-1):
    """Runs training."""
    # Create training log that also redirects to stdout.
    stdout_file = sys.stdout
    logfile = os.path.join(self._checkpoint_dir, "train.log")
    print("Training log: {}".format(logfile))
    sys.stdout = utils.DualLogger(logfile)

    # Dump some parameters.
    print("           Direction: {}".format(direction))
    print("            # Epochs: {}".format(epochs))
    print("          Batch size: {}".format(self._batch_size))
    print("         Window size: {}".format(window))
    print("     Max written len: {}".format(corpus.max_written_len))
    print("        Max pron len: {}".format(corpus.max_pronounce_len))
    print("Max written word len: {}".format(corpus.max_written_word_len))
    print("   Max pron word len: {}".format(corpus.max_pronounce_word_len))
    print("       Use residuals: {}".format(self._use_residuals))

    # Perform training.
    best_total_loss = 1000000
    nbatches = data.num_batches(corpus, self._batch_size, direction=direction,
                                window=window)
    for epoch in range(epochs):
      self._train_accuracy.reset_states()

      start = time.time()
      total_loss = 0
      steps = 0
      batches = data.batchify(corpus, self._batch_size, direction,
                              window=window)
      batch, (inputs, targ) = next(batches)
      while batch > -1:
        bos = np.expand_dims(
            [self._output_symbols.find("<s>")] * np.shape(targ)[0], 1)
        targets = np.concatenate((bos, targ), axis=-1)
        batch_loss = self._train_step(inputs, targets)
        total_loss += batch_loss
        if batch % 10 == 0:
          print("Epoch {} Batch {} (/{}) Loss {:.4f}".format(
              epoch + 1,
              batch,
              nbatches,
              batch_loss))
        steps += 1
        batch, (inputs, targ) = next(batches)
      total_loss /= steps
      print("Epoch {} Loss {:.4f} Accuracy {:.4f}".format(
          epoch + 1, total_loss, self._train_accuracy.result()))

      if total_loss < best_total_loss:
        self._checkpoint.save(file_prefix=self._checkpoint_prefix)
        print("Saved checkpoint to {}".format(self._checkpoint_prefix))
        best_total_loss = total_loss
      print("Time taken for 1 epoch {} sec\n".format(
          time.time() - start))
    print("Best total loss: {:.4f}".format(best_total_loss))

    # Restore stdout.
    sys.stdout = stdout_file
Esempio n. 2
0
def _test_language(language,
                   corpus,
                   model,
                   print_predictions=False,
                   show_plots=False,
                   compute_deviation=True,
                   deviation_only_for_correct=True,
                   simple_skew=False):
    """Runs model evaluation."""
    # Create test log that also redirects to stdout.
    stdout_file = sys.stdout
    log_name = "%s.log" % _eval_file_prefix(model)
    logfile = os.path.join(model.checkpoint_dir, log_name)
    print("Test log: {}".format(logfile))
    sys.stdout = utils.DualLogger(logfile)

    print("Window size: {}".format(FLAGS.window))
    print("# test examples: {}".format(FLAGS.ntest))
    test_examples = data.test_examples(corpus,
                                       FLAGS.direction,
                                       window=FLAGS.window)
    indices = data.random_test_indices(test_examples, k=FLAGS.ntest)
    tot, cor, rat, nrat = eval_lib.eval_and_plot(
        model,
        test_examples,
        indices,
        show_plots=show_plots,
        print_predictions=print_predictions,
        print_attention=PRINT_ATTENTION,
        compute_deviation=compute_deviation,
        deviation_mask_sigma=DEVIATION_MASK_SIGMA,
        deviation_only_for_correct=deviation_only_for_correct,
        simple_skew=simple_skew,
        report_type_stats=FLAGS.report_type_stats,
        figsize=FIGSIZE)

    # Write results to the log, stdout and the dedicated file.
    results = [
        "*" * 80, "Language: {}".format(language),
        "Total non-trivial predictions: {}".format(tot),
        "Accuracy: {}".format(cor), "Ratio: {}".format(rat),
        "tf.reduce_max'ed normalized ratio: {}".format(nrat)
    ]
    print("\n".join(results))
    results_file = "%s_results.txt" % _eval_file_prefix(model)
    with open(os.path.join(model.checkpoint_dir, results_file),
              encoding="utf-8",
              mode="wt") as f:
        f.write("\n".join(results) + "\n")

    # Restore stdout.
    sys.stdout = stdout_file
Esempio n. 3
0
def _test_language(language,
                   corpus,
                   model,
                   print_predictions=False,
                   show_plots=False,
                   compute_deviation=True,
                   deviation_only_for_correct=True,
                   simple_skew=False):
    """Runs model evaluation."""
    # Create test log that also redirects to stdout.
    stdout_file = sys.stdout
    logfile = os.path.join(model.checkpoint_dir, "eval.log")
    print("Test log: {}".format(logfile))
    sys.stdout = utils.DualLogger(logfile)

    print("Window size: {}".format(FLAGS.window))
    print("# test examples: {}".format(FLAGS.ntest))
    test_examples = data.test_examples(corpus,
                                       FLAGS.direction,
                                       window=FLAGS.window)
    indices = data.random_test_indices(test_examples, k=FLAGS.ntest)
    tot, cor, rat, nrat = evaluate.eval_and_plot(
        model,
        test_examples,
        indices,
        show_plots=show_plots,
        print_predictions=print_predictions,
        print_attention=_PRINT_ATTENTION,
        compute_deviation=compute_deviation,
        deviation_mask_sigma=_DEVIATION_MASK_SIGMA,
        deviation_only_for_correct=deviation_only_for_correct,
        simple_skew=simple_skew,
        report_type_stats=FLAGS.report_type_stats,
        figsize=_FIGSIZE)
    print("*" * 80)
    print("Language: {}".format(language))
    print("Total non-trivial predictions: {}".format(tot))
    print("Accuracy: {}".format(cor))
    print("Ratio: {}".format(rat))
    print("tf.reduce_max'ed normalized ratio: {}".format(nrat))

    # Restore stdout.
    sys.stdout = stdout_file
Esempio n. 4
0
    def train(self, corpus, epochs=10, direction="pronounce", window=-1):
        """Main entry point for running the training."""
        # Create training log that also redirects to stdout.
        stdout_file = sys.stdout
        logfile = os.path.join(self._checkpoint_dir, "train.log")
        print("Training log: {}".format(logfile))
        sys.stdout = utils.DualLogger(logfile)

        # Dump some parameters.
        print("           Direction: {}".format(direction))
        print("            # Epochs: {}".format(epochs))
        print("          Batch size: {}".format(self._batch_size))
        print("         Window size: {}".format(window))
        print("     Max written len: {}".format(corpus.max_written_len))
        print("        Max pron len: {}".format(corpus.max_pronounce_len))
        print("Max written word len: {}".format(corpus.max_written_word_len))
        print("   Max pron word len: {}".format(corpus.max_pronounce_word_len))

        # Perform training.
        best_total_loss = 1000000
        nbatches = data.num_batches(corpus,
                                    self._batch_size,
                                    direction=direction,
                                    window=window)
        for epoch in range(epochs):
            start = time.time()
            total_loss = 0
            steps = 0
            batches = data.batchify(corpus,
                                    self._batch_size,
                                    direction,
                                    window=window)
            batch, (inp, targ) = next(batches)
            enc_hidden = self._encoder.initialize_hidden_state(
                self._batch_size)
            if self._input_length == -1:
                # TODO(agutkin,rws): Following two lines will break if batchify()
                # returns an empty tuple.
                if isinstance(inp, np.ndarray):
                    self._input_length = inp.shape[1]
                if isinstance(targ, np.ndarray):
                    self._output_length = targ.shape[1]
            while batch > -1:
                batch_loss = self._train_step(inp, targ, enc_hidden)
                total_loss += batch_loss
                if batch % 10 == 0:
                    print("Epoch {} Batch {} (/{}) Loss {:.4f}".format(
                        epoch + 1, batch, nbatches, batch_loss.numpy()))
                steps += 1
                batch, (inp, targ) = next(batches)
            total_loss /= steps
            print("Epoch {} Loss {:.4f}".format(epoch + 1, total_loss))
            if total_loss < best_total_loss:
                self._checkpoint.save(file_prefix=self._checkpoint_prefix)
                print("Saved checkpoint to {}".format(self._checkpoint_prefix))
                best_total_loss = total_loss
            print("Time taken for 1 epoch {} sec\n".format(time.time() -
                                                           start))

        print("Best total loss: {:.4f}".format(best_total_loss))

        # Restore stdout.
        sys.stdout = stdout_file