def test_can_prepare_data(local_rank, node_rank): dm = Mock(spec=LightningDataModule) dm.prepare_data_per_node = True trainer = Trainer() trainer.datamodule = dm # 1 no DM # prepare_data_per_node = True # local rank = 0 (True) dm.prepare_data.assert_not_called() local_rank.return_value = 0 assert trainer.local_rank == 0 trainer._data_connector.prepare_data() dm.prepare_data.assert_called_once() # local rank = 1 (False) dm.reset_mock() local_rank.return_value = 1 assert trainer.local_rank == 1 trainer._data_connector.prepare_data() dm.prepare_data.assert_not_called() # prepare_data_per_node = False (prepare across all nodes) # global rank = 0 (True) dm.reset_mock() dm.prepare_data_per_node = False node_rank.return_value = 0 local_rank.return_value = 0 trainer._data_connector.prepare_data() dm.prepare_data.assert_called_once() # global rank = 1 (False) dm.reset_mock() node_rank.return_value = 1 local_rank.return_value = 0 trainer._data_connector.prepare_data() dm.prepare_data.assert_not_called() node_rank.return_value = 0 local_rank.return_value = 1 trainer._data_connector.prepare_data() dm.prepare_data.assert_not_called() # 2 dm # prepar per node = True # local rank = 0 (True) dm.prepare_data_per_node = True local_rank.return_value = 0 # is_overridden prepare data = True trainer._data_connector.prepare_data() dm.prepare_data.assert_called_once()
def test_can_prepare_data(local_rank, node_rank): dm = BoringDataModule() trainer = Trainer() trainer.datamodule = dm # 1 no DM # prepare_data_per_node = True # local rank = 0 (True) dm.random_full = None local_rank.return_value = 0 assert trainer.local_rank == 0 trainer._data_connector.prepare_data() assert dm.random_full is not None # local rank = 1 (False) dm.random_full = None local_rank.return_value = 1 assert trainer.local_rank == 1 trainer._data_connector.prepare_data() assert dm.random_full is None # prepare_data_per_node = False (prepare across all nodes) # global rank = 0 (True) dm.random_full = None dm.prepare_data_per_node = False node_rank.return_value = 0 local_rank.return_value = 0 trainer._data_connector.prepare_data() assert dm.random_full is not None # global rank = 1 (False) dm.random_full = None node_rank.return_value = 1 local_rank.return_value = 0 trainer._data_connector.prepare_data() assert dm.random_full is None node_rank.return_value = 0 local_rank.return_value = 1 trainer._data_connector.prepare_data() assert dm.random_full is None # 2 dm # prepar per node = True # local rank = 0 (True) dm.prepare_data_per_node = True local_rank.return_value = 0 with mock.patch.object(trainer.datamodule, "prepare_data") as dm_mock: # is_overridden prepare data = True trainer._data_connector.prepare_data() dm_mock.assert_called_once()
def test_can_prepare_data(local_rank, node_rank): dm = BoringDataModule() trainer = Trainer() trainer.datamodule = dm # 1 no DM # prepare_data_per_node = True # local rank = 0 (True) trainer.prepare_data_per_node = True local_rank.return_value = 0 assert trainer.local_rank == 0 assert trainer.data_connector.can_prepare_data() # local rank = 1 (False) local_rank.return_value = 1 assert trainer.local_rank == 1 assert not trainer.data_connector.can_prepare_data() # prepare_data_per_node = False (prepare across all nodes) # global rank = 0 (True) trainer.prepare_data_per_node = False node_rank.return_value = 0 local_rank.return_value = 0 assert trainer.data_connector.can_prepare_data() # global rank = 1 (False) node_rank.return_value = 1 local_rank.return_value = 0 assert not trainer.data_connector.can_prepare_data() node_rank.return_value = 0 local_rank.return_value = 1 assert not trainer.data_connector.can_prepare_data() # 2 dm # prepar per node = True # local rank = 0 (True) trainer.prepare_data_per_node = True local_rank.return_value = 0 # is_overridden prepare data = True # has been called # False dm._has_prepared_data = True assert not trainer.data_connector.can_prepare_data() # has not been called # True dm._has_prepared_data = False assert trainer.data_connector.can_prepare_data() # is_overridden prepare data = False # True dm.prepare_data = None assert trainer.data_connector.can_prepare_data()
def test_inconsistent_prepare_data_per_node(tmpdir): with pytest.raises( MisconfigurationException, match="Inconsistent settings found for `prepare_data_per_node`."): model = BoringModel() dm = BoringDataModule() trainer = Trainer(prepare_data_per_node=False) trainer.model = model trainer.datamodule = dm trainer.data_connector.prepare_data()
def test_can_prepare_data(tmpdir): dm = TrialMNISTDataModule() trainer = Trainer() trainer.datamodule = dm # 1 no DM # prepare_data_per_node = True # local rank = 0 (True) trainer.prepare_data_per_node = True trainer.local_rank = 0 assert trainer.data_connector.can_prepare_data() # local rank = 1 (False) trainer.local_rank = 1 assert not trainer.data_connector.can_prepare_data() # prepare_data_per_node = False (prepare across all nodes) # global rank = 0 (True) trainer.prepare_data_per_node = False trainer.node_rank = 0 trainer.local_rank = 0 assert trainer.data_connector.can_prepare_data() # global rank = 1 (False) trainer.node_rank = 1 trainer.local_rank = 0 assert not trainer.data_connector.can_prepare_data() trainer.node_rank = 0 trainer.local_rank = 1 assert not trainer.data_connector.can_prepare_data() # 2 dm # prepar per node = True # local rank = 0 (True) trainer.prepare_data_per_node = True trainer.local_rank = 0 # is_overridden prepare data = True # has been called # False dm._has_prepared_data = True assert not trainer.data_connector.can_prepare_data() # has not been called # True dm._has_prepared_data = False assert trainer.data_connector.can_prepare_data() # is_overridden prepare data = False # True dm.prepare_data = None assert trainer.data_connector.can_prepare_data()
def model_cases(): class TestHparamsNamespace(LightningModule): learning_rate = 1 def __contains__(self, item): return item == "learning_rate" TestHparamsDict = {"learning_rate": 2} class TestModel1(LightningModule): # test for namespace learning_rate = 0 model1 = TestModel1() class TestModel2(LightningModule): # test for hparams namespace hparams = TestHparamsNamespace() model2 = TestModel2() class TestModel3(LightningModule): # test for hparams dict hparams = TestHparamsDict model3 = TestModel3() class TestModel4(LightningModule): # fail case batch_size = 1 model4 = TestModel4() trainer = Trainer() datamodule = LightningDataModule() datamodule.batch_size = 8 trainer.datamodule = datamodule model5 = LightningModule() model5.trainer = trainer class TestModel6(LightningModule): # test for datamodule w/ hparams w/o attribute (should use datamodule) hparams = TestHparamsDict model6 = TestModel6() model6.trainer = trainer TestHparamsDict2 = {"batch_size": 2} class TestModel7(LightningModule): # test for datamodule w/ hparams w/ attribute (should use datamodule) hparams = TestHparamsDict2 model7 = TestModel7() model7.trainer = trainer return model1, model2, model3, model4, model5, model6, model7