Exemplo n.º 1
0
def finetune():
    # more info: https://github.com/huggingface/transformers/tree/master/examples/language-modeling
    # flags info: https://github.com/huggingface/transformers/blob/master/src/transformers/training_args.py
    additional_args = [
        "--output_dir=" + FILE_PATH + "/output", "--overwrite_output_dir",
        "--model_type=" + MODEL_TYPE, "--model_name_or_path=" + MODEL_PATH,
        "--do_train", "--train_data_file=" + TRAIN_FILE,
        "--per_device_train_batch_size=1", "--num_train_epochs=1", "--fp16"
    ]
    run_language_modeling.main(additional_args, GRADIENT_CHECKPOINTING)
Exemplo n.º 2
0
    def test_run_language_modeling(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            run_language_modeling.py
            --model_name_or_path distilroberta-base
            --model_type roberta
            --mlm
            --line_by_line
            --train_data_file ./tests/fixtures/sample_text.txt
            --eval_data_file ./tests/fixtures/sample_text.txt
            --output_dir {tmp_dir}
            --overwrite_output_dir
            --do_train
            --do_eval
            --num_train_epochs=1
            """.split()

        if torch_device != "cuda":
            testargs.append("--no_cuda")

        with patch.object(sys, "argv", testargs):
            result = run_language_modeling.main()
            self.assertLess(result["perplexity"], 42)
Exemplo n.º 3
0
    def test_run_language_modeling(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        testargs = """
            run_language_modeling.py
            --model_name_or_path distilroberta-base
            --model_type roberta
            --mlm
            --line_by_line
            --train_data_file ./tests/fixtures/sample_text.txt
            --eval_data_file ./tests/fixtures/sample_text.txt
            --overwrite_output_dir
            --do_train
            --do_eval
            --num_train_epochs=1
            --no_cuda
            """
        output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(
            hash(testargs))
        testargs += "--output_dir " + output_dir
        testargs = testargs.split()
        with patch.object(sys, "argv", testargs):
            result = run_language_modeling.main()
            self.assertLess(result["perplexity"], 35)
        clean_test_dir(output_dir)
Exemplo n.º 4
0
    def test_run_language_modeling(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)
        # TODO: switch to smaller model like sshleifer/tiny-distilroberta-base

        testargs = """
            run_language_modeling.py
            --model_name_or_path distilroberta-base
            --model_type roberta
            --mlm
            --line_by_line
            --train_data_file ./tests/fixtures/sample_text.txt
            --eval_data_file ./tests/fixtures/sample_text.txt
            --output_dir ./tests/fixtures/tests_samples/temp_dir
            --overwrite_output_dir
            --do_train
            --do_eval
            --num_train_epochs=1
            --no_cuda
            """.split()
        with patch.object(sys, "argv", testargs):
            result = run_language_modeling.main()
            self.assertLess(result["perplexity"], 35)