def test_xla_stats_monitor(tmpdir): """Test XLA stats are logged using a logger.""" model = BoringModel() xla_stats = XLAStatsMonitor() logger = CSVLogger(tmpdir) trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, limit_train_batches=5, tpu_cores=8, callbacks=[xla_stats], logger=logger) trainer.fit(model) assert trainer.state.finished, f"Training failed with {trainer.state}" path_csv = os.path.join(logger.log_dir, ExperimentWriter.NAME_METRICS_FILE) met_data = np.genfromtxt(path_csv, delimiter=',', names=True, deletechars='', replace_space=' ') fields = ['avg. free memory (MB)', 'avg. peak memory (MB)'] for f in fields: assert any(f in h for h in met_data.dtype.names)
def test_xla_stats_monitor_no_logger(tmpdir): """Test XLAStatsMonitor with no logger in Trainer.""" model = BoringModel() xla_stats = XLAStatsMonitor() trainer = Trainer( default_root_dir=tmpdir, callbacks=[xla_stats], max_epochs=1, accelerator="tpu", devices=[1], logger=False ) with pytest.raises(MisconfigurationException, match="Trainer that has no logger."): trainer.fit(model)
def test_xla_stats_monitor_no_tpu_warning(tmpdir): """Test XLAStatsMonitor raises a warning when not training on TPUs.""" model = BoringModel() xla_stats = XLAStatsMonitor() trainer = Trainer(default_root_dir=tmpdir, callbacks=[xla_stats], max_steps=1, tpu_cores=None) with pytest.raises(MisconfigurationException, match="not running on TPU"): trainer.fit(model)
# you can download this file at https://github.com/karpathy/char-rnn/blob/master/data/tinyshakespeare/input.txt text = open('input.txt', 'r').read() # don't worry we won't run out of file handles train_dataset = CharDataset( text, args.block_size) # one line of poem is roughly 50 characters train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers) model = GPT(vocab_size=train_dataset.vocab_size, block_size=train_dataset.block_size, n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd, learning_rate=args.learning_rate) lr_decay = LearningRateDecayCallback(learning_rate=6e-4, warmup_tokens=512 * 20, final_tokens=2 * len(train_dataset) * args.block_size) trainer = Trainer.from_argparse_args( args, max_epochs=5, tpu_cores=8, gradient_clip_val=1.0, callbacks=[lr_decay, XLAStatsMonitor()], ) trainer.fit(model, train_loader)