コード例 #1
0
ファイル: cnn_model.py プロジェクト: tallamjr/google-research
  def train(self, corpus, epochs=10, direction="pronounce", window=-1):
    """Runs training."""
    # Create training log that also redirects to stdout.
    stdout_file = sys.stdout
    logfile = os.path.join(self._checkpoint_dir, "train.log")
    print("Training log: {}".format(logfile))
    sys.stdout = utils.DualLogger(logfile)

    # Dump some parameters.
    print("           Direction: {}".format(direction))
    print("            # Epochs: {}".format(epochs))
    print("          Batch size: {}".format(self._batch_size))
    print("         Window size: {}".format(window))
    print("     Max written len: {}".format(corpus.max_written_len))
    print("        Max pron len: {}".format(corpus.max_pronounce_len))
    print("Max written word len: {}".format(corpus.max_written_word_len))
    print("   Max pron word len: {}".format(corpus.max_pronounce_word_len))
    print("       Use residuals: {}".format(self._use_residuals))

    # Perform training.
    best_total_loss = 1000000
    nbatches = data.num_batches(corpus, self._batch_size, direction=direction,
                                window=window)
    for epoch in range(epochs):
      self._train_accuracy.reset_states()

      start = time.time()
      total_loss = 0
      steps = 0
      batches = data.batchify(corpus, self._batch_size, direction,
                              window=window)
      batch, (inputs, targ) = next(batches)
      while batch > -1:
        bos = np.expand_dims(
            [self._output_symbols.find("<s>")] * np.shape(targ)[0], 1)
        targets = np.concatenate((bos, targ), axis=-1)
        batch_loss = self._train_step(inputs, targets)
        total_loss += batch_loss
        if batch % 10 == 0:
          print("Epoch {} Batch {} (/{}) Loss {:.4f}".format(
              epoch + 1,
              batch,
              nbatches,
              batch_loss))
        steps += 1
        batch, (inputs, targ) = next(batches)
      total_loss /= steps
      print("Epoch {} Loss {:.4f} Accuracy {:.4f}".format(
          epoch + 1, total_loss, self._train_accuracy.result()))

      if total_loss < best_total_loss:
        self._checkpoint.save(file_prefix=self._checkpoint_prefix)
        print("Saved checkpoint to {}".format(self._checkpoint_prefix))
        best_total_loss = total_loss
      print("Time taken for 1 epoch {} sec\n".format(
          time.time() - start))
    print("Best total loss: {:.4f}".format(best_total_loss))

    # Restore stdout.
    sys.stdout = stdout_file
コード例 #2
0
    def train(self, corpus, epochs=10, direction="pronounce", window=-1):
        """Main entry point for running the training."""
        # Create training log that also redirects to stdout.
        stdout_file = sys.stdout
        logfile = os.path.join(self._checkpoint_dir, "train.log")
        print("Training log: {}".format(logfile))
        sys.stdout = utils.DualLogger(logfile)

        # Dump some parameters.
        print("           Direction: {}".format(direction))
        print("            # Epochs: {}".format(epochs))
        print("          Batch size: {}".format(self._batch_size))
        print("         Window size: {}".format(window))
        print("     Max written len: {}".format(corpus.max_written_len))
        print("        Max pron len: {}".format(corpus.max_pronounce_len))
        print("Max written word len: {}".format(corpus.max_written_word_len))
        print("   Max pron word len: {}".format(corpus.max_pronounce_word_len))

        # Perform training.
        best_total_loss = 1000000
        nbatches = data.num_batches(corpus,
                                    self._batch_size,
                                    direction=direction,
                                    window=window)
        for epoch in range(epochs):
            start = time.time()
            total_loss = 0
            steps = 0
            batches = data.batchify(corpus,
                                    self._batch_size,
                                    direction,
                                    window=window)
            batch, (inp, targ) = next(batches)
            enc_hidden = self._encoder.initialize_hidden_state(
                self._batch_size)
            if self._input_length == -1:
                # TODO(agutkin,rws): Following two lines will break if batchify()
                # returns an empty tuple.
                if isinstance(inp, np.ndarray):
                    self._input_length = inp.shape[1]
                if isinstance(targ, np.ndarray):
                    self._output_length = targ.shape[1]
            while batch > -1:
                batch_loss = self._train_step(inp, targ, enc_hidden)
                total_loss += batch_loss
                if batch % 10 == 0:
                    print("Epoch {} Batch {} (/{}) Loss {:.4f}".format(
                        epoch + 1, batch, nbatches, batch_loss.numpy()))
                steps += 1
                batch, (inp, targ) = next(batches)
            total_loss /= steps
            print("Epoch {} Loss {:.4f}".format(epoch + 1, total_loss))
            if total_loss < best_total_loss:
                self._checkpoint.save(file_prefix=self._checkpoint_prefix)
                print("Saved checkpoint to {}".format(self._checkpoint_prefix))
                best_total_loss = total_loss
            print("Time taken for 1 epoch {} sec\n".format(time.time() -
                                                           start))

        print("Best total loss: {:.4f}".format(best_total_loss))

        # Restore stdout.
        sys.stdout = stdout_file