def test_run_glue(self):
        import xla_spawn

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            ./examples/pytorch/text-classification/run_glue.py
            --num_cores=8
            ./examples/pytorch/text-classification/run_glue.py
            --model_name_or_path distilbert-base-uncased
            --output_dir {tmp_dir}
            --overwrite_output_dir
            --train_file ./tests/fixtures/tests_samples/MRPC/train.csv
            --validation_file ./tests/fixtures/tests_samples/MRPC/dev.csv
            --do_train
            --do_eval
            --debug tpu_metrics_debug
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=1
            --learning_rate=1e-4
            --max_steps=10
            --warmup_steps=2
            --seed=42
            --max_seq_length=128
            """.split()

        with patch.object(sys, "argv", testargs):
            start = time()
            xla_spawn.main()
            end = time()

            result = get_results(tmp_dir)
            self.assertGreaterEqual(result["eval_accuracy"], 0.75)

            # Assert that the script takes less than 500 seconds to make sure it doesn't hang.
            self.assertLess(end - start, 500)
Beispiel #2
0
    def test_trainer_tpu(self):
        import xla_spawn

        testargs = """
            ./tests/test_trainer_tpu.py
            --num_cores=8
            ./tests/test_trainer_tpu.py
            """.split()
        with patch.object(sys, "argv", testargs):
            xla_spawn.main()
Beispiel #3
0
    def test_run_glue(self):
        import xla_spawn

        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        output_directory = "run_glue_output"

        testargs = f"""
            text-classification/run_glue.py
            --num_cores=8
            text-classification/run_glue.py
            --do_train
            --do_eval
            --task_name=MRPC
            --data_dir=../glue_data/MRPC
            --cache_dir=./cache_dir
            --num_train_epochs=1
            --max_seq_length=128
            --learning_rate=3e-5
            --output_dir={output_directory}
            --overwrite_output_dir
            --logging_steps=5
            --save_steps=5
            --overwrite_cache
            --tpu_metrics_debug
            --model_name_or_path=bert-base-cased
            --per_device_train_batch_size=64
            --per_device_eval_batch_size=64
            --evaluate_during_training
            --overwrite_cache
            """.split()
        with patch.object(sys, "argv", testargs):
            start = time()
            xla_spawn.main()
            end = time()

            result = {}
            with open(f"{output_directory}/eval_results_mrpc.txt") as f:
                lines = f.readlines()
                for line in lines:
                    key, value = line.split(" = ")
                    result[key] = float(value)

            del result["eval_loss"]
            for value in result.values():
                # Assert that the model trains
                self.assertGreaterEqual(value, 0.70)

            # Assert that the script takes less than 100 seconds to make sure it doesn't hang.
            self.assertLess(end - start, 100)