def _swap_in_parameter(self, aio_handle, parameter, dest_buffers):
        swap_info = self._get_param_swap_info(parameter)
        if swap_info is None:
            return

        assert len(swap_info.tensors) <= len(dest_buffers)

        swap_lengths = [self._io_aligned_numel(swap_info.numel())] * len(
            swap_info.tensors)
        swap_buffers = get_sized_buffers(dest_buffers, swap_lengths)

        READ_TIMER = 'swap_submit_read_param'
        WAIT_TIMER = 'swap_wait_read_param'

        self._start_timer(READ_TIMER)
        swap_in_tensors(aio_handle, swap_buffers, swap_info.swap_paths)
        self._stop_timer(READ_TIMER)

        swap_bytes = sum([
            buffer.numel() * buffer.element_size() for buffer in swap_buffers
        ])

        self._start_timer(WAIT_TIMER)
        aio_handle.wait()
        self._stop_timer(WAIT_TIMER)

        compute_lengths = [swap_info.numel()] * len(swap_info.tensors)
        compute_buffers = get_sized_buffers(dest_buffers, compute_lengths)
        for t, buffer in zip(swap_info.tensors, compute_buffers):
            t.data = buffer.data

        self._log_timers([READ_TIMER, WAIT_TIMER])
        if DEBUG_MODE and torch.distributed.get_rank() == 0:
            logger.info(
                f'optimizer_param_swap_in: {(swap_bytes/(1024**3)):5.2f} GB')
Esempio n. 2
0
    def _swap_in_optimizer_state(self, aio_handle, parameter):
        param_info = self._get_param_swap_info(parameter)
        if param_info is None:
            return None

        required_buffer_count = len(
            param_info.tensors) + (1 if param_info.has_gradients() else 0)
        aligned_numel = self._io_aligned_numel(param_info.numel())
        allocated_buffers = self.swap_buffer_manager.allocate(
            num_elems=aligned_numel,
            count=required_buffer_count,
            dtype=parameter.dtype)
        assert allocated_buffers is not None, \
        f"PipelinedOptimizerSwapper ran out of swap buffers, try increasing {OFFLOAD_OPTIMIZER_BUFFER_COUNT}"

        state_buffers = allocated_buffers[:len(param_info.tensors)]
        param_info.set_swap_buffers(state_buffers)

        swap_buffers = state_buffers.copy()
        swap_paths = param_info.swap_paths.copy()

        if param_info.has_gradients():
            parameter.grad = allocated_buffers[-1].narrow(
                0, 0, param_info.numel())
            if param_info.swapped_gradients:
                swap_buffers += param_info.get_swap_gradient_buffers(
                    parameter.grad)
                swap_paths += param_info.get_swap_gradient_paths()

        swap_in_tensors(aio_handle, swap_buffers, swap_paths)

        if param_info.unswapped_gradients:
            self._retrieve_unswapped_grad_partitions(
                swap_info=param_info, dest_buffer=parameter.grad)

        swap_in_op = OptimizerSwapOp(aio_handle=aio_handle,
                                     param_info=param_info,
                                     read_op=True,
                                     allocated_buffers=allocated_buffers,
                                     state_buffers=state_buffers,
                                     num_ops=len(swap_buffers))

        return swap_in_op
    def _swap_in_pinned_gradients(self, aio_handle, parameter,
                                  gradient_tensor):
        swap_info = self.swap_params_info[id(parameter)]
        param_gradients = swap_info.swapped_gradients.values()
        swap_buffers = [
            gradient_tensor.narrow(0, grad.offset, grad.length)
            for grad in param_gradients
        ]
        swap_paths = [grad.path for grad in param_gradients]
        SWAP_READ_GRADIENTS = 'swap_submit_read_gradient'
        SWAP_WAIT_GRADIENTS = 'swap_submit_wait_gradient'

        self._start_timer(SWAP_READ_GRADIENTS)
        swap_in_tensors(aio_handle, swap_buffers, swap_paths)
        self._stop_timer(SWAP_READ_GRADIENTS)

        self._start_timer(SWAP_WAIT_GRADIENTS)
        assert len(swap_buffers) == aio_handle.wait()
        self._stop_timer(SWAP_WAIT_GRADIENTS)

        self._log_timers([SWAP_READ_GRADIENTS, SWAP_WAIT_GRADIENTS])
Esempio n. 4
0
    def _swap_in_fp16_params(self,
                             aio_handle,
                             fp16_num_elems,
                             fp16_partitions_info,
                             fp16_swap_buffers):
        assert len(fp16_num_elems) > 0

        swapped_fp16_tensors = []
        swap_tensors = []
        swap_paths = []
        unswapped_srcs = []
        unswapped_dsts = []

        for i, numel in enumerate(fp16_num_elems):
            pinned_tensor, _ = fp16_swap_buffers.allocate_tensor(numel, None, numel)
            if pinned_tensor is None:
                break

            swapped_fp16_tensors.append(pinned_tensor)
            offset = 0
            for tensor, partition_numel, partition_path in fp16_partitions_info[i]:
                dst_tensor = pinned_tensor.narrow(0, offset, partition_numel)
                if partition_path is None:
                    unswapped_srcs.append(tensor)
                    unswapped_dsts.append(dst_tensor)
                else:
                    swap_paths.append(partition_path)
                    swap_tensors.append(dst_tensor)
                offset += partition_numel

        assert len(swapped_fp16_tensors) + len(unswapped_srcs) > 0
        ret = swap_in_tensors(aio_handle, swap_tensors, swap_paths)
        for src, dst in zip(unswapped_srcs, unswapped_dsts):
            dst.data.copy_(src.data)

        assert len(swap_tensors) == aio_handle.wait()

        return swapped_fp16_tensors