Ejemplo n.º 1
0
def test_pytorch_profiler_trainer_ddp(tmpdir, use_output_filename):
    """Ensure that the profiler can be given to the training and default step are properly recorded. """

    if use_output_filename:
        output_filename = os.path.join(tmpdir, "profiler.txt")
    else:
        output_filename = None

    profiler = PyTorchProfiler(output_filename=output_filename)

    model = BoringModel()
    trainer = Trainer(
        fast_dev_run=True,
        profiler=profiler,
        accelerator="ddp",
        gpus=2,
    )
    trainer.fit(model)

    enabled = use_output_filename or not use_output_filename and profiler.local_rank == 0

    if enabled:
        assert len(profiler.summary()) > 0
        assert set(profiler.profiled_actions.keys()) == {
            'training_step_and_backward', 'validation_step'
        }
    else:
        assert profiler.summary() is None
        assert set(profiler.profiled_actions.keys()) == set()

    # todo (tchaton) add support for all ranks
    if use_output_filename and os.getenv("LOCAL_RANK") == "0":
        data = Path(profiler.output_fname).read_text()
        assert len(data) > 0
Ejemplo n.º 2
0
def test_pytorch_profiler_trainer_ddp(tmpdir):
    """Ensure that the profiler can be given to the training and default step are properly recorded. """
    pytorch_profiler = PyTorchProfiler(dirpath=None, filename="profiler")
    model = BoringModel()
    trainer = Trainer(
        max_epochs=1,
        default_root_dir=tmpdir,
        limit_train_batches=2,
        limit_val_batches=2,
        profiler=pytorch_profiler,
        accelerator="ddp",
        gpus=2,
    )
    trainer.fit(model)

    assert len(pytorch_profiler.summary()) > 0
    assert set(pytorch_profiler.profiled_actions) == {
        'training_step_and_backward', 'validation_step'
    }

    files = sorted(f for f in os.listdir(pytorch_profiler.dirpath)
                   if "fit" in f)
    rank = int(os.getenv("LOCAL_RANK", "0"))
    expected = f"fit-profiler-{rank}.txt"
    assert files[rank] == expected

    path = os.path.join(pytorch_profiler.dirpath, expected)
    data = Path(path).read_text("utf-8")
    assert len(data) > 0