コード例 #1
0
def test_xla_stats_monitor(tmpdir):
    """Test XLA stats are logged using a logger."""

    model = BoringModel()
    xla_stats = XLAStatsMonitor()
    logger = CSVLogger(tmpdir)

    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=2,
                      limit_train_batches=5,
                      tpu_cores=8,
                      callbacks=[xla_stats],
                      logger=logger)

    trainer.fit(model)
    assert trainer.state.finished, f"Training failed with {trainer.state}"

    path_csv = os.path.join(logger.log_dir, ExperimentWriter.NAME_METRICS_FILE)
    met_data = np.genfromtxt(path_csv,
                             delimiter=',',
                             names=True,
                             deletechars='',
                             replace_space=' ')

    fields = ['avg. free memory (MB)', 'avg. peak memory (MB)']

    for f in fields:
        assert any(f in h for h in met_data.dtype.names)
コード例 #2
0
def test_xla_stats_monitor_no_logger(tmpdir):
    """Test XLAStatsMonitor with no logger in Trainer."""

    model = BoringModel()
    xla_stats = XLAStatsMonitor()

    trainer = Trainer(
        default_root_dir=tmpdir, callbacks=[xla_stats], max_epochs=1, accelerator="tpu", devices=[1], logger=False
    )

    with pytest.raises(MisconfigurationException, match="Trainer that has no logger."):
        trainer.fit(model)
コード例 #3
0
def test_xla_stats_monitor_no_tpu_warning(tmpdir):
    """Test XLAStatsMonitor raises a warning when not training on TPUs."""

    model = BoringModel()
    xla_stats = XLAStatsMonitor()

    trainer = Trainer(default_root_dir=tmpdir,
                      callbacks=[xla_stats],
                      max_steps=1,
                      tpu_cores=None)

    with pytest.raises(MisconfigurationException, match="not running on TPU"):
        trainer.fit(model)
コード例 #4
0
    # you can download this file at https://github.com/karpathy/char-rnn/blob/master/data/tinyshakespeare/input.txt
    text = open('input.txt',
                'r').read()  # don't worry we won't run out of file handles
    train_dataset = CharDataset(
        text, args.block_size)  # one line of poem is roughly 50 characters
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              num_workers=args.num_workers)

    model = GPT(vocab_size=train_dataset.vocab_size,
                block_size=train_dataset.block_size,
                n_layer=args.n_layer,
                n_head=args.n_head,
                n_embd=args.n_embd,
                learning_rate=args.learning_rate)

    lr_decay = LearningRateDecayCallback(learning_rate=6e-4,
                                         warmup_tokens=512 * 20,
                                         final_tokens=2 * len(train_dataset) *
                                         args.block_size)

    trainer = Trainer.from_argparse_args(
        args,
        max_epochs=5,
        tpu_cores=8,
        gradient_clip_val=1.0,
        callbacks=[lr_decay, XLAStatsMonitor()],
    )
    trainer.fit(model, train_loader)