def test_register(): class Custom(BaseModel): def __init__(self): super().__init__() models.register_model(Custom) cls = models.get("Custom") assert cls == Custom with pytest.raises(ValueError): models.register_model(models.DPRNNTasNet)
def test_get_errors(wrong): with pytest.raises(ValueError): # Should raise for anything not a Optimizer instance + unknown string models.get(wrong)
def test_get(model): retrieved = models.get(model.__name__) assert retrieved == model