Пример #1
0
    def test_clm_from_config_zero3(self):
        # this test exercises AutoModel.from_config(config) - to ensure zero.Init is called

        data_dir = self.tests_dir / "fixtures"
        output_dir = self.get_auto_remove_tmp_dir()
        args = f"""
            --model_type gpt2
            --tokenizer_name sshleifer/tiny-gpt2
            --train_file {data_dir}/sample_text.txt
            --validation_file {data_dir}/sample_text.txt
            --output_dir {output_dir}
            --overwrite_output_dir
            --do_train
            --max_train_samples 4
            --per_device_train_batch_size 2
            --num_train_epochs 1
            --warmup_steps 8
            --block_size 8
            --fp16
            --report_to none
            """.split()

        ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_zero3.json".split(
        )
        script = [
            f"{self.examples_dir_str}/pytorch/language-modeling/run_clm.py"
        ]
        launcher = self.get_launcher(distributed=True)

        cmd = launcher + script + args + ds_args
        # keep for quick debug
        # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
        with CaptureStderr() as cs:
            execute_subprocess_async(cmd, env=self.get_env())
        assert "Detected DeepSpeed ZeRO-3" in cs.err
Пример #2
0
    def test_trainer_log_level_replica(self, experiment_id):
        # as each sub-test is slow-ish split into multiple sub-tests to avoid CI timeout
        experiments = dict(
            # test with the default log_level - should be info and thus log info once
            base=dict(extra_args_str="", n_matches=1),
            # test with low log_level and log_level_replica - should be noisy on all processes
            # now the info string should appear twice on 2 processes
            low=dict(
                extra_args_str="--log_level debug --log_level_replica debug",
                n_matches=2),
            # test with high log_level and low log_level_replica
            # now the info string should appear once only on the replica
            high=dict(
                extra_args_str="--log_level error --log_level_replica debug",
                n_matches=1),
            # test with high log_level and log_level_replica - should be quiet on all processes
            mixed=dict(
                extra_args_str="--log_level error --log_level_replica error",
                n_matches=0),
        )

        data = experiments[experiment_id]
        kwargs = dict(distributed=True,
                      predict_with_generate=False,
                      do_eval=False,
                      do_predict=False)
        log_info_string = "Running training"
        with CaptureStderr() as cl:
            self.run_seq2seq_quick(**kwargs,
                                   extra_args_str=data["extra_args_str"])
        n_matches = len(re.findall(log_info_string, cl.err))
        self.assertEqual(n_matches, data["n_matches"])
Пример #3
0
def test_finetune_lr_schedulers():
    args_d: dict = CHEAP_ARGS.copy()

    task = "summarization"
    tmp_dir = make_test_data_dir()

    model = BART_TINY
    output_dir = tempfile.mkdtemp(prefix="output_1_")

    args_d.update(
        data_dir=tmp_dir,
        model_name_or_path=model,
        output_dir=output_dir,
        tokenizer_name=None,
        train_batch_size=2,
        eval_batch_size=2,
        do_predict=False,
        task=task,
        src_lang="en_XX",
        tgt_lang="ro_RO",
        freeze_encoder=True,
        freeze_embeds=True,
    )

    # emulate finetune.py
    parser = argparse.ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
    args = {"--help": True}

    # --help test
    with pytest.raises(SystemExit) as excinfo:
        with CaptureStdout() as cs:
            args = parser.parse_args(args)
        assert False, "--help is expected to sys.exit"
    assert excinfo.type == SystemExit
    expected = lightning_base.arg_to_scheduler_metavar
    assert expected in cs.out, "--help is expected to list the supported schedulers"

    # --lr_scheduler=non_existing_scheduler test
    unsupported_param = "non_existing_scheduler"
    args = {f"--lr_scheduler={unsupported_param}"}
    with pytest.raises(SystemExit) as excinfo:
        with CaptureStderr() as cs:
            args = parser.parse_args(args)
        assert False, "invalid argument is expected to sys.exit"
    assert excinfo.type == SystemExit
    expected = f"invalid choice: '{unsupported_param}'"
    assert expected in cs.err, f"should have bailed on invalid choice of scheduler {unsupported_param}"

    # --lr_scheduler=existing_scheduler test
    supported_param = "cosine"
    args_d1 = args_d.copy()
    args_d1["lr_scheduler"] = supported_param
    args = argparse.Namespace(**args_d1)
    model = main(args)
    assert getattr(
        model.hparams, "lr_scheduler"
    ) == supported_param, f"lr_scheduler={supported_param} shouldn't fail"
Пример #4
0
    def test_trainer_log_level_replica(self):
        log_info_string = "Running training"
        kwargs = dict(distributed=True,
                      predict_with_generate=False,
                      do_eval=False,
                      do_predict=False)

        # test with the default log_level - should be info and thus log info once
        with CaptureStderr() as cl:
            self.run_seq2seq_quick(
                **kwargs,
                extra_args_str="",
            )
        n_matches = len(re.findall(log_info_string, cl.err))
        self.assertEqual(n_matches, 1)

        # test with low log_level and log_level_replica - should be noisy on all processes
        # now the info string should appear twice on 2 processes
        with CaptureStderr() as cl:
            self.run_seq2seq_quick(
                **kwargs,
                extra_args_str="--log_level debug --log_level_replica debug",
            )
        n_matches = len(re.findall(log_info_string, cl.err))
        self.assertEqual(n_matches, 2)

        # test with high log_level and low log_level_replica
        # now the info string should appear once only on the replica
        with CaptureStderr() as cl:
            self.run_seq2seq_quick(
                **kwargs,
                extra_args_str="--log_level error --log_level_replica debug",
            )
        n_matches = len(re.findall(log_info_string, cl.err))
        self.assertEqual(n_matches, 1)

        # test with high log_level and log_level_replica - should be quiet on all processes
        with CaptureStderr() as cl:
            self.run_seq2seq_quick(
                **kwargs,
                extra_args_str="--log_level error --log_level_replica error",
            )
        n_matches = len(re.findall(log_info_string, cl.err))
        self.assertEqual(n_matches, 0)
Пример #5
0
    def test_load_best_model(self, stage):
        # this test exercises --load_best_model_at_end - the key is being able to resume after some training

        data_dir = self.tests_dir / "fixtures/tests_samples/wmt_en_ro"
        output_dir = self.get_auto_remove_tmp_dir()
        args = f"""
            --model_name_or_path {T5_TINY}
            --tokenizer_name {T5_TINY}
            --train_file {data_dir}/train.json
            --validation_file {data_dir}/val.json
            --output_dir {output_dir}
            --overwrite_output_dir
            --source_lang en
            --target_lang ro
            --do_train
            --max_train_samples 3
            --do_eval
            --max_eval_samples 1
            --logging_strategy steps
            --logging_steps 1
            --evaluation_strategy steps
            --eval_steps 1
            --save_strategy steps
            --save_steps 1
            --load_best_model_at_end
            --per_device_train_batch_size 1
            --per_device_eval_batch_size 1
            --num_train_epochs 1
            --fp16
            --report_to none
            """.split()
        args.extend(["--source_prefix", "translate English to Romanian: "])

        ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_zero3.json".split(
        )
        script = [
            f"{self.examples_dir_str}/pytorch/translation/run_translation.py"
        ]
        launcher = get_launcher(distributed=False)

        cmd = launcher + script + args + ds_args
        # keep for quick debug
        # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
        with CaptureStderr() as cs:
            execute_subprocess_async(cmd, env=self.get_env())
        # enough to test it didn't fail
        assert "Detected DeepSpeed ZeRO-3" in cs.err
    def test_batch_encoding_with_labels_jax(self):
        batch = BatchEncoding({
            "inputs": [[1, 2, 3], [4, 5, 6]],
            "labels": [0, 1]
        })
        tensor_batch = batch.convert_to_tensors(tensor_type="jax")
        self.assertEqual(tensor_batch["inputs"].shape, (2, 3))
        self.assertEqual(tensor_batch["labels"].shape, (2, ))
        # test converting the converted
        with CaptureStderr() as cs:
            tensor_batch = batch.convert_to_tensors(tensor_type="jax")
        self.assertFalse(len(cs.err),
                         msg=f"should have no warning, but got {cs.err}")

        batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0})
        tensor_batch = batch.convert_to_tensors(tensor_type="jax",
                                                prepend_batch_axis=True)
        self.assertEqual(tensor_batch["inputs"].shape, (1, 3))
        self.assertEqual(tensor_batch["labels"].shape, (1, ))