def communication(comm: MPI.Intracomm) -> Callable[[np.ndarray], np.ndarray]: """ Wrapper for communcation step function Args: comm: MPI communicator Returns: function which implements the communication step """ # precompute left_src, left_dst = comm.Shift(direction=0, disp=-1) right_src, right_dst = comm.Shift(direction=0, disp=1) bottom_src, bottom_dst = comm.Shift(direction=1, disp=-1) top_src, top_dst = comm.Shift(direction=1, disp=1) def communicate(f: np.ndarray) -> np.ndarray: """ Implements the communication step. For details we refer to the report in the repository. Args: f: probability density function Returns: probability density function after communication """ # send to left recvbuf = f[-1, ...].copy() comm.Sendrecv(f[1, ...].copy(), left_dst, recvbuf=recvbuf, source=left_src) f[-1, ...] = recvbuf # send to right recvbuf = f[0, ...].copy() comm.Sendrecv(f[-2, ...].copy(), right_dst, recvbuf=recvbuf, source=right_src) f[0, ...] = recvbuf # send to bottom recvbuf = f[:, -1, :].copy() comm.Sendrecv(f[:, 1, :].copy(), bottom_dst, recvbuf=recvbuf, source=bottom_src) f[:, -1, :] = recvbuf # send to top recvbuf = f[:, 0, :].copy() comm.Sendrecv(f[:, -2, :].copy(), top_dst, recvbuf=recvbuf, source=top_src) f[:, 0, :] = recvbuf return f return communicate
def mpi_sync_checkpoint(comm: MPI.Intracomm, args, new_pi, old_pi): rank = comm.Get_rank() checkpath = get_modelpath(args, 'checkpoint') if rank == 0: torch.save(new_pi.pt_model.state_dict(), checkpath) comm.Barrier() # Update other policy old_pi.pt_model.load_state_dict(torch.load(checkpath, args.device))
def save_mpiio(comm: MPI.Intracomm, fn: str, g_kl: np.ndarray): """ Write a global two-dimensional array to a single file in the npy format using MPI I/O: https://docs.scipy.org/doc/numpy/neps/npy-format.html Arrays written with this function can be read with numpy.load. Args: comm: MPI communicator fn: File name g_kl: array_like Portion of the array on this MPI processes. This needs to be a two-dimensional array. """ from numpy.lib.format import dtype_to_descr, magic magic_str = magic(1, 0) local_nx, local_ny = g_kl.shape nx = np.empty_like(local_nx) ny = np.empty_like(local_ny) commx = comm.Sub((True, False)) commy = comm.Sub((False, True)) commx.Allreduce(np.asarray(local_nx), nx) commy.Allreduce(np.asarray(local_ny), ny) arr_dict_str = str({ 'descr': dtype_to_descr(g_kl.dtype), 'fortran_order': False, 'shape': (np.asscalar(nx), np.asscalar(ny)) }) while (len(arr_dict_str) + len(magic_str) + 2) % 16 != 15: arr_dict_str += ' ' arr_dict_str += '\n' header_len = len(arr_dict_str) + len(magic_str) + 2 offsetx = np.zeros_like(local_nx) commx.Exscan(np.asarray(ny * local_nx), offsetx) offsety = np.zeros_like(local_ny) commy.Exscan(np.asarray(local_ny), offsety) file = MPI.File.Open(comm, fn, MPI.MODE_CREATE | MPI.MODE_WRONLY) if comm.Get_rank() == 0: file.Write(magic_str) file.Write(np.int16(len(arr_dict_str))) file.Write(arr_dict_str.encode('latin-1')) mpitype = MPI._typedict[g_kl.dtype.char] filetype = mpitype.Create_vector(g_kl.shape[0], g_kl.shape[1], ny) filetype.Commit() file.Set_view(header_len + (offsety + offsetx) * mpitype.Get_size(), filetype=filetype) file.Write_all(g_kl.copy()) filetype.Free() file.Close()
def _send_exit_signal(self, icomm: MPI.Intracomm) -> None: """Send exit signal to intercommunicator. """ service_name = self._service_name(icomm) logging.debug('Sending exit signal to {}:{}' .format(icomm, service_name)) icomm.send(None, dest=0) logging.debug('Sent exit signal to {}:{}' .format(icomm, service_name))
def mpi_disk_append_replay(comm: MPI.Intracomm, args, replays): rank = comm.Get_rank() for worker in range(comm.Get_size()): if rank == worker: if os.path.exists(args.replay_path): all_replays = load_replay(args.replay_path) all_replays.extend(replays) else: all_replays = replays with open(args.replay_path, 'wb') as f: pickle.dump(all_replays, f) comm.Barrier()
def generate_formatter(comm: MPI.Intracomm) -> Formatter: rank = comm.Get_rank() hostname = MPI.Get_processor_name() return Formatter( f"%(asctime)s[{hostname}][{rank}][%(levelname)s] - %(message)s" )
def mpi_log_debug(comm: MPI.Intracomm, s): """ Only the first worker prints stuff :param rank: :param s: :return: """ rank = comm.Get_rank() if rank == 0: log_debug(s)
def mpi_sample_eventdata(comm: MPI.Intracomm, replay_path, batches, batchsize): """ :param replay_path: :param batches: :param batchsize: :return: Batches of sample data, len of total data that was sampled """ # Workers sample data one at a time to avoid memory issues rank = comm.Get_rank() world_size = comm.Get_size() sample_data = None replay_len = None for worker in range(world_size): if rank == worker: replay = load_replay(replay_path) replay_len = len(replay) # Seperate into black wins and black non-wins to ensure even sampling between the two black_wins = list(filter(lambda traj: traj.get_winner() == 1, replay)) black_nonwins = list(filter(lambda traj: traj.get_winner() != 1, replay)) black_wins = replay_to_events(black_wins) black_nonwins = replay_to_events(black_nonwins) n = min(len(black_wins), len(black_nonwins)) sample_size = min(batchsize * batches // 2, n) sample_data = random.sample(black_wins, sample_size) + random.sample(black_nonwins, sample_size) # Save memory del replay comm.Barrier() random.shuffle(sample_data) sample_data = events_to_numpy(sample_data) sample_size = len(sample_data[0]) for component in sample_data: assert len(component) == sample_size splits = max(sample_size // batchsize, 1) batched_sampledata = [np.array_split(component, splits) for component in sample_data] batched_sampledata = list(zip(*batched_sampledata)) return batched_sampledata, replay_len
def mpi_sync_data(comm: MPI.Intracomm, args): rank = comm.Get_rank() if rank == 0: # Clear worker data data.reset_replay(args) if args.baseline or args.customdir != '': if os.path.exists(args.checkpath): # Ensure that we do not use any old parameters os.remove(args.checkpath) if args.baseline: mpi_log_debug(comm, "Starting from baseline") else: mpi_log_debug(comm, f"Starting from model file: {args.custompath}") else: # Save new model _, new_model = baselines.create_policy(args, '') torch.save(new_model.state_dict(), args.checkpath) mpi_log_debug(comm, "Starting from scratch") comm.Barrier()
def optimize(self, comm: MPI.Intracomm, batched_data, optimizer): raw_metrics = [] self.train() for states, actions, reward, children, terminal, wins, pi in batched_data: metrics = self.train_step(optimizer, states, actions, reward, children, terminal, wins, pi) raw_metrics.append(metrics) # Sync Parameters average_model(comm, self) # Sync Metrics world_size = comm.Get_size() raw_metrics = np.array(raw_metrics, dtype=np.float) with warnings.catch_warnings(): warnings.simplefilter("ignore", category=RuntimeWarning) mean_metrics = np.nanmean(raw_metrics, axis=0) reduced_metrics = comm.allreduce(mean_metrics, op=MPI.SUM) / world_size metrics = ModelMetrics(*reduced_metrics) # Return metrics return metrics
def __init__(self, grid_file: str, twitter_file: str, mpi_comm: MPI.Intracomm) -> None: """Initialization. Args: grid_file (str): file of grid data. twitter_file (str): file of twitter data. mpi_comm (MPI.Intracomm): MPI.COMM_WORLD. """ self._grid_file = grid_file self._twitter_file = twitter_file self._mpi_comm = mpi_comm self._node_num = mpi_comm.Get_size() self._node_rank = mpi_comm.Get_rank() # root node if self._node_rank == 0: self._print_title("Task") print("task_time: ", datetime.datetime.now()) print("grid_file: ", grid_file) print("twitter_file:", twitter_file) print("node_num: ", self._node_num) print()
def mpi_play(comm: MPI.Intracomm, go_env, pi1, pi2, requested_episodes): """ Plays games in parallel :param comm: :param go_env: :param pi1: :param pi2: :param gettraj: :param requested_episodes: :return: """ world_size = comm.Get_size() worker_episodes = int(math.ceil(requested_episodes / world_size)) episodes = worker_episodes * world_size single_worker = comm.Get_size() <= 1 timestart = time.time() p1wr, black_wr, replay, steps = game.play_games(go_env, pi1, pi2, worker_episodes, progress=single_worker) timeend = time.time() duration = timeend - timestart avg_time = comm.allreduce(duration / worker_episodes, op=MPI.SUM) / world_size p1wr = comm.allreduce(p1wr, op=MPI.SUM) / world_size black_wr = comm.allreduce(black_wr, op=MPI.SUM) / world_size avg_steps = comm.allreduce(sum(steps), op=MPI.SUM) / episodes mpi_log_debug( comm, f'{pi1} V {pi2} | {episodes} GAMES, {avg_time:.1f} SEC/GAME, {avg_steps:.0f} STEPS/GAME, ' f'{100 * p1wr:.1f}% WIN({100 * black_wr:.1f}% BLACK_WIN)') return p1wr, black_wr, replay
def mpi_config_log(args, comm: MPI.Intracomm): if comm.Get_rank() == 0: config_log(args) comm.Barrier()