def test_timer_time_remaining(time_mock):
    """ Test that the timer tracks the elapsed and remaining time correctly. """
    start_time = time.monotonic()
    duration = timedelta(seconds=10)
    time_mock.monotonic.return_value = start_time
    timer = Timer(duration=duration)
    assert timer.time_remaining() == duration.total_seconds()
    assert timer.time_elapsed() == 0

    # timer not started yet
    time_mock.monotonic.return_value = start_time + 60
    assert timer.start_time() is None
    assert timer.time_remaining() == 10
    assert timer.time_elapsed() == 0

    # start timer
    time_mock.monotonic.return_value = start_time
    timer.on_train_start(trainer=Mock(), pl_module=Mock())
    assert timer.start_time() == start_time

    # pretend time has elapsed
    elapsed = 3
    time_mock.monotonic.return_value = start_time + elapsed
    assert timer.start_time() == start_time
    assert round(timer.time_remaining()) == 7
    assert round(timer.time_elapsed()) == 3
def test_timer_resume_training(tmpdir):
    """ Test that the timer can resume together with the Trainer. """
    model = BoringModel()
    timer = Timer(duration=timedelta(milliseconds=200))
    checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1)

    # initial training
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=100,
        callbacks=[timer, checkpoint_callback],
    )
    trainer.fit(model)
    assert not timer._offset
    assert timer.time_remaining() <= 0
    assert trainer.current_epoch < 99
    saved_global_step = trainer.global_step

    # resume training (with depleted timer
    timer = Timer(duration=timedelta(milliseconds=200))
    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=[timer, checkpoint_callback],
        resume_from_checkpoint=checkpoint_callback.best_model_path,
    )
    trainer.fit(model)
    assert timer._offset > 0
    assert trainer.global_step == saved_global_step + 1
def test_timer_parse_duration(duration, expected):
    timer = Timer(duration=duration)
    assert (timer.time_remaining() == expected is
            None) or (timer.time_remaining() == expected.total_seconds())