def test_train_gpt2_decoder(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            train.py
            --train_file ./tests/fixtures/line_by_line_max_len_3.txt
            --validation_file ./tests/fixtures/line_by_line_max_len_3.txt
            --do_train
            --max_steps=10
            --per_device_train_batch_size 2
            --encoder_model full-1st-token
            --decoder_model n-tokens
            --set_seq_size 8
            --transformer_type funnel-gpt2
            --transformer_name funnel-transformer/intermediate
            --transformer_decoder_name distilgpt2
            --tokenizer_name distilgpt2
            --output_dir {tmp_dir}
            --overwrite_output_dir
            """.split()

        if torch.cuda.device_count() > 1:
            # Skipping because there are not enough batches to train the model + would need a drop_last to work.
            return

        if torch_device != "cuda":
            testargs.append("--no_cuda")

        with patch.object(sys, "argv", testargs):
            main()
    def test_train_unsupervised_classification_agnews(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            train.py
            --dataset_name=ag_news
            --classification_column=label
            --do_train
            --max_steps=10
            --validation_name=test
            --test_classification
            --per_device_train_batch_size 2
            --per_device_eval_batch_size 2
            --max_validation_size 100
            --encoder_model full-1st-token
            --decoder_model n-tokens
            --latent_size 2
            --transformer_name t5-small
            --output_dir {tmp_dir}
            --overwrite_output_dir
            """.split()

        if torch.cuda.device_count() > 1:
            # Skipping because there are not enough batches to train the model + would need a drop_last to work.
            return

        if torch_device != "cuda":
            testargs.append("--no_cuda")

        with patch.object(sys, "argv", testargs):
            main()
Beispiel #3
0
    def test_train_n_tokens_model(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            train.py
            --train_file ./tests/fixtures/line_by_line_max_len_3.txt
            --validation_file ./tests/fixtures/line_by_line_max_len_3.txt
            --do_train
            --per_device_train_batch_size 2
            --num_train_epochs 1
            --set_seq_size 4
            --n_latent_tokens 2
            --latent_size 2
            --output_dir {tmp_dir}
            --overwrite_output_dir
            """.split()

        if torch.cuda.device_count() > 1:
            # Skipping because there are not enough batches to train the model + would need a drop_last to work.
            return

        if torch_device != "cuda":
            testargs.append("--no_cuda")

        with patch.object(sys, "argv", testargs):
            main()
    def test_train_json(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            train.py
            --train_file ./tests/fixtures/max_len_3.json
            --validation_file ./tests/fixtures/max_len_3.json
            --do_train
            --do_eval
            --per_device_train_batch_size 5
            --per_device_eval_batch_size 5
            --num_train_epochs 2
            --set_seq_size 4
            --latent_size 2
            --transformer_name t5-small
            --output_dir {tmp_dir}
            --overwrite_output_dir
            """.split()

        if torch.cuda.device_count() > 1:
            # Skipping because there are not enough batches to train the model + would need a drop_last to work.
            return

        if torch_device != "cuda":
            testargs.append("--no_cuda")

        with patch.object(sys, "argv", testargs):
            result = main()
            self.assertAlmostEqual(result["epoch"], 2.0)
    def test_train_unsupervised_classification(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            train.py
            --dataset_name=Fraser/news-category-dataset
            --text_column=headline
            --classification_column=category_num
            --do_eval
            --per_device_train_batch_size 2
            --per_device_eval_batch_size 2
            --max_validation_size 100
            --eval_steps 4
            --encoder_model full-1st-token
            --decoder_model n-tokens
            --latent_size 2
            --transformer_name t5-small
            --output_dir {tmp_dir}
            --overwrite_output_dir
            """.split()

        if torch.cuda.device_count() > 1:
            # Skipping because there are not enough batches to train the model + would need a drop_last to work.
            return

        if torch_device != "cuda":
            testargs.append("--no_cuda")

        with patch.object(sys, "argv", testargs):
            result = main()
            self.assertGreater(result["eval_loss"], 0.0)
            self.assertNotIn("epoch", result)
Beispiel #6
0
    def test_train_cycle_loss(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            train.py
            --train_file ./tests/fixtures/line_by_line_max_len_3.txt
            --validation_file ./tests/fixtures/line_by_line_max_len_3.txt
            --do_train
            --do_eval
            --eval_steps 3
            --evaluation_strategy steps
            --sample_from_latent
            --per_device_train_batch_size 4
            --per_device_eval_batch_size 4
            --num_train_epochs 1
            --set_seq_size 8
            --n_latent_tokens 1
            --latent_size 2
            --cycle_loss
            --output_dir {tmp_dir}
            --overwrite_output_dir
            """.split()

        if torch.cuda.device_count() > 1:
            # Skipping because there are not enough batches to train the model + would need a drop_last to work.
            return

        if torch_device != "cuda":
            testargs.append("--no_cuda")

        with patch.object(sys, "argv", testargs):
            result = main()
            self.assertAlmostEqual(result["epoch"], 1.0)
Beispiel #7
0
    def test_train_render_text_image(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            train.py
            --dataset_name=Fraser/mnist-text-default
            --eval_steps 2
            --validation_name test
            --do_eval
            --tokenizer_name tokenizers/tkn_mnist-text-small_byte
            --sample_from_latent
            --render_text_image
            --seq_check python
            --dont_clean_up_tokenization_spaces
            --per_device_train_batch_size 2
            --per_device_eval_batch_size 2
            --num_train_epochs 2
            --set_seq_size 237
            --generate_max_len 2
            --latent_size 2
            --output_dir {tmp_dir}
            --overwrite_output_dir
            """.split()

        if torch.cuda.device_count() > 1:
            # Skipping because there are not enough batches to train the model + would need a drop_last to work.
            return

        if torch_device != "cuda":
            testargs.append("--no_cuda")

        with patch.object(sys, "argv", testargs):
            result = main()
            self.assertAlmostEqual(result["epoch"], 2.0)
Beispiel #8
0
from transformer_vae.train import main

if __name__ == "__main__":
    main()