def test_can_prepare_data(local_rank, node_rank): dm = BoringDataModule() trainer = Trainer() trainer.datamodule = dm # 1 no DM # prepare_data_per_node = True # local rank = 0 (True) dm.random_full = None local_rank.return_value = 0 assert trainer.local_rank == 0 trainer._data_connector.prepare_data() assert dm.random_full is not None # local rank = 1 (False) dm.random_full = None local_rank.return_value = 1 assert trainer.local_rank == 1 trainer._data_connector.prepare_data() assert dm.random_full is None # prepare_data_per_node = False (prepare across all nodes) # global rank = 0 (True) dm.random_full = None dm.prepare_data_per_node = False node_rank.return_value = 0 local_rank.return_value = 0 trainer._data_connector.prepare_data() assert dm.random_full is not None # global rank = 1 (False) dm.random_full = None node_rank.return_value = 1 local_rank.return_value = 0 trainer._data_connector.prepare_data() assert dm.random_full is None node_rank.return_value = 0 local_rank.return_value = 1 trainer._data_connector.prepare_data() assert dm.random_full is None # 2 dm # prepar per node = True # local rank = 0 (True) dm.prepare_data_per_node = True local_rank.return_value = 0 with mock.patch.object(trainer.datamodule, "prepare_data") as dm_mock: # is_overridden prepare data = True trainer._data_connector.prepare_data() dm_mock.assert_called_once()
def test_can_prepare_data(local_rank, node_rank): model = BoringModel() dm = BoringDataModule() trainer = Trainer() trainer.datamodule = dm # 1 no DM # prepare_data_per_node = True # local rank = 0 (True) trainer.prepare_data_per_node = True dm.random_full = None dm._has_prepared_data = False local_rank.return_value = 0 assert trainer.local_rank == 0 assert trainer.data_connector.can_prepare_data() trainer.data_connector.prepare_data(model) assert dm.random_full is not None # local rank = 1 (False) dm.random_full = None dm._has_prepared_data = False local_rank.return_value = 1 assert trainer.local_rank == 1 assert not trainer.data_connector.can_prepare_data() trainer.data_connector.prepare_data(model) assert dm.random_full is None # prepare_data_per_node = False (prepare across all nodes) # global rank = 0 (True) dm.random_full = None dm._has_prepared_data = False trainer.prepare_data_per_node = False node_rank.return_value = 0 local_rank.return_value = 0 assert trainer.data_connector.can_prepare_data() trainer.data_connector.prepare_data(model) assert dm.random_full is not None # global rank = 1 (False) dm.random_full = None dm._has_prepared_data = False node_rank.return_value = 1 local_rank.return_value = 0 assert not trainer.data_connector.can_prepare_data() trainer.data_connector.prepare_data(model) assert dm.random_full is None node_rank.return_value = 0 local_rank.return_value = 1 assert not trainer.data_connector.can_prepare_data() trainer.data_connector.prepare_data(model) assert dm.random_full is None # 2 dm # prepar per node = True # local rank = 0 (True) trainer.prepare_data_per_node = True local_rank.return_value = 0 # is_overridden prepare data = True # has been called # False dm._has_prepared_data = True assert not trainer.data_connector.can_prepare_data() # has not been called # True dm._has_prepared_data = False assert trainer.data_connector.can_prepare_data() # is_overridden prepare data = False # True dm.prepare_data = None assert trainer.data_connector.can_prepare_data()