def test_serialize_mpi_bcast_task_to_from_file(): task = Task(mpi_bcast, 'x') filename = 'mpi_bcast_task.task' serialize(task, file=filename) assert os.path.exists(filename) deserialized_task = deserialize(file=filename) assert deserialized_task.compute() == task.compute() os.remove(filename)
def test_deserialize_mpi_bcast_to_from_filename(): filename = 'serialized_mpi_bcast.out' serialized_mpi_bcast = serialize(mpi_bcast, file=filename) assert serialized_mpi_bcast is None assert os.path.exists(filename) deserialized_mpi_bcast = deserialize(file=filename) assert deserialized_mpi_bcast('x') == mpi_bcast('x') os.remove(filename)
def test_serialize_local_task(): def local_test_func(x, y=3): return x, y task = Task(local_test_func, 4) serialized_task = serialize(task) deserialized_task = deserialize(serialized_task) assert deserialized_task.compute() == task.compute()
def test_deserialize_mpi_bcast_to_from_file(): filename = 'serialized_mpi_bcast.out' with open(filename, 'wb') as f: serialized_mpi_bcast = serialize(mpi_bcast, file=f) assert serialized_mpi_bcast is None assert os.path.exists(filename) with open(filename, 'rb') as f: deserialized_mpi_bcast = deserialize(file=f) assert deserialized_mpi_bcast('x') == mpi_bcast('x') os.remove(filename)
def mpirun_task_file(task_file, result_file): from mpi4py import MPI rank = MPI.COMM_WORLD.Get_rank() task = deserialize(file=task_file) try: result = task.compute() except: result = ExceptionInfo(rank) results = MPI.COMM_WORLD.gather(result) if rank == 0: serialize(results, file=result_file)
def wrapped_func(*args, **kwargs): task_file = '{}.task'.format(func.__name__) result_file = '{}.result'.format(func.__name__) task = Task(func, *args, **kwargs) serialize(task, file=task_file) launch_mpirun_task_file(task_file, result_file, **self.kwargs) results = deserialize(file=result_file) self._remove_file(task_file) self._remove_file(result_file) exception = self._get_first_exception(results) if exception: exception.reraise() else: return self._collect_results(results)
def test_mpirun_task_file_parallel(): task_file = 'parallel_main_test_func.task' result_file = 'parallel_main_test_func.result' task = Task(main_test_func, 1) serialize(task, file=task_file) assert os.path.exists(task_file) launch_mpirun_task_file(task_file, result_file, np=2) assert os.path.exists(result_file) results = deserialize(file=result_file) assert results == [(1, 2), (1, 2)] if os.path.exists(task_file): os.remove(task_file) if os.path.exists(result_file): os.remove(result_file)
def test_serialize_lambda_task(): task = Task(lambda x: x, 1) serialized_task = serialize(task) deserialized_task = deserialize(serialized_task) assert deserialized_task.compute() == task.compute()
def test_serialize_mpi_bcast_task(): task = Task(mpi_bcast, 'x') serialized_task = serialize(task) deserialized_task = deserialize(serialized_task) assert deserialized_task.compute() == task.compute()
def test_serialize_main_task(): task = Task(main_test_func, 1, 'a', x='y') serialized_task = serialize(task) deserialized_task = deserialize(serialized_task) assert deserialized_task.compute() == task.compute()
def test_deserialize_mpi_bcast(): serialized_mpi_bcast = serialize(mpi_bcast) deserialized_mpi_bcast = deserialize(serialized_mpi_bcast) assert deserialized_mpi_bcast('x') == mpi_bcast('x')