Exemplo n.º 1
0
def test_can_prepare_data(local_rank, node_rank):
    dm = BoringDataModule()
    trainer = Trainer()
    trainer.datamodule = dm

    # 1 no DM
    # prepare_data_per_node = True
    # local rank = 0   (True)
    dm.random_full = None
    local_rank.return_value = 0
    assert trainer.local_rank == 0

    trainer._data_connector.prepare_data()
    assert dm.random_full is not None

    # local rank = 1   (False)
    dm.random_full = None
    local_rank.return_value = 1
    assert trainer.local_rank == 1

    trainer._data_connector.prepare_data()
    assert dm.random_full is None

    # prepare_data_per_node = False (prepare across all nodes)
    # global rank = 0   (True)
    dm.random_full = None
    dm.prepare_data_per_node = False
    node_rank.return_value = 0
    local_rank.return_value = 0

    trainer._data_connector.prepare_data()
    assert dm.random_full is not None

    # global rank = 1   (False)
    dm.random_full = None
    node_rank.return_value = 1
    local_rank.return_value = 0

    trainer._data_connector.prepare_data()
    assert dm.random_full is None

    node_rank.return_value = 0
    local_rank.return_value = 1

    trainer._data_connector.prepare_data()
    assert dm.random_full is None

    # 2 dm
    # prepar per node = True
    # local rank = 0 (True)
    dm.prepare_data_per_node = True
    local_rank.return_value = 0

    with mock.patch.object(trainer.datamodule, "prepare_data") as dm_mock:
        # is_overridden prepare data = True
        trainer._data_connector.prepare_data()
        dm_mock.assert_called_once()
def test_can_prepare_data(local_rank, node_rank):

    model = BoringModel()
    dm = BoringDataModule()
    trainer = Trainer()
    trainer.datamodule = dm

    # 1 no DM
    # prepare_data_per_node = True
    # local rank = 0   (True)
    trainer.prepare_data_per_node = True

    dm.random_full = None
    dm._has_prepared_data = False
    local_rank.return_value = 0
    assert trainer.local_rank == 0
    assert trainer.data_connector.can_prepare_data()

    trainer.data_connector.prepare_data(model)
    assert dm.random_full is not None

    # local rank = 1   (False)
    dm.random_full = None
    dm._has_prepared_data = False
    local_rank.return_value = 1
    assert trainer.local_rank == 1
    assert not trainer.data_connector.can_prepare_data()

    trainer.data_connector.prepare_data(model)
    assert dm.random_full is None

    # prepare_data_per_node = False (prepare across all nodes)
    # global rank = 0   (True)
    dm.random_full = None
    dm._has_prepared_data = False
    trainer.prepare_data_per_node = False
    node_rank.return_value = 0
    local_rank.return_value = 0
    assert trainer.data_connector.can_prepare_data()

    trainer.data_connector.prepare_data(model)
    assert dm.random_full is not None

    # global rank = 1   (False)
    dm.random_full = None
    dm._has_prepared_data = False
    node_rank.return_value = 1
    local_rank.return_value = 0
    assert not trainer.data_connector.can_prepare_data()

    trainer.data_connector.prepare_data(model)
    assert dm.random_full is None

    node_rank.return_value = 0
    local_rank.return_value = 1
    assert not trainer.data_connector.can_prepare_data()

    trainer.data_connector.prepare_data(model)
    assert dm.random_full is None

    # 2 dm
    # prepar per node = True
    # local rank = 0 (True)
    trainer.prepare_data_per_node = True
    local_rank.return_value = 0

    # is_overridden prepare data = True
    # has been called
    # False
    dm._has_prepared_data = True
    assert not trainer.data_connector.can_prepare_data()

    # has not been called
    # True
    dm._has_prepared_data = False
    assert trainer.data_connector.can_prepare_data()

    # is_overridden prepare data = False
    # True
    dm.prepare_data = None
    assert trainer.data_connector.can_prepare_data()