def test_only_auto_str(steps_per_epoch="abc"): cfg = DictConfig( {"lr_scheduler": { "scheduler": { "steps_per_epoch": steps_per_epoch } }}) with pytest.raises(AssertionError): _ = resolve_steps_per_epoch(cfg, 1)
def test_int_steps_per_epoch(steps_per_epoch=123): cfg = DictConfig( {"lr_scheduler": { "scheduler": { "steps_per_epoch": steps_per_epoch } }}) lr_scheduler = resolve_steps_per_epoch(cfg, 1) assert cfg.lr_scheduler == lr_scheduler
def test_no_steps_per_epoch(total_steps=20000): cfg = DictConfig( {"lr_scheduler": { "scheduler": { "total_steps": total_steps } }}) lr_scheduler = resolve_steps_per_epoch(cfg, 1) assert cfg.lr_scheduler == lr_scheduler
def test_auto_steps_per_epoch(random_datamodule, batch_size: int = 32): cfg = DictConfig({ "dataset": { "batch_size": batch_size }, "lr_scheduler": { "scheduler": { "steps_per_epoch": "AUTO" } } }) random_datamodule.prepare_data() len_train = random_datamodule.len_train lr_scheduler = resolve_steps_per_epoch(cfg, len_train) assert lr_scheduler.scheduler.steps_per_epoch == int( math.ceil(len_train / batch_size))
def test_no_lr_schedule(): cfg = DictConfig({}) lr_scheduler = resolve_steps_per_epoch(cfg, 1) assert lr_scheduler is None