Ejemplo n.º 1
0
 def feature_encoders(self, data_dir):
     source_vocab_filename = os.path.join(data_dir, self.source_vocab_name)
     target_vocab_filename = os.path.join(data_dir, self.target_vocab_name)
     source_token = text_encoder.SubwordTextEncoder(source_vocab_filename)
     target_token = text_encoder.SubwordTextEncoder(target_vocab_filename)
     return {
         "inputs": source_token,
         "targets": target_token,
     }
Ejemplo n.º 2
0
 def get_or_create_vocab(self, data_dir, tmp_dir, force_get=False):
     if self.vocab_type == VocabType.CHARACTER:
         encoder = text_encoder.ByteTextEncoder()
     elif self.vocab_type == VocabType.SUBWORD:
         if force_get:
             vocab_filepath = os.path.join(data_dir, self.vocab_filename)
             encoder = text_encoder.SubwordTextEncoder(vocab_filepath)
         else:
             other_problem = self.use_vocab_from_other_problem
             if other_problem:
                 return other_problem.get_or_create_vocab(
                     data_dir, tmp_dir, force_get)
             encoder = generator_utils.get_or_generate_vocab_inner(
                 data_dir,
                 self.vocab_filename,
                 self.approx_vocab_size,
                 self.generate_text_for_vocab(data_dir, tmp_dir),
                 max_subtoken_length=self.max_subtoken_length,
                 reserved_tokens=(text_encoder.RESERVED_TOKENS +
                                  self.additional_reserved_tokens))
     elif self.vocab_type == VocabType.TOKEN:
         vocab_filename = os.path.join(data_dir, self.vocab_filename)
         encoder = text_encoder.TokenTextEncoder(vocab_filename,
                                                 replace_oov=self.oov_token)
     else:
         raise ValueError("Unrecognized VocabType: %s" %
                          str(self.vocab_type))
     return encoder
Ejemplo n.º 3
0
 def get_or_create_vocab(self, data_dir, tmp_dir, force_get=False):
     """Get vocab for distill problems."""
     # We assume that vocab file is present in data_dir directory where the
     # data generated will be stored.
     vocab_filepath = os.path.join(data_dir, self.vocab_filename)
     encoder = text_encoder.SubwordTextEncoder(vocab_filepath)
     return encoder
Ejemplo n.º 4
0
def get_or_generate_vocab_inner(data_dir,
                                vocab_filename,
                                vocab_size,
                                generator,
                                max_subtoken_length=None,
                                reserved_tokens=None):
    """Inner implementation for vocab generators.

  Args:
    data_dir: The base directory where data and vocab files are stored. If None,
      then do not save the vocab even if it doesn't exist.
    vocab_filename: relative filename where vocab file is stored
    vocab_size: target size of the vocabulary constructed by SubwordTextEncoder
    generator: a generator that produces tokens from the vocabulary
    max_subtoken_length: an optional integer.  Set this to a finite value to
      avoid quadratic costs during vocab building.
    reserved_tokens: List of reserved tokens. `text_encoder.RESERVED_TOKENS`
      should be a prefix of `reserved_tokens`. If `None`, defaults to
      `RESERVED_TOKENS`.

  Returns:
    A SubwordTextEncoder vocabulary object.
  """
    if data_dir and vocab_filename:
        vocab_filepath = os.path.join(data_dir, vocab_filename)
        if tf.gfile.Exists(vocab_filepath):
            tf.logging.info("Found vocab file: %s", vocab_filepath)
            return text_encoder.SubwordTextEncoder(vocab_filepath)
    else:
        vocab_filepath = None

    tf.logging.info("Generating vocab file: %s", vocab_filepath)
    vocab = text_encoder.SubwordTextEncoder.build_from_generator(
        generator,
        vocab_size,
        max_subtoken_length=max_subtoken_length,
        reserved_tokens=reserved_tokens)

    if vocab_filepath:
        tf.gfile.MakeDirs(data_dir)
        vocab.store_to_file(vocab_filepath)

    return vocab
Ejemplo n.º 5
0
def main(unused_argv):
  if FLAGS.corpus_filepattern and FLAGS.vocab_filepattern:
    raise ValueError(
        'Must only provide one of --corpus_filepattern or --vocab_filepattern')

  elif FLAGS.corpus_filepattern:
    token_counts = tokenizer.corpus_token_counts(
        FLAGS.corpus_filepattern,
        FLAGS.corpus_max_lines,
        split_on_newlines=FLAGS.split_on_newlines)

  elif FLAGS.vocab_filepattern:
    token_counts = tokenizer.vocab_token_counts(FLAGS.vocab_filepattern,
                                                FLAGS.corpus_max_lines)

  else:
    raise ValueError(
        'Must provide one of --corpus_filepattern or --vocab_filepattern')

  encoder = text_encoder.SubwordTextEncoder()
  encoder.build_from_token_counts(token_counts, FLAGS.min_count,
                                  FLAGS.num_iterations)
  encoder.store_to_file(FLAGS.output_filename)