Beispiel #1
0
def test_can_prepare_data(local_rank, node_rank):
    dm = Mock(spec=LightningDataModule)
    dm.prepare_data_per_node = True
    trainer = Trainer()
    trainer.datamodule = dm

    # 1 no DM
    # prepare_data_per_node = True
    # local rank = 0   (True)
    dm.prepare_data.assert_not_called()
    local_rank.return_value = 0
    assert trainer.local_rank == 0

    trainer._data_connector.prepare_data()
    dm.prepare_data.assert_called_once()

    # local rank = 1   (False)
    dm.reset_mock()
    local_rank.return_value = 1
    assert trainer.local_rank == 1

    trainer._data_connector.prepare_data()
    dm.prepare_data.assert_not_called()

    # prepare_data_per_node = False (prepare across all nodes)
    # global rank = 0   (True)
    dm.reset_mock()
    dm.prepare_data_per_node = False
    node_rank.return_value = 0
    local_rank.return_value = 0

    trainer._data_connector.prepare_data()
    dm.prepare_data.assert_called_once()

    # global rank = 1   (False)
    dm.reset_mock()
    node_rank.return_value = 1
    local_rank.return_value = 0

    trainer._data_connector.prepare_data()
    dm.prepare_data.assert_not_called()

    node_rank.return_value = 0
    local_rank.return_value = 1

    trainer._data_connector.prepare_data()
    dm.prepare_data.assert_not_called()

    # 2 dm
    # prepar per node = True
    # local rank = 0 (True)
    dm.prepare_data_per_node = True
    local_rank.return_value = 0

    # is_overridden prepare data = True
    trainer._data_connector.prepare_data()
    dm.prepare_data.assert_called_once()
Beispiel #2
0
def test_can_prepare_data(local_rank, node_rank):
    dm = BoringDataModule()
    trainer = Trainer()
    trainer.datamodule = dm

    # 1 no DM
    # prepare_data_per_node = True
    # local rank = 0   (True)
    dm.random_full = None
    local_rank.return_value = 0
    assert trainer.local_rank == 0

    trainer._data_connector.prepare_data()
    assert dm.random_full is not None

    # local rank = 1   (False)
    dm.random_full = None
    local_rank.return_value = 1
    assert trainer.local_rank == 1

    trainer._data_connector.prepare_data()
    assert dm.random_full is None

    # prepare_data_per_node = False (prepare across all nodes)
    # global rank = 0   (True)
    dm.random_full = None
    dm.prepare_data_per_node = False
    node_rank.return_value = 0
    local_rank.return_value = 0

    trainer._data_connector.prepare_data()
    assert dm.random_full is not None

    # global rank = 1   (False)
    dm.random_full = None
    node_rank.return_value = 1
    local_rank.return_value = 0

    trainer._data_connector.prepare_data()
    assert dm.random_full is None

    node_rank.return_value = 0
    local_rank.return_value = 1

    trainer._data_connector.prepare_data()
    assert dm.random_full is None

    # 2 dm
    # prepar per node = True
    # local rank = 0 (True)
    dm.prepare_data_per_node = True
    local_rank.return_value = 0

    with mock.patch.object(trainer.datamodule, "prepare_data") as dm_mock:
        # is_overridden prepare data = True
        trainer._data_connector.prepare_data()
        dm_mock.assert_called_once()
def test_can_prepare_data(local_rank, node_rank):

    dm = BoringDataModule()
    trainer = Trainer()
    trainer.datamodule = dm

    # 1 no DM
    # prepare_data_per_node = True
    # local rank = 0   (True)
    trainer.prepare_data_per_node = True

    local_rank.return_value = 0
    assert trainer.local_rank == 0
    assert trainer.data_connector.can_prepare_data()

    # local rank = 1   (False)
    local_rank.return_value = 1
    assert trainer.local_rank == 1
    assert not trainer.data_connector.can_prepare_data()

    # prepare_data_per_node = False (prepare across all nodes)
    # global rank = 0   (True)
    trainer.prepare_data_per_node = False
    node_rank.return_value = 0
    local_rank.return_value = 0
    assert trainer.data_connector.can_prepare_data()

    # global rank = 1   (False)
    node_rank.return_value = 1
    local_rank.return_value = 0
    assert not trainer.data_connector.can_prepare_data()
    node_rank.return_value = 0
    local_rank.return_value = 1
    assert not trainer.data_connector.can_prepare_data()

    # 2 dm
    # prepar per node = True
    # local rank = 0 (True)
    trainer.prepare_data_per_node = True
    local_rank.return_value = 0

    # is_overridden prepare data = True
    # has been called
    # False
    dm._has_prepared_data = True
    assert not trainer.data_connector.can_prepare_data()

    # has not been called
    # True
    dm._has_prepared_data = False
    assert trainer.data_connector.can_prepare_data()

    # is_overridden prepare data = False
    # True
    dm.prepare_data = None
    assert trainer.data_connector.can_prepare_data()
Beispiel #4
0
def test_inconsistent_prepare_data_per_node(tmpdir):
    with pytest.raises(
            MisconfigurationException,
            match="Inconsistent settings found for `prepare_data_per_node`."):
        model = BoringModel()
        dm = BoringDataModule()
        trainer = Trainer(prepare_data_per_node=False)
        trainer.model = model
        trainer.datamodule = dm
        trainer.data_connector.prepare_data()
def test_can_prepare_data(tmpdir):

    dm = TrialMNISTDataModule()
    trainer = Trainer()
    trainer.datamodule = dm

    # 1 no DM
    # prepare_data_per_node = True
    # local rank = 0   (True)
    trainer.prepare_data_per_node = True
    trainer.local_rank = 0
    assert trainer.data_connector.can_prepare_data()

    # local rank = 1   (False)
    trainer.local_rank = 1
    assert not trainer.data_connector.can_prepare_data()

    # prepare_data_per_node = False (prepare across all nodes)
    # global rank = 0   (True)
    trainer.prepare_data_per_node = False
    trainer.node_rank = 0
    trainer.local_rank = 0
    assert trainer.data_connector.can_prepare_data()

    # global rank = 1   (False)
    trainer.node_rank = 1
    trainer.local_rank = 0
    assert not trainer.data_connector.can_prepare_data()
    trainer.node_rank = 0
    trainer.local_rank = 1
    assert not trainer.data_connector.can_prepare_data()

    # 2 dm
    # prepar per node = True
    # local rank = 0 (True)
    trainer.prepare_data_per_node = True
    trainer.local_rank = 0

    # is_overridden prepare data = True
    # has been called
    # False
    dm._has_prepared_data = True
    assert not trainer.data_connector.can_prepare_data()

    # has not been called
    # True
    dm._has_prepared_data = False
    assert trainer.data_connector.can_prepare_data()

    # is_overridden prepare data = False
    # True
    dm.prepare_data = None
    assert trainer.data_connector.can_prepare_data()
def model_cases():
    class TestHparamsNamespace(LightningModule):
        learning_rate = 1

        def __contains__(self, item):
            return item == "learning_rate"

    TestHparamsDict = {"learning_rate": 2}

    class TestModel1(LightningModule):  # test for namespace
        learning_rate = 0

    model1 = TestModel1()

    class TestModel2(LightningModule):  # test for hparams namespace
        hparams = TestHparamsNamespace()

    model2 = TestModel2()

    class TestModel3(LightningModule):  # test for hparams dict
        hparams = TestHparamsDict

    model3 = TestModel3()

    class TestModel4(LightningModule):  # fail case
        batch_size = 1

    model4 = TestModel4()

    trainer = Trainer()
    datamodule = LightningDataModule()
    datamodule.batch_size = 8
    trainer.datamodule = datamodule

    model5 = LightningModule()
    model5.trainer = trainer

    class TestModel6(LightningModule):  # test for datamodule w/ hparams w/o attribute (should use datamodule)
        hparams = TestHparamsDict

    model6 = TestModel6()
    model6.trainer = trainer

    TestHparamsDict2 = {"batch_size": 2}

    class TestModel7(LightningModule):  # test for datamodule w/ hparams w/ attribute (should use datamodule)
        hparams = TestHparamsDict2

    model7 = TestModel7()
    model7.trainer = trainer

    return model1, model2, model3, model4, model5, model6, model7