예제 #1
0
def test_lightning_setattr(tmpdir, model_cases):
    """Test that the lightning_setattr works in all cases."""
    models = model_cases
    for m in models[:3]:
        lightning_setattr(m, "learning_rate", 10)
        assert lightning_getattr(
            m, "learning_rate") == 10, "attribute not correctly set"

    model5, model6, model7 = models[4:]
    lightning_setattr(model5, "batch_size", 128)
    lightning_setattr(model6, "batch_size", 128)
    lightning_setattr(model7, "batch_size", 128)
    assert lightning_getattr(
        model5, "batch_size") == 128, "batch_size not correctly set"
    assert lightning_getattr(
        model6, "batch_size") == 128, "batch_size not correctly set"
    assert lightning_getattr(
        model7, "batch_size") == 128, "batch_size not correctly set"

    for m in models:
        with pytest.raises(
                AttributeError,
                match=
                "is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule.",
        ):
            lightning_setattr(m, "this_attr_not_exist", None)
예제 #2
0
def test_lightning_setattr(tmpdir):
    """Test that the lightning_setattr works in all cases"""
    models = _get_test_cases()
    for m in models[:3]:
        lightning_setattr(m, 'learning_rate', 10)
        assert lightning_getattr(m, 'learning_rate') == 10, \
            'attribute not correctly set'

    model5, model6, model7 = models[4:]
    lightning_setattr(model5, 'batch_size', 128)
    lightning_setattr(model6, 'batch_size', 128)
    lightning_setattr(model7, 'batch_size', 128)
    assert lightning_getattr(model5, 'batch_size') == 128, \
        'batch_size not correctly set'
    assert lightning_getattr(model6, 'batch_size') == 128, \
        'batch_size not correctly set'
    assert lightning_getattr(model7, 'batch_size') == 128, \
        'batch_size not correctly set'

    for m in models:
        with pytest.raises(
                AttributeError,
                match=
                "is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule."
        ):
            lightning_setattr(m, "this_attr_not_exist", None)
예제 #3
0
def test_lightning_getattr(tmpdir):
    """ Test that the lightning_getattr works in all cases"""
    models = _get_test_cases()
    for i, m in enumerate(models[:3]):
        value = lightning_getattr(m, 'learning_rate')
        assert value == i, 'attribute not correctly extracted'

    model5 = models[4]
    assert lightning_getattr(model5, 'batch_size') == 8, \
        'batch_size not correctly extracted'
예제 #4
0
def test_lightning_setattr(tmpdir):
    """ Test that the lightning_setattr works in all cases"""
    models = _get_test_cases()
    for m in models[:3]:
        lightning_setattr(m, 'learning_rate', 10)
        assert lightning_getattr(m, 'learning_rate') == 10, \
            'attribute not correctly set'

    model5 = models[4]
    lightning_setattr(model5, 'batch_size', 128)
    assert lightning_getattr(model5, 'batch_size') == 128, \
        'batch_size not correctly set'
def _adjust_batch_size(trainer,
                       batch_arg_name: str = 'batch_size',
                       factor: float = 1.0,
                       value: Optional[int] = None,
                       desc: str = None):
    """ Function for adjusting the batch size. It is expected that the user
        has provided a model that has a hparam field called `batch_size` i.e.
        `model.hparams.batch_size` should exist. Additionally there can be a
        datamodule attached to either Trainer or model, in that case the attribute
        also gets updated when present.

    Args:
        trainer: instance of pytorch_lightning.Trainer

        batch_arg_name: field where batch_size is stored in `model.hparams`

        factor: value which the old batch size is multiplied by to get the
            new batch size

        value: if a value is given, will override the batch size with this value.
            Note that the value of `factor` will not have an effect in this case

        desc: either `succeeded` or `failed`. Used purely for logging

    """
    model = trainer.get_model()
    batch_size = lightning_getattr(model, batch_arg_name)
    new_size = value if value is not None else int(batch_size * factor)
    if desc:
        log.info(
            f'Batch size {batch_size} {desc}, trying batch size {new_size}')
    lightning_setattr(model, batch_arg_name, new_size)
    return new_size
예제 #6
0
def _adjust_batch_size(
    trainer: "pl.Trainer",
    batch_arg_name: str = "batch_size",
    factor: float = 1.0,
    value: Optional[int] = None,
    desc: Optional[str] = None,
) -> Tuple[int, bool]:
    """Helper function for adjusting the batch size.

    Args:
        trainer: instance of pytorch_lightning.Trainer

        batch_arg_name: name of the field where batch_size is stored.

        factor: value which the old batch size is multiplied by to get the
            new batch size

        value: if a value is given, will override the batch size with this value.
            Note that the value of `factor` will not have an effect in this case

        desc: either `succeeded` or `failed`. Used purely for logging

    Returns:
        The new batch size for the next trial and a bool that signals whether the
        new value is different than the previous batch size.
    """
    model = trainer.lightning_module
    batch_size = lightning_getattr(model, batch_arg_name)
    new_size = value if value is not None else int(batch_size * factor)
    if desc:
        log.info(
            f"Batch size {batch_size} {desc}, trying batch size {new_size}")

    if not _is_valid_batch_size(new_size, trainer.train_dataloader):
        new_size = min(new_size, len(trainer.train_dataloader.dataset))

    changed = new_size != batch_size
    lightning_setattr(model, batch_arg_name, new_size)
    return new_size, changed
def test_lightning_getattr():
    """Test that the lightning_getattr works in all cases."""
    models = model_cases()
    for i, m in enumerate(models[:3]):
        value = lightning_getattr(m, "learning_rate")
        assert value == i, "attribute not correctly extracted"

    model5, model6, model7 = models[4:]
    assert lightning_getattr(model5, "batch_size") == 8, "batch_size not correctly extracted"
    assert lightning_getattr(model6, "batch_size") == 8, "batch_size not correctly extracted"
    assert lightning_getattr(model7, "batch_size") == 8, "batch_size not correctly extracted"

    for m in models:
        with pytest.raises(
            AttributeError,
            match="is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule.",
        ):
            lightning_getattr(m, "this_attr_not_exist")
예제 #8
0
def test_lightning_getattr(tmpdir, model_cases):
    """Test that the lightning_getattr works in all cases"""
    models = model_cases
    for i, m in enumerate(models[:3]):
        value = lightning_getattr(m, 'learning_rate')
        assert value == i, 'attribute not correctly extracted'

    model5, model6, model7 = models[4:]
    assert lightning_getattr(model5, 'batch_size') == 8, \
        'batch_size not correctly extracted'
    assert lightning_getattr(model6, 'batch_size') == 8, \
        'batch_size not correctly extracted'
    assert lightning_getattr(model7, 'batch_size') == 8, \
        'batch_size not correctly extracted'

    for m in models:
        with pytest.raises(
            AttributeError,
            match="is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule."
        ):
            lightning_getattr(m, "this_attr_not_exist")