def test_train_gpt2_decoder(self): stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) tmp_dir = self.get_auto_remove_tmp_dir() testargs = f""" train.py --train_file ./tests/fixtures/line_by_line_max_len_3.txt --validation_file ./tests/fixtures/line_by_line_max_len_3.txt --do_train --max_steps=10 --per_device_train_batch_size 2 --encoder_model full-1st-token --decoder_model n-tokens --set_seq_size 8 --transformer_type funnel-gpt2 --transformer_name funnel-transformer/intermediate --transformer_decoder_name distilgpt2 --tokenizer_name distilgpt2 --output_dir {tmp_dir} --overwrite_output_dir """.split() if torch.cuda.device_count() > 1: # Skipping because there are not enough batches to train the model + would need a drop_last to work. return if torch_device != "cuda": testargs.append("--no_cuda") with patch.object(sys, "argv", testargs): main()
def test_train_unsupervised_classification_agnews(self): stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) tmp_dir = self.get_auto_remove_tmp_dir() testargs = f""" train.py --dataset_name=ag_news --classification_column=label --do_train --max_steps=10 --validation_name=test --test_classification --per_device_train_batch_size 2 --per_device_eval_batch_size 2 --max_validation_size 100 --encoder_model full-1st-token --decoder_model n-tokens --latent_size 2 --transformer_name t5-small --output_dir {tmp_dir} --overwrite_output_dir """.split() if torch.cuda.device_count() > 1: # Skipping because there are not enough batches to train the model + would need a drop_last to work. return if torch_device != "cuda": testargs.append("--no_cuda") with patch.object(sys, "argv", testargs): main()
def test_train_n_tokens_model(self): stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) tmp_dir = self.get_auto_remove_tmp_dir() testargs = f""" train.py --train_file ./tests/fixtures/line_by_line_max_len_3.txt --validation_file ./tests/fixtures/line_by_line_max_len_3.txt --do_train --per_device_train_batch_size 2 --num_train_epochs 1 --set_seq_size 4 --n_latent_tokens 2 --latent_size 2 --output_dir {tmp_dir} --overwrite_output_dir """.split() if torch.cuda.device_count() > 1: # Skipping because there are not enough batches to train the model + would need a drop_last to work. return if torch_device != "cuda": testargs.append("--no_cuda") with patch.object(sys, "argv", testargs): main()
def test_train_json(self): stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) tmp_dir = self.get_auto_remove_tmp_dir() testargs = f""" train.py --train_file ./tests/fixtures/max_len_3.json --validation_file ./tests/fixtures/max_len_3.json --do_train --do_eval --per_device_train_batch_size 5 --per_device_eval_batch_size 5 --num_train_epochs 2 --set_seq_size 4 --latent_size 2 --transformer_name t5-small --output_dir {tmp_dir} --overwrite_output_dir """.split() if torch.cuda.device_count() > 1: # Skipping because there are not enough batches to train the model + would need a drop_last to work. return if torch_device != "cuda": testargs.append("--no_cuda") with patch.object(sys, "argv", testargs): result = main() self.assertAlmostEqual(result["epoch"], 2.0)
def test_train_unsupervised_classification(self): stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) tmp_dir = self.get_auto_remove_tmp_dir() testargs = f""" train.py --dataset_name=Fraser/news-category-dataset --text_column=headline --classification_column=category_num --do_eval --per_device_train_batch_size 2 --per_device_eval_batch_size 2 --max_validation_size 100 --eval_steps 4 --encoder_model full-1st-token --decoder_model n-tokens --latent_size 2 --transformer_name t5-small --output_dir {tmp_dir} --overwrite_output_dir """.split() if torch.cuda.device_count() > 1: # Skipping because there are not enough batches to train the model + would need a drop_last to work. return if torch_device != "cuda": testargs.append("--no_cuda") with patch.object(sys, "argv", testargs): result = main() self.assertGreater(result["eval_loss"], 0.0) self.assertNotIn("epoch", result)
def test_train_cycle_loss(self): stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) tmp_dir = self.get_auto_remove_tmp_dir() testargs = f""" train.py --train_file ./tests/fixtures/line_by_line_max_len_3.txt --validation_file ./tests/fixtures/line_by_line_max_len_3.txt --do_train --do_eval --eval_steps 3 --evaluation_strategy steps --sample_from_latent --per_device_train_batch_size 4 --per_device_eval_batch_size 4 --num_train_epochs 1 --set_seq_size 8 --n_latent_tokens 1 --latent_size 2 --cycle_loss --output_dir {tmp_dir} --overwrite_output_dir """.split() if torch.cuda.device_count() > 1: # Skipping because there are not enough batches to train the model + would need a drop_last to work. return if torch_device != "cuda": testargs.append("--no_cuda") with patch.object(sys, "argv", testargs): result = main() self.assertAlmostEqual(result["epoch"], 1.0)
def test_train_render_text_image(self): stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) tmp_dir = self.get_auto_remove_tmp_dir() testargs = f""" train.py --dataset_name=Fraser/mnist-text-default --eval_steps 2 --validation_name test --do_eval --tokenizer_name tokenizers/tkn_mnist-text-small_byte --sample_from_latent --render_text_image --seq_check python --dont_clean_up_tokenization_spaces --per_device_train_batch_size 2 --per_device_eval_batch_size 2 --num_train_epochs 2 --set_seq_size 237 --generate_max_len 2 --latent_size 2 --output_dir {tmp_dir} --overwrite_output_dir """.split() if torch.cuda.device_count() > 1: # Skipping because there are not enough batches to train the model + would need a drop_last to work. return if torch_device != "cuda": testargs.append("--no_cuda") with patch.object(sys, "argv", testargs): result = main() self.assertAlmostEqual(result["epoch"], 2.0)
from transformer_vae.train import main if __name__ == "__main__": main()