def test_only_auto_str(steps_per_epoch="abc"):
    cfg = DictConfig(
        {"lr_scheduler": {
            "scheduler": {
                "steps_per_epoch": steps_per_epoch
            }
        }})
    with pytest.raises(AssertionError):
        _ = resolve_steps_per_epoch(cfg, 1)
def test_int_steps_per_epoch(steps_per_epoch=123):
    cfg = DictConfig(
        {"lr_scheduler": {
            "scheduler": {
                "steps_per_epoch": steps_per_epoch
            }
        }})
    lr_scheduler = resolve_steps_per_epoch(cfg, 1)
    assert cfg.lr_scheduler == lr_scheduler
def test_no_steps_per_epoch(total_steps=20000):
    cfg = DictConfig(
        {"lr_scheduler": {
            "scheduler": {
                "total_steps": total_steps
            }
        }})
    lr_scheduler = resolve_steps_per_epoch(cfg, 1)
    assert cfg.lr_scheduler == lr_scheduler
def test_auto_steps_per_epoch(random_datamodule, batch_size: int = 32):
    cfg = DictConfig({
        "dataset": {
            "batch_size": batch_size
        },
        "lr_scheduler": {
            "scheduler": {
                "steps_per_epoch": "AUTO"
            }
        }
    })
    random_datamodule.prepare_data()
    len_train = random_datamodule.len_train
    lr_scheduler = resolve_steps_per_epoch(cfg, len_train)
    assert lr_scheduler.scheduler.steps_per_epoch == int(
        math.ceil(len_train / batch_size))
def test_no_lr_schedule():
    cfg = DictConfig({})
    lr_scheduler = resolve_steps_per_epoch(cfg, 1)
    assert lr_scheduler is None