def test_calculate_shuffle_buffer_size_small_row_size(self): hvd_size = 4 local_size = 2 hvd_mock = mock.MagicMock() hvd_mock.local_size = lambda: local_size hvd_mock.allgather = lambda x: torch.tensor([local_size for _ in range(hvd_size)]) avg_row_size = 100 train_row_count_per_worker = 100 calculate_shuffle_buffer_size = remote._calculate_shuffle_buffer_size_fn() shuffle_size = calculate_shuffle_buffer_size(hvd_mock, avg_row_size, train_row_count_per_worker) assert shuffle_size == train_row_count_per_worker
def test_calculate_shuffle_buffer_size(self): # case with 2 workers, one with 5 ranks and second with 3 ranks hvd_mock = mock.MagicMock() hvd_mock.allgather = lambda x: torch.tensor([5, 5, 5, 5, 5, 3, 3, 3]) hvd_mock.local_size = lambda: 2 avg_row_size = 100000 train_row_count_per_worker = 1000000 calculate_shuffle_buffer_size = remote._calculate_shuffle_buffer_size_fn() shuffle_size = calculate_shuffle_buffer_size(hvd_mock, avg_row_size, train_row_count_per_worker) assert int(shuffle_size) == \ int(constants.TOTAL_BUFFER_MEMORY_CAP_GIB * constants.BYTES_PER_GIB / avg_row_size / 5)