Exemplo n.º 1
0
def test_serialize_mpi_bcast_task_to_from_file():
    task = Task(mpi_bcast, 'x')
    filename = 'mpi_bcast_task.task'
    serialize(task, file=filename)
    assert os.path.exists(filename)
    deserialized_task = deserialize(file=filename)
    assert deserialized_task.compute() == task.compute()
    os.remove(filename)
Exemplo n.º 2
0
def mpirun_task_file(task_file, result_file):
    from mpi4py import MPI

    rank = MPI.COMM_WORLD.Get_rank()
    task = deserialize(file=task_file)
    try:
        result = task.compute()
    except:
        result = ExceptionInfo(rank)
    results = MPI.COMM_WORLD.gather(result)
    if rank == 0:
        serialize(results, file=result_file)
Exemplo n.º 3
0
def test_serialize_local_task():
    def local_test_func(x, y=3):
        return x, y

    task = Task(local_test_func, 4)
    serialized_task = serialize(task)
    deserialized_task = deserialize(serialized_task)
    assert deserialized_task.compute() == task.compute()
Exemplo n.º 4
0
        def wrapped_func(*args, **kwargs):
            task_file = '{}.task'.format(func.__name__)
            result_file = '{}.result'.format(func.__name__)

            task = Task(func, *args, **kwargs)
            serialize(task, file=task_file)
            launch_mpirun_task_file(task_file, result_file, **self.kwargs)
            results = deserialize(file=result_file)

            self._remove_file(task_file)
            self._remove_file(result_file)

            exception = self._get_first_exception(results)
            if exception:
                exception.reraise()
            else:
                return self._collect_results(results)
Exemplo n.º 5
0
def test_deserialize_mpi_bcast_to_from_filename():
    filename = 'serialized_mpi_bcast.out'
    serialized_mpi_bcast = serialize(mpi_bcast, file=filename)
    assert serialized_mpi_bcast is None
    assert os.path.exists(filename)
    deserialized_mpi_bcast = deserialize(file=filename)
    assert deserialized_mpi_bcast('x') == mpi_bcast('x')
    os.remove(filename)
Exemplo n.º 6
0
def test_mpirun_task_file_parallel():
    task_file = 'parallel_main_test_func.task'
    result_file = 'parallel_main_test_func.result'

    task = Task(main_test_func, 1)
    serialize(task, file=task_file)
    assert os.path.exists(task_file)

    launch_mpirun_task_file(task_file, result_file, np=2)
    assert os.path.exists(result_file)

    results = deserialize(file=result_file)
    assert results == [(1, 2), (1, 2)]

    if os.path.exists(task_file):
        os.remove(task_file)
    if os.path.exists(result_file):
        os.remove(result_file)
Exemplo n.º 7
0
def test_deserialize_mpi_bcast_to_from_file():
    filename = 'serialized_mpi_bcast.out'
    with open(filename, 'wb') as f:
        serialized_mpi_bcast = serialize(mpi_bcast, file=f)
    assert serialized_mpi_bcast is None
    assert os.path.exists(filename)
    with open(filename, 'rb') as f:
        deserialized_mpi_bcast = deserialize(file=f)
    assert deserialized_mpi_bcast('x') == mpi_bcast('x')
    os.remove(filename)
Exemplo n.º 8
0
def test_serialize_lambda_task():
    task = Task(lambda x: x, 1)
    serialized_task = serialize(task)
    deserialized_task = deserialize(serialized_task)
    assert deserialized_task.compute() == task.compute()
Exemplo n.º 9
0
def test_serialize_mpi_bcast_task():
    task = Task(mpi_bcast, 'x')
    serialized_task = serialize(task)
    deserialized_task = deserialize(serialized_task)
    assert deserialized_task.compute() == task.compute()
Exemplo n.º 10
0
def test_serialize_main_task():
    task = Task(main_test_func, 1, 'a', x='y')
    serialized_task = serialize(task)
    deserialized_task = deserialize(serialized_task)
    assert deserialized_task.compute() == task.compute()
Exemplo n.º 11
0
def test_serialize_mpi_bcast():
    serialized_mpi_bcast = serialize(mpi_bcast)
    if sys.version_info[0] == 2:
        assert isinstance(serialized_mpi_bcast, str)
    elif sys.version_info[0] == 3:
        assert isinstance(serialized_mpi_bcast, bytes)
Exemplo n.º 12
0
def test_deserialize_mpi_bcast():
    serialized_mpi_bcast = serialize(mpi_bcast)
    deserialized_mpi_bcast = deserialize(serialized_mpi_bcast)
    assert deserialized_mpi_bcast('x') == mpi_bcast('x')