def communication(comm: MPI.Intracomm) -> Callable[[np.ndarray], np.ndarray]:
    """
    Wrapper for communcation step function

    Args:
        comm: MPI communicator

    Returns:
        function which implements the communication step

    """
    # precompute
    left_src, left_dst = comm.Shift(direction=0, disp=-1)
    right_src, right_dst = comm.Shift(direction=0, disp=1)
    bottom_src, bottom_dst = comm.Shift(direction=1, disp=-1)
    top_src, top_dst = comm.Shift(direction=1, disp=1)

    def communicate(f: np.ndarray) -> np.ndarray:
        """
        Implements the communication step. For details we refer to the report in the repository.

        Args:
            f: probability density function

        Returns:
            probability density function after communication

        """
        # send to left
        recvbuf = f[-1, ...].copy()
        comm.Sendrecv(f[1, ...].copy(),
                      left_dst,
                      recvbuf=recvbuf,
                      source=left_src)
        f[-1, ...] = recvbuf
        # send to right
        recvbuf = f[0, ...].copy()
        comm.Sendrecv(f[-2, ...].copy(),
                      right_dst,
                      recvbuf=recvbuf,
                      source=right_src)
        f[0, ...] = recvbuf
        # send to bottom
        recvbuf = f[:, -1, :].copy()
        comm.Sendrecv(f[:, 1, :].copy(),
                      bottom_dst,
                      recvbuf=recvbuf,
                      source=bottom_src)
        f[:, -1, :] = recvbuf
        # send to top
        recvbuf = f[:, 0, :].copy()
        comm.Sendrecv(f[:, -2, :].copy(),
                      top_dst,
                      recvbuf=recvbuf,
                      source=top_src)
        f[:, 0, :] = recvbuf
        return f

    return communicate
Esempio n. 2
0
def mpi_sync_checkpoint(comm: MPI.Intracomm, args, new_pi, old_pi):
    rank = comm.Get_rank()
    checkpath = get_modelpath(args, 'checkpoint')
    if rank == 0:
        torch.save(new_pi.pt_model.state_dict(), checkpath)
    comm.Barrier()
    # Update other policy
    old_pi.pt_model.load_state_dict(torch.load(checkpath, args.device))
def save_mpiio(comm: MPI.Intracomm, fn: str, g_kl: np.ndarray):
    """
    Write a global two-dimensional array to a single file in the npy format
    using MPI I/O: https://docs.scipy.org/doc/numpy/neps/npy-format.html

    Arrays written with this function can be read with numpy.load.

    Args:
        comm: MPI communicator
        fn: File name
        g_kl: array_like Portion of the array on this MPI processes. This needs to be a two-dimensional array.

    """
    from numpy.lib.format import dtype_to_descr, magic
    magic_str = magic(1, 0)

    local_nx, local_ny = g_kl.shape
    nx = np.empty_like(local_nx)
    ny = np.empty_like(local_ny)

    commx = comm.Sub((True, False))
    commy = comm.Sub((False, True))
    commx.Allreduce(np.asarray(local_nx), nx)
    commy.Allreduce(np.asarray(local_ny), ny)

    arr_dict_str = str({
        'descr': dtype_to_descr(g_kl.dtype),
        'fortran_order': False,
        'shape': (np.asscalar(nx), np.asscalar(ny))
    })
    while (len(arr_dict_str) + len(magic_str) + 2) % 16 != 15:
        arr_dict_str += ' '
    arr_dict_str += '\n'
    header_len = len(arr_dict_str) + len(magic_str) + 2

    offsetx = np.zeros_like(local_nx)
    commx.Exscan(np.asarray(ny * local_nx), offsetx)
    offsety = np.zeros_like(local_ny)
    commy.Exscan(np.asarray(local_ny), offsety)

    file = MPI.File.Open(comm, fn, MPI.MODE_CREATE | MPI.MODE_WRONLY)
    if comm.Get_rank() == 0:
        file.Write(magic_str)
        file.Write(np.int16(len(arr_dict_str)))
        file.Write(arr_dict_str.encode('latin-1'))
    mpitype = MPI._typedict[g_kl.dtype.char]
    filetype = mpitype.Create_vector(g_kl.shape[0], g_kl.shape[1], ny)
    filetype.Commit()
    file.Set_view(header_len + (offsety + offsetx) * mpitype.Get_size(),
                  filetype=filetype)
    file.Write_all(g_kl.copy())
    filetype.Free()
    file.Close()
Esempio n. 4
0
    def _send_exit_signal(self, icomm: MPI.Intracomm) -> None:
        """Send exit signal to intercommunicator.

        """
        service_name = self._service_name(icomm)
        logging.debug('Sending exit signal to {}:{}'
                      .format(icomm, service_name))

        icomm.send(None, dest=0)

        logging.debug('Sent exit signal to {}:{}'
                      .format(icomm, service_name))
Esempio n. 5
0
def mpi_disk_append_replay(comm: MPI.Intracomm, args, replays):
    rank = comm.Get_rank()
    for worker in range(comm.Get_size()):
        if rank == worker:
            if os.path.exists(args.replay_path):
                all_replays = load_replay(args.replay_path)
                all_replays.extend(replays)
            else:
                all_replays = replays

            with open(args.replay_path, 'wb') as f:
                pickle.dump(all_replays, f)
        comm.Barrier()
Esempio n. 6
0
    def generate_formatter(comm: MPI.Intracomm) -> Formatter:
        rank = comm.Get_rank()
        hostname = MPI.Get_processor_name()

        return Formatter(
            f"%(asctime)s[{hostname}][{rank}][%(levelname)s] - %(message)s"
        )
Esempio n. 7
0
def mpi_log_debug(comm: MPI.Intracomm, s):
    """
    Only the first worker prints stuff
    :param rank:
    :param s:
    :return:
    """
    rank = comm.Get_rank()
    if rank == 0:
        log_debug(s)
Esempio n. 8
0
def mpi_sample_eventdata(comm: MPI.Intracomm, replay_path, batches, batchsize):
    """
    :param replay_path:
    :param batches:
    :param batchsize:
    :return: Batches of sample data, len of total data that was sampled
    """
    # Workers sample data one at a time to avoid memory issues
    rank = comm.Get_rank()
    world_size = comm.Get_size()
    sample_data = None
    replay_len = None
    for worker in range(world_size):
        if rank == worker:
            replay = load_replay(replay_path)
            replay_len = len(replay)
            # Seperate into black wins and black non-wins to ensure even sampling between the two
            black_wins = list(filter(lambda traj: traj.get_winner() == 1, replay))
            black_nonwins = list(filter(lambda traj: traj.get_winner() != 1, replay))
            black_wins = replay_to_events(black_wins)
            black_nonwins = replay_to_events(black_nonwins)
            n = min(len(black_wins), len(black_nonwins))
            sample_size = min(batchsize * batches // 2, n)
            sample_data = random.sample(black_wins, sample_size) + random.sample(black_nonwins, sample_size)
            # Save memory
            del replay
        comm.Barrier()

    random.shuffle(sample_data)
    sample_data = events_to_numpy(sample_data)

    sample_size = len(sample_data[0])
    for component in sample_data:
        assert len(component) == sample_size

    splits = max(sample_size // batchsize, 1)
    batched_sampledata = [np.array_split(component, splits) for component in sample_data]
    batched_sampledata = list(zip(*batched_sampledata))

    return batched_sampledata, replay_len
Esempio n. 9
0
def mpi_sync_data(comm: MPI.Intracomm, args):
    rank = comm.Get_rank()
    if rank == 0:
        # Clear worker data
        data.reset_replay(args)

        if args.baseline or args.customdir != '':
            if os.path.exists(args.checkpath):
                # Ensure that we do not use any old parameters
                os.remove(args.checkpath)
            if args.baseline:
                mpi_log_debug(comm, "Starting from baseline")
            else:
                mpi_log_debug(comm,
                              f"Starting from model file: {args.custompath}")
        else:
            # Save new model
            _, new_model = baselines.create_policy(args, '')
            torch.save(new_model.state_dict(), args.checkpath)
            mpi_log_debug(comm, "Starting from scratch")

    comm.Barrier()
Esempio n. 10
0
    def optimize(self, comm: MPI.Intracomm, batched_data, optimizer):
        raw_metrics = []
        self.train()
        for states, actions, reward, children, terminal, wins, pi in batched_data:
            metrics = self.train_step(optimizer, states, actions, reward,
                                      children, terminal, wins, pi)
            raw_metrics.append(metrics)

        # Sync Parameters
        average_model(comm, self)

        # Sync Metrics
        world_size = comm.Get_size()
        raw_metrics = np.array(raw_metrics, dtype=np.float)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=RuntimeWarning)
            mean_metrics = np.nanmean(raw_metrics, axis=0)
        reduced_metrics = comm.allreduce(mean_metrics, op=MPI.SUM) / world_size

        metrics = ModelMetrics(*reduced_metrics)

        # Return metrics
        return metrics
Esempio n. 11
0
    def __init__(self, grid_file: str, twitter_file: str,
                 mpi_comm: MPI.Intracomm) -> None:
        """Initialization.

        Args:
            grid_file (str): file of grid data.
            twitter_file (str): file of twitter data.
            mpi_comm (MPI.Intracomm): MPI.COMM_WORLD.
        """
        self._grid_file = grid_file
        self._twitter_file = twitter_file
        self._mpi_comm = mpi_comm
        self._node_num = mpi_comm.Get_size()
        self._node_rank = mpi_comm.Get_rank()

        # root node
        if self._node_rank == 0:
            self._print_title("Task")
            print("task_time:   ", datetime.datetime.now())
            print("grid_file:   ", grid_file)
            print("twitter_file:", twitter_file)
            print("node_num:    ", self._node_num)
            print()
Esempio n. 12
0
def mpi_play(comm: MPI.Intracomm, go_env, pi1, pi2, requested_episodes):
    """
    Plays games in parallel
    :param comm:
    :param go_env:
    :param pi1:
    :param pi2:
    :param gettraj:
    :param requested_episodes:
    :return:
    """
    world_size = comm.Get_size()

    worker_episodes = int(math.ceil(requested_episodes / world_size))
    episodes = worker_episodes * world_size
    single_worker = comm.Get_size() <= 1

    timestart = time.time()
    p1wr, black_wr, replay, steps = game.play_games(go_env,
                                                    pi1,
                                                    pi2,
                                                    worker_episodes,
                                                    progress=single_worker)
    timeend = time.time()

    duration = timeend - timestart
    avg_time = comm.allreduce(duration / worker_episodes,
                              op=MPI.SUM) / world_size
    p1wr = comm.allreduce(p1wr, op=MPI.SUM) / world_size
    black_wr = comm.allreduce(black_wr, op=MPI.SUM) / world_size
    avg_steps = comm.allreduce(sum(steps), op=MPI.SUM) / episodes

    mpi_log_debug(
        comm,
        f'{pi1} V {pi2} | {episodes} GAMES, {avg_time:.1f} SEC/GAME, {avg_steps:.0f} STEPS/GAME, '
        f'{100 * p1wr:.1f}% WIN({100 * black_wr:.1f}% BLACK_WIN)')
    return p1wr, black_wr, replay
Esempio n. 13
0
def mpi_config_log(args, comm: MPI.Intracomm):
    if comm.Get_rank() == 0:
        config_log(args)

    comm.Barrier()