def test_can_prepare_data(local_rank, node_rank): dm = BoringDataModule() trainer = Trainer() trainer.datamodule = dm # 1 no DM # prepare_data_per_node = True # local rank = 0 (True) trainer.prepare_data_per_node = True local_rank.return_value = 0 assert trainer.local_rank == 0 assert trainer.data_connector.can_prepare_data() # local rank = 1 (False) local_rank.return_value = 1 assert trainer.local_rank == 1 assert not trainer.data_connector.can_prepare_data() # prepare_data_per_node = False (prepare across all nodes) # global rank = 0 (True) trainer.prepare_data_per_node = False node_rank.return_value = 0 local_rank.return_value = 0 assert trainer.data_connector.can_prepare_data() # global rank = 1 (False) node_rank.return_value = 1 local_rank.return_value = 0 assert not trainer.data_connector.can_prepare_data() node_rank.return_value = 0 local_rank.return_value = 1 assert not trainer.data_connector.can_prepare_data() # 2 dm # prepar per node = True # local rank = 0 (True) trainer.prepare_data_per_node = True local_rank.return_value = 0 # is_overridden prepare data = True # has been called # False dm._has_prepared_data = True assert not trainer.data_connector.can_prepare_data() # has not been called # True dm._has_prepared_data = False assert trainer.data_connector.can_prepare_data() # is_overridden prepare data = False # True dm.prepare_data = None assert trainer.data_connector.can_prepare_data()
def test_can_prepare_data(local_rank, node_rank): dm = BoringDataModule() trainer = Trainer() trainer.datamodule = dm # 1 no DM # prepare_data_per_node = True # local rank = 0 (True) dm.random_full = None dm._has_prepared_data = False local_rank.return_value = 0 assert trainer.local_rank == 0 trainer.data_connector.prepare_data() assert dm.random_full is not None # local rank = 1 (False) dm.random_full = None dm._has_prepared_data = False local_rank.return_value = 1 assert trainer.local_rank == 1 trainer.data_connector.prepare_data() assert dm.random_full is None # prepare_data_per_node = False (prepare across all nodes) # global rank = 0 (True) dm.random_full = None dm._has_prepared_data = False dm.prepare_data_per_node = False node_rank.return_value = 0 local_rank.return_value = 0 trainer.data_connector.prepare_data() assert dm.random_full is not None # global rank = 1 (False) dm.random_full = None dm._has_prepared_data = False node_rank.return_value = 1 local_rank.return_value = 0 trainer.data_connector.prepare_data() assert dm.random_full is None node_rank.return_value = 0 local_rank.return_value = 1 trainer.data_connector.prepare_data() assert dm.random_full is None # 2 dm # prepar per node = True # local rank = 0 (True) dm.prepare_data_per_node = True local_rank.return_value = 0 with mock.patch.object(trainer.datamodule, "prepare_data") as dm_mock: # is_overridden prepare data = True # has been called # False dm._has_prepared_data = True trainer.data_connector.prepare_data() dm_mock.assert_not_called() # has not been called # True dm._has_prepared_data = False trainer.data_connector.prepare_data() dm_mock.assert_called_once()