示例#1
0
    def test_save_and_reload(self):
        """Test that saving and reloading doesn't change the vocab.

    Note that this test reads and writes to the filesystem, which necessitates
    that this test size be "large".
    """

        corpus = "A B C D E F G H I J K L M N O P Q R S T U V W X Y Z"
        vocab_filename = os.path.join(self.test_temp_dir, "abc.vocab")

        # Make text encoder from a list and store vocab to fake filesystem.
        encoder = text_encoder.TokenTextEncoder(None,
                                                vocab_list=corpus.split())
        encoder.store_to_file(vocab_filename)

        # Load back the saved vocab file from the fake_filesystem.
        new_encoder = text_encoder.TokenTextEncoder(vocab_filename)

        self.assertEqual(encoder._id_to_token, new_encoder._id_to_token)
        self.assertEqual(encoder._token_to_id, new_encoder._token_to_id)
示例#2
0
 def feature_encoders(self, data_dir):
     if self.is_character_level:
         encoder = text_encoder.ByteTextEncoder()
     elif self.use_subword_tokenizer:
         vocab_filename = os.path.join(data_dir, self.vocab_file)
         encoder = text_encoder.SubwordTextEncoder(vocab_filename)
     else:
         vocab_filename = os.path.join(data_dir, self.vocab_file)
         encoder = text_encoder.TokenTextEncoder(vocab_filename)
     if self.has_inputs:
         return {"inputs": encoder, "targets": encoder}
     return {"targets": encoder}
示例#3
0
    def test_reserved_tokens_in_corpus(self):
        """Test that we handle reserved tokens appearing in the corpus."""
        corpus = "A B {} D E F {} G {}".format(text_encoder.EOS,
                                               text_encoder.EOS,
                                               text_encoder.PAD)

        encoder = text_encoder.TokenTextEncoder(None,
                                                vocab_list=corpus.split())

        all_tokens = encoder._id_to_token.values()

        # If reserved tokens are removed correctly, then the set of tokens will
        # be unique.
        self.assertEqual(len(all_tokens), len(set(all_tokens)))
示例#4
0
 def generator(self, data_dir, tmp_dir, train):
     """Instance of token generator for the WMT en->de task, training set."""
     dataset_path = ("train.tok.clean.bpe.32000"
                     if train else "newstest2013.tok.bpe.32000")
     train_path = _get_wmt_ende_bpe_dataset(tmp_dir, dataset_path)
     token_tmp_path = os.path.join(tmp_dir, self.vocab_file)
     token_path = os.path.join(data_dir, self.vocab_file)
     tf.gfile.Copy(token_tmp_path, token_path, overwrite=True)
     with tf.gfile.GFile(token_path, mode="a") as f:
         f.write("UNK\n")  # Add UNK to the vocab.
     token_vocab = text_encoder.TokenTextEncoder(token_path,
                                                 replace_oov="UNK")
     return token_generator(train_path + ".en", train_path + ".de",
                            token_vocab, EOS)
示例#5
0
def main(_):
    """Convert a file to examples."""
    if FLAGS.subword_text_encoder_filename:
        encoder = text_encoder.SubwordTextEncoder(
            FLAGS.subword_text_encoder_filename)
    elif FLAGS.token_text_encoder_filename:
        encoder = text_encoder.TokenTextEncoder(
            FLAGS.token_text_encoder_filename)
    elif FLAGS.byte_text_encoder:
        encoder = text_encoder.ByteTextEncoder()
    else:
        encoder = None
    reader = tf.python_io.tf_record_iterator(FLAGS.input_filename)
    total_sequences = 0
    total_input_tokens = 0
    total_target_tokens = 0
    max_input_length = 0
    max_target_length = 0
    for record in reader:
        x = tf.train.Example()
        x.ParseFromString(record)
        inputs = [
            int(i) for i in x.features.feature["inputs"].int64_list.value
        ]
        targets = [
            int(i) for i in x.features.feature["targets"].int64_list.value
        ]
        if FLAGS.print_inputs:
            print("INPUTS:\n" + encoder.decode(inputs) if encoder else inputs)
        if FLAGS.print_targets:
            print("TARGETS:\n" +
                  encoder.decode(targets) if encoder else targets)
        total_input_tokens += len(inputs)
        total_target_tokens += len(targets)
        total_sequences += 1
        max_input_length = max(max_input_length, len(inputs))
        max_target_length = max(max_target_length, len(targets))

    tf.logging.info("total_sequences: %d", total_sequences)
    tf.logging.info("total_input_tokens: %d", total_input_tokens)
    tf.logging.info("total_target_tokens: %d", total_target_tokens)
    tf.logging.info("max_input_length: %d", max_input_length)
    tf.logging.info("max_target_length: %d", max_target_length)
示例#6
0
 def feature_encoders(self, data_dir):
     vocab_filename = os.path.join(data_dir, self.vocab_file)
     encoder = text_encoder.TokenTextEncoder(vocab_filename,
                                             replace_oov="UNK")
     return {"inputs": encoder, "targets": encoder}
示例#7
0
def _get_token_encoder(vocab_dir, vocab_name, filename):
  """Reads from file and returns a `TokenTextEncoder` for the vocabulary."""
  vocab_path = os.path.join(vocab_dir, vocab_name)
  if not tf.gfile.Exists(vocab_path):
    _build_vocab(filename, vocab_path, 10000)
  return text_encoder.TokenTextEncoder(vocab_path)