Esempio n. 1
0
def test_v1_7_0_datamodule_transform_properties(tmpdir):
    dm = MNISTDataModule()
    with pytest.deprecated_call(
            match=
            r"DataModule property `train_transforms` was deprecated in v1.5"):
        dm.train_transforms = "a"
    with pytest.deprecated_call(
            match=r"DataModule property `val_transforms` was deprecated in v1.5"
    ):
        dm.val_transforms = "b"
    with pytest.deprecated_call(
            match=
            r"DataModule property `test_transforms` was deprecated in v1.5"):
        dm.test_transforms = "c"
    with pytest.deprecated_call(
            match=
            r"DataModule property `train_transforms` was deprecated in v1.5"):
        _ = LightningDataModule(train_transforms="a")
    with pytest.deprecated_call(
            match=r"DataModule property `val_transforms` was deprecated in v1.5"
    ):
        _ = LightningDataModule(val_transforms="b")
    with pytest.deprecated_call(
            match=
            r"DataModule property `test_transforms` was deprecated in v1.5"):
        _ = LightningDataModule(test_transforms="c")
    with pytest.deprecated_call(
            match=
            r"DataModule property `test_transforms` was deprecated in v1.5"):
        _ = LightningDataModule(test_transforms="c", dims=(1, 1, 1))
Esempio n. 2
0
def test_v1_7_0_datamodule_dims_property(tmpdir):
    dm = MNISTDataModule()
    with pytest.deprecated_call(
            match=r"DataModule property `dims` was deprecated in v1.5"):
        _ = dm.dims
    with pytest.deprecated_call(
            match=r"DataModule property `dims` was deprecated in v1.5"):
        _ = LightningDataModule(dims=(1, 1, 1))
def model_cases():
    class TestHparamsNamespace(LightningModule):
        learning_rate = 1

        def __contains__(self, item):
            return item == "learning_rate"

    TestHparamsDict = {"learning_rate": 2}

    class TestModel1(LightningModule):  # test for namespace
        learning_rate = 0

    model1 = TestModel1()

    class TestModel2(LightningModule):  # test for hparams namespace
        hparams = TestHparamsNamespace()

    model2 = TestModel2()

    class TestModel3(LightningModule):  # test for hparams dict
        hparams = TestHparamsDict

    model3 = TestModel3()

    class TestModel4(LightningModule):  # fail case
        batch_size = 1

    model4 = TestModel4()

    trainer = Trainer()
    datamodule = LightningDataModule()
    datamodule.batch_size = 8
    trainer.datamodule = datamodule

    model5 = LightningModule()
    model5.trainer = trainer

    class TestModel6(LightningModule):  # test for datamodule w/ hparams w/o attribute (should use datamodule)
        hparams = TestHparamsDict

    model6 = TestModel6()
    model6.trainer = trainer

    TestHparamsDict2 = {"batch_size": 2}

    class TestModel7(LightningModule):  # test for datamodule w/ hparams w/ attribute (should use datamodule)
        hparams = TestHparamsDict2

    model7 = TestModel7()
    model7.trainer = trainer

    return model1, model2, model3, model4, model5, model6, model7