def testCharacterGenerator(self):
        # Generate a trivial source and target file.
        tmp_dir = self.get_temp_dir()
        (_, tmp_file_path) = tempfile.mkstemp(dir=tmp_dir)
        with io.open(tmp_file_path + ".src", "wb") as src_file:
            src_file.write("source1\n")
            src_file.write("source2\n")
        with io.open(tmp_file_path + ".tgt", "wb") as tgt_file:
            tgt_file.write("target1\n")
            tgt_file.write("target2\n")

        # Call character generator on the generated files.
        results_src, results_tgt = [], []
        for dictionary in wmt.character_generator(tmp_file_path + ".src",
                                                  tmp_file_path + ".tgt"):
            self.assertEqual(sorted(list(dictionary)), ["inputs", "targets"])
            results_src.append(dictionary["inputs"])
            results_tgt.append(dictionary["targets"])

        # Check that the results match the files.
        self.assertEqual(len(results_src), 2)
        self.assertEqual("".join([six.int2byte(i) for i in results_src[0]]),
                         "source1")
        self.assertEqual("".join([six.int2byte(i) for i in results_src[1]]),
                         "source2")
        self.assertEqual("".join([six.int2byte(i) for i in results_tgt[0]]),
                         "target1")
        self.assertEqual("".join([six.int2byte(i) for i in results_tgt[1]]),
                         "target2")

        # Clean up.
        os.remove(tmp_file_path + ".src")
        os.remove(tmp_file_path + ".tgt")
        os.remove(tmp_file_path)
  def testCharacterGenerator(self):
    # Generate a trivial source and target file.
    tmp_dir = self.get_temp_dir()
    (_, tmp_file_path) = tempfile.mkstemp(dir=tmp_dir)
    if six.PY2:
      enc_f = lambda s: s
    else:
      enc_f = lambda s: s.encode("utf-8")
    with io.open(tmp_file_path + ".src", "wb") as src_file:
      src_file.write(enc_f("source1\n"))
      src_file.write(enc_f("source2\n"))
    with io.open(tmp_file_path + ".tgt", "wb") as tgt_file:
      tgt_file.write(enc_f("target1\n"))
      tgt_file.write(enc_f("target2\n"))

    # Call character generator on the generated files.
    results_src, results_tgt = [], []
    character_vocab = text_encoder.ByteTextEncoder()
    for dictionary in wmt.character_generator(
        tmp_file_path + ".src", tmp_file_path + ".tgt", character_vocab):
      self.assertEqual(sorted(list(dictionary)), ["inputs", "targets"])
      results_src.append(dictionary["inputs"])
      results_tgt.append(dictionary["targets"])

    # Check that the results match the files.
    # First check that the results match the encoded original strings;
    # this is a comparison of integer arrays.
    self.assertEqual(len(results_src), 2)
    self.assertEqual(results_src[0],
                     character_vocab.encode("source1"))
    self.assertEqual(results_src[1],
                     character_vocab.encode("source2"))
    self.assertEqual(results_tgt[0],
                     character_vocab.encode("target1"))
    self.assertEqual(results_tgt[1],
                     character_vocab.encode("target2"))
    # Then decode the results and compare with the original strings;
    # this is a comparison of strings
    self.assertEqual(character_vocab.decode(results_src[0]),
                     "source1")
    self.assertEqual(character_vocab.decode(results_src[1]),
                     "source2")
    self.assertEqual(character_vocab.decode(results_tgt[0]),
                     "target1")
    self.assertEqual(character_vocab.decode(results_tgt[1]),
                     "target2")

    # Clean up.
    os.remove(tmp_file_path + ".src")
    os.remove(tmp_file_path + ".tgt")
    os.remove(tmp_file_path)
 def generator(self, data_dir, tmp_dir, train):
   character_vocab = text_encoder.ByteTextEncoder()
   datasets = _TEXT_SIMPLIFICATION_TRAIN_DATASETS if train else _TEXT_SIMPLIFICATION_TEST_DATASETS
   return character_generator(datasets[0], datasets[1], character_vocab, EOS)