예제 #1
0
 def test_load_data_single_path(self):
     test_args = test_utils.ModelParamsDict()
     test_args.source_lang = "en"
     test_args.target_lang = "fr"
     test_args.log_verbose = False
     src_dict, tgt_dict = test_utils.create_vocab_dictionaries()
     src_text_file, tgt_text_file = test_utils.create_test_text_files()
     src_bin_path = preprocess.binarize_text_file(
         text_file=src_text_file,
         dictionary=src_dict,
         output_path=tempfile.NamedTemporaryFile().name,
         append_eos=True,
         reverse_order=False,
     )
     tgt_bin_path = preprocess.binarize_text_file(
         text_file=tgt_text_file,
         dictionary=tgt_dict,
         output_path=tempfile.NamedTemporaryFile().name,
         append_eos=True,
         reverse_order=False,
     )
     task = tasks.PytorchTranslateTask(test_args, src_dict, tgt_dict)
     split = "0"
     task.load_dataset(split, src_bin_path, tgt_bin_path)
     self.assertEqual(len(task.datasets[split]), 4)
     self.assertIsInstance(task.datasets[split], LanguagePairDataset)
예제 #2
0
 def test_load_data_multi_path(self):
     test_args = test_utils.ModelParamsDict()
     test_args.source_lang = "en"
     test_args.target_lang = "fr"
     test_args.log_verbose = False
     src_dict, tgt_dict = test_utils.create_vocab_dictionaries()
     num_paths = 4
     src_bin_path, tgt_bin_path = {}, {}
     for i in range(num_paths):
         src_text_file, tgt_text_file = test_utils.create_test_text_files()
         src_bin_path[i] = preprocess.binarize_text_file(
             text_file=src_text_file,
             dictionary=src_dict,
             output_path=tempfile.NamedTemporaryFile().name,
             append_eos=True,
             reverse_order=False,
         )
         tgt_bin_path[i] = preprocess.binarize_text_file(
             text_file=tgt_text_file,
             dictionary=tgt_dict,
             output_path=tempfile.NamedTemporaryFile().name,
             append_eos=True,
             reverse_order=False,
         )
     task = tasks.PytorchTranslateTask(test_args, src_dict, tgt_dict)
     split = "1"
     task.load_dataset(split, src_bin_path, tgt_bin_path)
     self.assertEqual(len(task.datasets[split]), 16)
     self.assertIsInstance(task.datasets[split], MultiCorpusSampledDataset)
예제 #3
0
 def _prepare_data_multi_path(self, num_paths):
     test_args = test_utils.ModelParamsDict()
     test_args.source_lang = "en"
     test_args.target_lang = "fr"
     test_args.log_verbose = False
     test_args.dataset_upsampling = None
     test_args.dataset_relative_ratio = None
     src_dict, tgt_dict = test_utils.create_vocab_dictionaries()
     src_bin_path, tgt_bin_path = {}, {}
     for i in range(num_paths):
         src_text_file, tgt_text_file = test_utils.create_test_text_files()
         src_bin_path[i] = preprocess.binarize_text_file(
             text_file=src_text_file,
             dictionary=src_dict,
             output_path=tempfile.NamedTemporaryFile().name,
             append_eos=True,
             reverse_order=False,
         )
         tgt_bin_path[i] = preprocess.binarize_text_file(
             text_file=tgt_text_file,
             dictionary=tgt_dict,
             output_path=tempfile.NamedTemporaryFile().name,
             append_eos=True,
             reverse_order=False,
         )
     return test_args, src_dict, tgt_dict, src_bin_path, tgt_bin_path
예제 #4
0
 def test_load_data_noising(self):
     test_args = test_utils.ModelParamsDict()
     test_args.source_lang = "en"
     test_args.target_lang = "fr"
     test_args.log_verbose = False
     src_dict, tgt_dict = test_utils.create_vocab_dictionaries()
     num_paths = 4
     src_bin_path, tgt_bin_path = {}, {}
     for i in range(num_paths):
         src_text_file, tgt_text_file = test_utils.create_test_text_files()
         src_bin_path[i] = preprocess.binarize_text_file(
             text_file=src_text_file,
             dictionary=src_dict,
             output_path=tempfile.NamedTemporaryFile().name,
             append_eos=True,
             reverse_order=False,
         )
         tgt_bin_path[i] = preprocess.binarize_text_file(
             text_file=tgt_text_file,
             dictionary=tgt_dict,
             output_path=tempfile.NamedTemporaryFile().name,
             append_eos=True,
             reverse_order=False,
         )
     task = tasks.PytorchTranslateTask(test_args, src_dict, tgt_dict)
     split = "1"
     task.load_dataset(
         split,
         src_bin_path,
         tgt_bin_path,
         noiser={
             0:
             UnsupervisedMTNoising(
                 dictionary=src_dict,
                 max_word_shuffle_distance=3,
                 word_dropout_prob=0.2,
                 word_blanking_prob=0.2,
             )
         },
     )
     self.assertEqual(len(task.datasets[split]), 16)
     self.assertIsInstance(task.datasets[split].datasets[0].src,
                           NoisingDataset)