def _swap_out_unpinned_tensors(self, aio_handle, unpinned_tensors, dest_paths, pinned_buffers): swap_buffer_count = len(pinned_buffers) unpinned_tensor_count = len(unpinned_tensors) for i in range(0, unpinned_tensor_count, swap_buffer_count): swap_tensor_count = min((unpinned_tensor_count - i), swap_buffer_count) src_tensors = unpinned_tensors[i:(i + swap_tensor_count)] compute_lengths = [t.numel() for t in src_tensors] compute_buffers = get_sized_buffers(pinned_buffers, compute_lengths) for dst, src in zip(compute_buffers, src_tensors): dst.data.copy_(src.data) swap_lengths = [self._io_aligned_numel(t.numel()) for t in src_tensors] swap_buffers = get_sized_buffers(pinned_buffers, swap_lengths) swap_paths = dest_paths[i:(i + swap_tensor_count)] swap_out_tensors(aio_handle, swap_buffers, swap_paths) assert aio_handle.wait() == swap_tensor_count
def _swap_out_ready_buffers(self): for buffer_index in self.ready_buffer_index: buffer = self._get_buffer(buffer_index) swap_tensors = buffer.get_swap_tensors() swap_paths = buffer.get_swap_paths() self.num_pending_swaps += len(swap_tensors) swap_out_tensors(self.aio_handle, swap_tensors, swap_paths) self.swapping_buffer_index += self.ready_buffer_index self.ready_buffer_index = []
def swap_out_optimizer_state(self, parameter, async_swap=False): swap_info = self._get_param_swap_info(parameter=parameter) if swap_info is None: return self._start_timer(SWAP_OUT_PARAM_TIMER) pinned_tensors, pinned_paths, unpinned_tensors, unpinned_paths = self._separate_pinned_tensors( swap_info) swap_bytes = sum([ self._io_aligned_numel(t.numel()) * t.element_size() for t in swap_info.tensors ]) WRITE_TIMER = 'swap_submit_write' self._start_timer(WRITE_TIMER) swap_out_tensors(self.aio_handle, pinned_tensors, pinned_paths) assert self.aio_handle.wait() == len(pinned_tensors) for t in pinned_tensors: t.data = torch.Tensor() if len(unpinned_tensors) > 0: pinned_buffers = self.swap_buffer_manager.allocate_all( num_elems=self.largest_numel, dtype=self.dtype) self._swap_out_unpinned_tensors(aio_handle=self.aio_handle, unpinned_tensors=unpinned_tensors, dest_paths=unpinned_paths, pinned_buffers=pinned_buffers) self.allocated_swap_buffers += pinned_buffers for t in unpinned_tensors: t.data = torch.Tensor() self._stop_timer(WRITE_TIMER) self.swap_buffer_manager.free(self.allocated_swap_buffers) self.allocated_swap_buffers = [] self._stop_timer(SWAP_OUT_PARAM_TIMER) self.timer_names.add(SWAP_OUT_PARAM_TIMER) self._log_timers([WRITE_TIMER]) if DEBUG_MODE and torch.distributed.get_rank() == 0: logger.info( f'optimizer_param_swap_out: {(swap_bytes/(1024**3)):5.2f} GB')
def _swap_out_optimizer_state(self, aio_handle, parameter, swap_in_op): assert swap_in_op.is_parameter(parameter) allocated_buffers = swap_in_op.allocated_buffers.copy() swap_buffers = swap_in_op.state_buffers.copy() param_info = swap_in_op.param_info self._update_param_state_info(param_info, parameter) unpinned_tensors = param_info.get_unpinned_state_tensors() if len(unpinned_tensors) > 0: new_alloc_buffers = self.swap_buffer_manager.allocate( num_elems=self._io_aligned_numel(param_info.numel()), count=len(unpinned_tensors), dtype=param_info.dtype()) assert new_alloc_buffers is not None allocated_buffers += new_alloc_buffers swap_buffers += new_alloc_buffers for pinned_dst, unpinned_src in zip(new_alloc_buffers, unpinned_tensors): dst = get_sized_buffer(pinned_dst, unpinned_src.numel()) dst.data.copy_(unpinned_src.data) swap_paths = param_info.swap_paths.copy() assert len(swap_paths) == len(swap_buffers) swap_out_tensors(aio_handle, swap_buffers, swap_paths) swap_out_op = OptimizerSwapOp(aio_handle=aio_handle, param_info=param_info, read_op=False, allocated_buffers=allocated_buffers, state_buffers=swap_buffers, num_ops=len(swap_buffers)) return swap_out_op