def test_load_data_single_path(self): test_args = test_utils.ModelParamsDict() test_args.source_lang = "en" test_args.target_lang = "fr" test_args.log_verbose = False src_dict, tgt_dict = test_utils.create_vocab_dictionaries() src_text_file, tgt_text_file = test_utils.create_test_text_files() src_bin_path = preprocess.binarize_text_file( text_file=src_text_file, dictionary=src_dict, output_path=tempfile.NamedTemporaryFile().name, append_eos=True, reverse_order=False, ) tgt_bin_path = preprocess.binarize_text_file( text_file=tgt_text_file, dictionary=tgt_dict, output_path=tempfile.NamedTemporaryFile().name, append_eos=True, reverse_order=False, ) task = tasks.PytorchTranslateTask(test_args, src_dict, tgt_dict) split = "0" task.load_dataset(split, src_bin_path, tgt_bin_path) self.assertEqual(len(task.datasets[split]), 4) self.assertIsInstance(task.datasets[split], LanguagePairDataset)
def test_load_data_multi_path(self): test_args = test_utils.ModelParamsDict() test_args.source_lang = "en" test_args.target_lang = "fr" test_args.log_verbose = False src_dict, tgt_dict = test_utils.create_vocab_dictionaries() num_paths = 4 src_bin_path, tgt_bin_path = {}, {} for i in range(num_paths): src_text_file, tgt_text_file = test_utils.create_test_text_files() src_bin_path[i] = preprocess.binarize_text_file( text_file=src_text_file, dictionary=src_dict, output_path=tempfile.NamedTemporaryFile().name, append_eos=True, reverse_order=False, ) tgt_bin_path[i] = preprocess.binarize_text_file( text_file=tgt_text_file, dictionary=tgt_dict, output_path=tempfile.NamedTemporaryFile().name, append_eos=True, reverse_order=False, ) task = tasks.PytorchTranslateTask(test_args, src_dict, tgt_dict) split = "1" task.load_dataset(split, src_bin_path, tgt_bin_path) self.assertEqual(len(task.datasets[split]), 16) self.assertIsInstance(task.datasets[split], MultiCorpusSampledDataset)
def _prepare_data_multi_path(self, num_paths): test_args = test_utils.ModelParamsDict() test_args.source_lang = "en" test_args.target_lang = "fr" test_args.log_verbose = False test_args.dataset_upsampling = None test_args.dataset_relative_ratio = None src_dict, tgt_dict = test_utils.create_vocab_dictionaries() src_bin_path, tgt_bin_path = {}, {} for i in range(num_paths): src_text_file, tgt_text_file = test_utils.create_test_text_files() src_bin_path[i] = preprocess.binarize_text_file( text_file=src_text_file, dictionary=src_dict, output_path=tempfile.NamedTemporaryFile().name, append_eos=True, reverse_order=False, ) tgt_bin_path[i] = preprocess.binarize_text_file( text_file=tgt_text_file, dictionary=tgt_dict, output_path=tempfile.NamedTemporaryFile().name, append_eos=True, reverse_order=False, ) return test_args, src_dict, tgt_dict, src_bin_path, tgt_bin_path
def test_load_data_noising(self): test_args = test_utils.ModelParamsDict() test_args.source_lang = "en" test_args.target_lang = "fr" test_args.log_verbose = False src_dict, tgt_dict = test_utils.create_vocab_dictionaries() num_paths = 4 src_bin_path, tgt_bin_path = {}, {} for i in range(num_paths): src_text_file, tgt_text_file = test_utils.create_test_text_files() src_bin_path[i] = preprocess.binarize_text_file( text_file=src_text_file, dictionary=src_dict, output_path=tempfile.NamedTemporaryFile().name, append_eos=True, reverse_order=False, ) tgt_bin_path[i] = preprocess.binarize_text_file( text_file=tgt_text_file, dictionary=tgt_dict, output_path=tempfile.NamedTemporaryFile().name, append_eos=True, reverse_order=False, ) task = tasks.PytorchTranslateTask(test_args, src_dict, tgt_dict) split = "1" task.load_dataset( split, src_bin_path, tgt_bin_path, noiser={ 0: UnsupervisedMTNoising( dictionary=src_dict, max_word_shuffle_distance=3, word_dropout_prob=0.2, word_blanking_prob=0.2, ) }, ) self.assertEqual(len(task.datasets[split]), 16) self.assertIsInstance(task.datasets[split].datasets[0].src, NoisingDataset)