def test_timer_resume_training(tmpdir): """ Test that the timer can resume together with the Trainer. """ model = BoringModel() timer = Timer(duration=timedelta(milliseconds=200)) checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1) # initial training trainer = Trainer( default_root_dir=tmpdir, max_epochs=100, callbacks=[timer, checkpoint_callback], ) trainer.fit(model) assert not timer._offset assert timer.time_remaining() <= 0 assert trainer.current_epoch < 99 saved_global_step = trainer.global_step # resume training (with depleted timer timer = Timer(duration=timedelta(milliseconds=200)) trainer = Trainer( default_root_dir=tmpdir, callbacks=[timer, checkpoint_callback], resume_from_checkpoint=checkpoint_callback.best_model_path, ) trainer.fit(model) assert timer._offset > 0 assert trainer.global_step == saved_global_step + 1
def test_timer_duration_min_steps_override(tmpdir, min_steps, min_epochs): model = BoringModel() duration = timedelta(0) timer = Timer(duration=duration) trainer = Trainer(default_root_dir=tmpdir, callbacks=[timer], min_steps=min_steps, min_epochs=min_epochs) trainer.fit(model) if min_epochs: assert trainer.current_epoch >= min_epochs if min_steps: assert trainer.global_step >= min_steps - 1 assert timer.time_elapsed() > duration.total_seconds()
def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None) -> None: if max_time is None: return if any(isinstance(cb, Timer) for cb in self.trainer.callbacks): rank_zero_info("Ignoring `Trainer(max_time=...)`, callbacks list already contains a Timer.") return timer = Timer(duration=max_time, interval="step") self.trainer.callbacks.append(timer)
def test_timer_zero_duration_stop(tmpdir, interval): """Test that the timer stops training immediately after the first check occurs.""" model = BoringModel() duration = timedelta(0) timer = Timer(duration=duration, interval=interval) trainer = Trainer(default_root_dir=tmpdir, callbacks=[timer]) trainer.fit(model) assert trainer.global_step == 0 assert trainer.current_epoch == 0
def test_timer_stops_training(tmpdir, caplog): """Test that the timer stops training before reaching max_epochs.""" model = BoringModel() duration = timedelta(milliseconds=100) timer = Timer(duration=duration) trainer = Trainer(default_root_dir=tmpdir, max_epochs=1000, callbacks=[timer]) with caplog.at_level(logging.INFO): trainer.fit(model) assert trainer.global_step > 1 assert trainer.current_epoch < 999 assert "Time limit reached." in caplog.text assert "Signaling Trainer to stop." in caplog.text
def test_timer_track_stages(tmpdir): """Test that the timer tracks time also for other stages (train/val/test).""" # note: skipped on windows because time resolution of time.monotonic() is not high enough for this fast test model = BoringModel() timer = Timer() trainer = Trainer(default_root_dir=tmpdir, max_steps=5, callbacks=[timer]) trainer.fit(model) assert timer.time_elapsed() == timer.time_elapsed("train") > 0 assert timer.time_elapsed("validate") > 0 assert timer.time_elapsed("test") == 0 trainer.test(model) assert timer.time_elapsed("test") > 0
def test_trainer_flag(caplog): class TestModel(BoringModel): def on_fit_start(self): raise SystemExit() trainer = Trainer(max_time=dict(seconds=1337)) with pytest.raises(SystemExit): trainer.fit(TestModel()) timer = [c for c in trainer.callbacks if isinstance(c, Timer)][0] assert timer._duration == 1337 trainer = Trainer(max_time=dict(seconds=1337), callbacks=[Timer()]) with pytest.raises(SystemExit), caplog.at_level(level=logging.INFO): trainer.fit(TestModel()) assert "callbacks list already contains a Timer" in caplog.text
def test_timer_zero_duration_stop(tmpdir, interval): """Test that the timer stops training immediately after the first check occurs.""" model = BoringModel() duration = timedelta(0) timer = Timer(duration=duration, interval=interval) trainer = Trainer(default_root_dir=tmpdir, callbacks=[timer]) trainer.fit(model) if interval == "step": # timer triggers stop on step end assert trainer.global_step == 1 assert trainer.current_epoch == 0 else: # timer triggers stop on epoch end assert trainer.global_step == len(trainer.train_dataloader) assert trainer.current_epoch == 0
def test_timer_parse_duration(duration, expected): timer = Timer(duration=duration) assert (timer.time_remaining() == expected is None) or (timer.time_remaining() == expected.total_seconds())