示例#1
0
    def generator(self, data_dir, tmp_dir, train):
        """
    Generate the vocab and then build train and validation t2t-datagen files.
    Four .txt files have to be present in the data_dir directory:
      trainSource.txt
      trainTarget.txt
      devSource.txt
      devTarget.txt

    Params:
      :train: Whether we are in train mode or not.
    """
        character_vocab = text_encoder.ByteTextEncoder()
        mode = "train" if train else "dev"
        print("t2t_csaky_log: " + mode + " data generation activated.")

        sourcePath = os.path.join(data_dir, mode + "Source.txt")
        targetPath = os.path.join(data_dir, mode + "Target.txt")

        # Try to find the txt files.
        if os.path.isfile(sourcePath) and os.path.isfile(targetPath):
            print("t2t_csaky_log: Generating " + mode + " files in " +
                  data_dir)
            return translate.character_generator(sourcePath, targetPath,
                                                 character_vocab, EOS)
        else:
            print("t2t_csaky_log: " + mode +
                  " source or target file not found, please check " +
                  "that the following files exist in your " + data_dir +
                  " directory and rerun this program:")
            print("  trainSource.txt")
            print("  trainTarget.txt")
            print("  devSource.txt")
            print("  devTarget.txt")
  def generator(self, data_dir, tmp_dir, train):
    '''
    Generate the vocab and then build train and validation t2t-datagen files.
    Four .txt files have to be present in the data_dir directory:
      trainSource.txt
      trainTarget.txt
      devSource.txt
      devTarget.txt

    Params:
      :train: Whether we are in train mode or not.
    '''
    character_vocab = text_encoder.ByteTextEncoder()
    mode = 'train' if train else 'dev'
    print('t2t_csaky_log: ' + mode + ' data generation activated.')

    sourcePath = os.path.join(data_dir, mode + 'Source.txt')
    targetPath = os.path.join(data_dir, mode + 'Target.txt')

    # Try to find the txt files.
    if os.path.isfile(sourcePath) and os.path.isfile(targetPath):
      print('t2t_csaky_log: Generating ' + mode + ' files in ' + data_dir)
      return translate.character_generator(sourcePath,
                                           targetPath,
                                           character_vocab,
                                           EOS)
    else:
      print('t2t_csaky_log: ' + mode +
            ' source or target file not found, please check ' +
            'that the following files exist in your ' + data_dir +
            ' directory and rerun this program:')
      print('  trainSource.txt')
      print('  trainTarget.txt')
      print('  devSource.txt')
      print('  devTarget.txt')
示例#3
0
    def generator(self, _, tmp_dir, train):
        character_vocab = text_encoder.ByteTextEncoder()
        tag = "train" if train else "dev"
        sPath = tag + "Source.txt"
        tPath = tag + "Target.txt"

        return translate.character_generator(sPath, tPath, character_vocab,
                                             EOS)
示例#4
0
 def generator(self, _, tmp_dir, train):
   character_vocab = text_encoder.ByteTextEncoder()
   datasets = _ENDE_TRAIN_DATASETS if train else _ENDE_TEST_DATASETS
   tag = "train" if train else "dev"
   data_path = translate.compile_data(tmp_dir, datasets,
                                      "wmt_ende_chr_%s" % tag)
   return translate.character_generator(
       data_path + ".lang1", data_path + ".lang2", character_vocab, EOS)
 def generator(self, _, tmp_dir, train):
     character_vocab = text_encoder.ByteTextEncoder()
     datasets = _ENDE_TRAIN_DATASETS if train else _ENDE_TEST_DATASETS
     tag = "train" if train else "dev"
     data_path = translate.compile_data(tmp_dir, datasets,
                                        "wmt_ende_chr_%s" % tag)
     return translate.character_generator(data_path + ".lang1",
                                          data_path + ".lang2",
                                          character_vocab, EOS)
示例#6
0
 def generator(self, data_dir, tmp_dir, train):
   character_vocab = text_encoder.ByteTextEncoder()
   if self.use_small_dataset:
     datasets = _ENFR_TRAIN_SMALL_DATA if train else _ENFR_TEST_SMALL_DATA
   else:
     datasets = _ENFR_TRAIN_LARGE_DATA if train else _ENFR_TEST_LARGE_DATA
   tag = "train" if train else "dev"
   data_path = translate.compile_data(tmp_dir, datasets,
                                      "wmt_enfr_chr_%s" % tag)
   return translate.character_generator(
       data_path + ".lang1", data_path + ".lang2", character_vocab, EOS)
示例#7
0
  def testCharacterGenerator(self):
    # Generate a trivial source and target file.
    tmp_dir = self.get_temp_dir()
    (_, tmp_file_path) = tempfile.mkstemp(dir=tmp_dir)
    if six.PY2:
      enc_f = lambda s: s
    else:
      enc_f = lambda s: s.encode("utf-8")
    with io.open(tmp_file_path + ".src", "wb") as src_file:
      src_file.write(enc_f("source1\n"))
      src_file.write(enc_f("source2\n"))
    with io.open(tmp_file_path + ".tgt", "wb") as tgt_file:
      tgt_file.write(enc_f("target1\n"))
      tgt_file.write(enc_f("target2\n"))

    # Call character generator on the generated files.
    results_src, results_tgt = [], []
    character_vocab = text_encoder.ByteTextEncoder()
    for dictionary in translate.character_generator(
        tmp_file_path + ".src", tmp_file_path + ".tgt", character_vocab):
      self.assertEqual(sorted(list(dictionary)), ["inputs", "targets"])
      results_src.append(dictionary["inputs"])
      results_tgt.append(dictionary["targets"])

    # Check that the results match the files.
    # First check that the results match the encoded original strings;
    # this is a comparison of integer arrays.
    self.assertEqual(len(results_src), 2)
    self.assertEqual(results_src[0],
                     character_vocab.encode("source1"))
    self.assertEqual(results_src[1],
                     character_vocab.encode("source2"))
    self.assertEqual(results_tgt[0],
                     character_vocab.encode("target1"))
    self.assertEqual(results_tgt[1],
                     character_vocab.encode("target2"))
    # Then decode the results and compare with the original strings;
    # this is a comparison of strings
    self.assertEqual(character_vocab.decode(results_src[0]),
                     "source1")
    self.assertEqual(character_vocab.decode(results_src[1]),
                     "source2")
    self.assertEqual(character_vocab.decode(results_tgt[0]),
                     "target1")
    self.assertEqual(character_vocab.decode(results_tgt[1]),
                     "target2")

    # Clean up.
    os.remove(tmp_file_path + ".src")
    os.remove(tmp_file_path + ".tgt")
    os.remove(tmp_file_path)