def test_run_translation(self): stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) tmp_dir = self.get_auto_remove_tmp_dir() testargs = f""" run_translation.py --model_name_or_path sshleifer/student_marian_en_ro_6_1 --source_lang en --target_lang ro --train_file tests/fixtures/tests_samples/wmt16/sample.json --validation_file tests/fixtures/tests_samples/wmt16/sample.json --output_dir {tmp_dir} --overwrite_output_dir --max_steps=50 --warmup_steps=8 --do_train --do_eval --learning_rate=3e-3 --per_device_train_batch_size=2 --per_device_eval_batch_size=1 --predict_with_generate --source_lang en_XX --target_lang ro_RO """.split() with patch.object(sys, "argv", testargs): run_translation.main() result = get_results(tmp_dir) self.assertGreaterEqual(result["eval_bleu"], 30)
def run_trainer( self, eval_steps: int, max_len: int, model_name: str, num_train_epochs: int, learning_rate: float = 3e-3, distributed: bool = False, extra_args_str: str = None, predict_with_generate: bool = True, ): data_dir = self.examples_dir / "test_data/wmt_en_ro" output_dir = self.get_auto_remove_tmp_dir() args = f""" --model_name_or_path {model_name} --train_file {data_dir}/train.json --validation_file {data_dir}/val.json --test_file {data_dir}/test.json --output_dir {output_dir} --overwrite_output_dir --max_train_samples 8 --max_val_samples 8 --max_source_length {max_len} --max_target_length {max_len} --val_max_target_length {max_len} --do_train --do_eval --do_predict --num_train_epochs {str(num_train_epochs)} --per_device_train_batch_size 4 --per_device_eval_batch_size 4 --learning_rate {learning_rate} --warmup_steps 8 --evaluation_strategy steps --logging_steps 0 --eval_steps {str(eval_steps)} --save_steps {str(eval_steps)} --group_by_length --label_smoothing_factor 0.1 --adafactor --target_lang ro_RO --source_lang en_XX """ if predict_with_generate: args += "--predict_with_generate" args = args.split() if extra_args_str is not None: args.extend(extra_args_str.split()) if distributed: n_gpu = get_gpu_count() distributed_args = f""" -m torch.distributed.launch --nproc_per_node={n_gpu} {self.examples_dir_str}/seq2seq/run_translation.py """.split() cmd = [sys.executable] + distributed_args + args execute_subprocess_async(cmd, env=self.get_env()) else: testargs = ["run_translation.py"] + args with patch.object(sys, "argv", testargs): main() return output_dir
def run_trainer( self, eval_steps: int, max_len: int, model_name: str, num_train_epochs: int, learning_rate: float = 3e-3, distributed: bool = False, extra_args_str: str = None, predict_with_generate: bool = True, do_train: bool = True, do_eval: bool = True, do_predict: bool = True, ): data_dir = self.test_file_dir / "../fixtures/tests_samples/wmt_en_ro" output_dir = self.get_auto_remove_tmp_dir() args_train = f""" --model_name_or_path {model_name} --train_file {data_dir}/train.json --validation_file {data_dir}/val.json --test_file {data_dir}/test.json --output_dir {output_dir} --overwrite_output_dir --max_train_samples 8 --max_source_length {max_len} --max_target_length {max_len} --do_train --num_train_epochs {str(num_train_epochs)} --per_device_train_batch_size 4 --learning_rate {learning_rate} --warmup_steps 8 --logging_steps 0 --logging_strategy no --save_steps {str(eval_steps)} --group_by_length --label_smoothing_factor 0.1 --adafactor --target_lang ro_RO --source_lang en_XX """ args_eval = f""" --do_eval --per_device_eval_batch_size 4 --max_eval_samples 8 --val_max_target_length {max_len} --evaluation_strategy steps --eval_steps {str(eval_steps)} """ args_predict = """ --do_predict """ args = "" if do_train: args += args_train if do_eval: args += args_eval if do_predict: args += args_predict if predict_with_generate: args += "--predict_with_generate" args = args.split() if extra_args_str is not None: args.extend(extra_args_str.split()) if distributed: n_gpu = get_gpu_count() master_port = get_torch_dist_unique_port() distributed_args = f""" -m torch.distributed.launch --nproc_per_node={n_gpu} --master_port={master_port} {self.examples_dir_str}/pytorch/translation/run_translation.py """.split() cmd = [sys.executable] + distributed_args + args # keep for quick debug # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die execute_subprocess_async(cmd, env=self.get_env()) else: testargs = ["run_translation.py"] + args with patch.object(sys, "argv", testargs): main() return output_dir