def _swap_in_parameter(self, aio_handle, parameter, dest_buffers): swap_info = self._get_param_swap_info(parameter) if swap_info is None: return assert len(swap_info.tensors) <= len(dest_buffers) swap_lengths = [self._io_aligned_numel(swap_info.numel())] * len( swap_info.tensors) swap_buffers = get_sized_buffers(dest_buffers, swap_lengths) READ_TIMER = 'swap_submit_read_param' WAIT_TIMER = 'swap_wait_read_param' self._start_timer(READ_TIMER) swap_in_tensors(aio_handle, swap_buffers, swap_info.swap_paths) self._stop_timer(READ_TIMER) swap_bytes = sum([ buffer.numel() * buffer.element_size() for buffer in swap_buffers ]) self._start_timer(WAIT_TIMER) aio_handle.wait() self._stop_timer(WAIT_TIMER) compute_lengths = [swap_info.numel()] * len(swap_info.tensors) compute_buffers = get_sized_buffers(dest_buffers, compute_lengths) for t, buffer in zip(swap_info.tensors, compute_buffers): t.data = buffer.data self._log_timers([READ_TIMER, WAIT_TIMER]) if DEBUG_MODE and torch.distributed.get_rank() == 0: logger.info( f'optimizer_param_swap_in: {(swap_bytes/(1024**3)):5.2f} GB')
def _swap_in_optimizer_state(self, aio_handle, parameter): param_info = self._get_param_swap_info(parameter) if param_info is None: return None required_buffer_count = len( param_info.tensors) + (1 if param_info.has_gradients() else 0) aligned_numel = self._io_aligned_numel(param_info.numel()) allocated_buffers = self.swap_buffer_manager.allocate( num_elems=aligned_numel, count=required_buffer_count, dtype=parameter.dtype) assert allocated_buffers is not None, \ f"PipelinedOptimizerSwapper ran out of swap buffers, try increasing {OFFLOAD_OPTIMIZER_BUFFER_COUNT}" state_buffers = allocated_buffers[:len(param_info.tensors)] param_info.set_swap_buffers(state_buffers) swap_buffers = state_buffers.copy() swap_paths = param_info.swap_paths.copy() if param_info.has_gradients(): parameter.grad = allocated_buffers[-1].narrow( 0, 0, param_info.numel()) if param_info.swapped_gradients: swap_buffers += param_info.get_swap_gradient_buffers( parameter.grad) swap_paths += param_info.get_swap_gradient_paths() swap_in_tensors(aio_handle, swap_buffers, swap_paths) if param_info.unswapped_gradients: self._retrieve_unswapped_grad_partitions( swap_info=param_info, dest_buffer=parameter.grad) swap_in_op = OptimizerSwapOp(aio_handle=aio_handle, param_info=param_info, read_op=True, allocated_buffers=allocated_buffers, state_buffers=state_buffers, num_ops=len(swap_buffers)) return swap_in_op
def _swap_in_pinned_gradients(self, aio_handle, parameter, gradient_tensor): swap_info = self.swap_params_info[id(parameter)] param_gradients = swap_info.swapped_gradients.values() swap_buffers = [ gradient_tensor.narrow(0, grad.offset, grad.length) for grad in param_gradients ] swap_paths = [grad.path for grad in param_gradients] SWAP_READ_GRADIENTS = 'swap_submit_read_gradient' SWAP_WAIT_GRADIENTS = 'swap_submit_wait_gradient' self._start_timer(SWAP_READ_GRADIENTS) swap_in_tensors(aio_handle, swap_buffers, swap_paths) self._stop_timer(SWAP_READ_GRADIENTS) self._start_timer(SWAP_WAIT_GRADIENTS) assert len(swap_buffers) == aio_handle.wait() self._stop_timer(SWAP_WAIT_GRADIENTS) self._log_timers([SWAP_READ_GRADIENTS, SWAP_WAIT_GRADIENTS])
def _swap_in_fp16_params(self, aio_handle, fp16_num_elems, fp16_partitions_info, fp16_swap_buffers): assert len(fp16_num_elems) > 0 swapped_fp16_tensors = [] swap_tensors = [] swap_paths = [] unswapped_srcs = [] unswapped_dsts = [] for i, numel in enumerate(fp16_num_elems): pinned_tensor, _ = fp16_swap_buffers.allocate_tensor(numel, None, numel) if pinned_tensor is None: break swapped_fp16_tensors.append(pinned_tensor) offset = 0 for tensor, partition_numel, partition_path in fp16_partitions_info[i]: dst_tensor = pinned_tensor.narrow(0, offset, partition_numel) if partition_path is None: unswapped_srcs.append(tensor) unswapped_dsts.append(dst_tensor) else: swap_paths.append(partition_path) swap_tensors.append(dst_tensor) offset += partition_numel assert len(swapped_fp16_tensors) + len(unswapped_srcs) > 0 ret = swap_in_tensors(aio_handle, swap_tensors, swap_paths) for src, dst in zip(unswapped_srcs, unswapped_dsts): dst.data.copy_(src.data) assert len(swap_tensors) == aio_handle.wait() return swapped_fp16_tensors