Example #1
0
class Layer(nn.Module):
    def __init__(self, compute_cycles, has_params: bool):
        super().__init__()
        self.sleep_cycles = compute_cycles
        self.optional_param = None
        if has_params:
            self.optional_param = nn.Parameter(torch.rand(1))

    def forward(self, x):
        # Get 2 events.
        self.e1 = Event(enable_timing=True)
        self.e2 = Event(enable_timing=True)

        # Record the fake forward compute time.
        self.e1.record()
        if self.sleep_cycles > 0:
            torch.cuda._sleep(self.sleep_cycles)
        if self.optional_param is not None:
            x = x + self.optional_param  # force the param to be part of the graph
        self.e2.record()
        return x

    def get_time(self):
        # return the recorded duration.
        return self.e1.elapsed_time(self.e2)
Example #2
0
    def forward(self, x):
        # Get 2 events.
        self.e1 = Event(enable_timing=True)
        self.e2 = Event(enable_timing=True)

        # Record the fake forward compute time.
        self.e1.record()
        if self.sleep_cycles > 0:
            torch.cuda._sleep(self.sleep_cycles)
        if self.optional_param is not None:
            x = x + self.optional_param  # force the param to be part of the graph
        self.e2.record()
        return x
Example #3
0
    def tick(self, name='default'):
        if name not in self.current_ticks:
            start = Event(enable_timing=True, blocking=True)
            start.record()
            self.current_ticks[name] = start

            return 0.0
        else:
            if name not in self.cumulative_secs:
                self.cumulative_secs[name] = 0
            self.end.record()
            self.end.synchronize()
            self.cumulative_secs[name] += self.current_ticks[
                name].elapsed_time(self.end) / 1000.
            self.current_ticks.pop(name)

            return self.cumulative_secs[name]
Example #4
0
class CUDATimer(object):
    def __init__(self, silent=False):
        self.cumulative_secs = {}
        self.current_ticks = {}
        self.silent = silent
        self.end = Event(enable_timing=True, blocking=True)

    def tick(self, name='default'):
        if name not in self.current_ticks:
            start = Event(enable_timing=True, blocking=True)
            start.record()
            self.current_ticks[name] = start

            return 0.0
        else:
            if name not in self.cumulative_secs:
                self.cumulative_secs[name] = 0
            self.end.record()
            self.end.synchronize()
            self.cumulative_secs[name] += self.current_ticks[
                name].elapsed_time(self.end) / 1000.
            self.current_ticks.pop(name)

            return self.cumulative_secs[name]

    def tock(self, name='default'):
        self.tick(name)
        value = self.cumulative_secs[name]
        if not self.silent:
            print('Time taken for {0}: {1:.8f}s'.format(name, value))
        self.cumulative_secs.pop(name)
        if name in self.current_ticks:
            del self.current_ticks[name]
        self.current_ticks.pop(name, None)

        return value
Example #5
0
    def run(compute_cycles, all_gather_cycles):
        has_params = all_gather_cycles > 0
        model = _create_model(fsdp_config, compute_cycles, has_params)

        # Get the input and sets the input's requires_grad to True because
        # we have a fake compute in the forward pass.
        batch = torch.rand(1).cuda()
        batch.requires_grad = True

        # We run 20 iterations but only collect timing data from the minimal 10
        # data points because nondeterministic system events can disturb the timing.
        cpu_iter = Min10()
        cpu_wait = Min10()
        gpu_compute = Min10()
        gpu_total = Min10()
        for _ in range(20):
            # Get two events for measuring the overall time.
            e1 = Event(enable_timing=True)
            e2 = Event(enable_timing=True)

            cpu_start = time.process_time()

            all_gather_called = False

            def _delayed_all_gather(*args, **kwargs):
                nonlocal all_gather_called
                all_gather_called = True
                torch.cuda._sleep(all_gather_cycles)
                return orig_all_gather(*args, **kwargs)

            # forward pass
            #
            # Even though both e1 & e2 are on the compute stream, since
            # compute depends on all_gather, e2-e1 includes all_gather time.
            e1.record()
            with patch("torch.distributed.all_gather", _delayed_all_gather):
                out = model(batch)
                if has_params and world_size > 1:
                    assert all_gather_called
                else:
                    assert not all_gather_called
            e2.record()

            # backward pass
            out.backward()
            if torch_version() >= (1, 7, 0):
                model.zero_grad(set_to_none=True)
            else:
                for p in model.parameters():
                    p.grad = None

            cpu_iter_time = time.process_time() - cpu_start

            # wait for gpu
            out.item()
            cpu_wait_for_gpu_time = time.process_time() - cpu_start - cpu_iter_time

            # get sum of the compute time
            times = []
            for mod in model.modules():
                if not isinstance(mod, Layer):
                    continue
                times.append(mod.get_time())

            # get gpu compute + all_gather time
            overall_gpu_time = e1.elapsed_time(e2)

            cpu_iter.add(cpu_iter_time)
            cpu_wait.add(cpu_wait_for_gpu_time)
            gpu_compute.add(sum(times))
            gpu_total.add(overall_gpu_time)

        del model
        return {
            "cpu_iter": cpu_iter.avg(),
            "cpu_wait": cpu_wait.avg(),
            "gpu_compute": gpu_compute.avg(),
            "gpu_total": gpu_total.avg(),
        }
Example #6
0
 def __init__(self, silent=False):
     self.cumulative_secs = {}
     self.current_ticks = {}
     self.silent = silent
     self.end = Event(enable_timing=True, blocking=True)
    def fetch_sub_module(self, current_submodule: Module) -> None:
        """This method does the following (in order):
        1. kick off fetch for parameters in immediately required sub module
        2. kick off fetch for next few parameters we will need later (prefetch)
        3. block on parameters in immediately required sub module
        """
        debug_rank0(
            f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule)]} "
            +
            str({
                "avail": f"{self.__n_available_params:.1e}",
                "queue_sz": f"{len(self.__param_queue or [])}",
                "inflight": [p.ds_id for p in self.__inflight_param_registry],
            }))

        params_to_fetch = frozenset(iter_params(current_submodule))

        # kick off all gather for params in the immediately required submodule
        for param in params_to_fetch:
            debug_rank0(f"-fetch: {param.ds_summary()}")
        self.__all_gather_params(params_to_fetch)

        # wait for parameters in the immediately needed submodule to become available
        for param in params_to_fetch:
            param.ds_active_sub_modules.add(current_submodule.id)
            debug_rank0(f"-wait: {param.ds_summary()}")
            if param in self.__inflight_param_registry:
                with torch.cuda.stream(self.__allgather_stream):
                    while self.__ongoing_fetch_events and self.__ongoing_fetch_events[
                            0].query():
                        self.__ongoing_fetch_events.popleft()
                    if len(self.__ongoing_fetch_events
                           ) > self.__max_ongoing_fetch_events:
                        self.__ongoing_fetch_events.popleft().synchronize()

                    self.__inflight_param_registry.pop(param).wait()

                    event = Event()
                    event.record()
                    self.__ongoing_fetch_events.append(event)

            assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary(
            )
        torch.cuda.current_stream().wait_stream(self.__allgather_stream)

        # kick off parameter prefetches for upcoming modules
        # don't prefetch if we dont have a completed model trace
        if self.is_complete_trace():
            # go through the parameters we need for the current module and pop them
            # off the fetch queue so that they aren't prefetched later.
            # if params have already been popped off the fetch queue by earlier
            # prefetches we won't look for them here
            discarded_from_prefetch_queue = set()
            params_not_already_fetched = set(
                filter(
                    lambda p: self.__most_recent_step_id_param_fetched_for[p] <
                    self.__step_id, params_to_fetch))
            while self.__param_queue and len(
                    discarded_from_prefetch_queue) < len(
                        params_not_already_fetched):
                param_in_trace = self.__param_queue.popleft()
                self.__most_recent_step_id_param_fetched_for[
                    param_in_trace.param] = param_in_trace.step_id_last_used_at
                discarded_from_prefetch_queue.add(param_in_trace.param)

            if discarded_from_prefetch_queue != params_not_already_fetched:
                raise RuntimeError(
                    f"tracing error at step {self.__step_id}: \n"
                    f"module id: {current_submodule.id}, training: {current_submodule.training}\n"
                    f"expected the next {len(params_not_already_fetched)} parameters in the "
                    f"parameter fetch queue to be {tuple(p.ds_summary(use_debug_name=True) for p in params_not_already_fetched)} \n"
                    f"but got \n {tuple(p.ds_summary(use_debug_name=True) for p in discarded_from_prefetch_queue)}."
                )

            def _is_currently_on_nvme(param):
                if param.nvme_swapper is None:
                    return False

                return param.ds_tensor.final_location == OffloadDeviceEnum.nvme \
                    and param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE

            # kick off all gather for params in the next few submodules (prefetch)
            if self.__prefetch_bucket_sz > 0:
                max_params_to_prefetch = min(
                    self.__max_n_available_params - self.__n_available_params,
                    self.__prefetch_bucket_sz)
                params_to_prefetch = set()
                numel_prefetching = 0
                while self.__param_queue and numel_prefetching < max_params_to_prefetch:
                    param_in_trace: __class__.__ParamInTrace = self.__param_queue.popleft(
                    )

                    if _is_currently_on_nvme(param_in_trace.param):
                        # nvme prefetch is handled elsewhere. Need to break here to preserve fetch order
                        self.__param_queue.appendleft(param_in_trace)
                        break

                    do_prefetch = param_in_trace.param.ds_status == ZeroParamStatus.NOT_AVAILABLE
                    if param_in_trace.param in params_to_prefetch:
                        # Avoid duplicates
                        do_prefetch = False

                    self.__most_recent_step_id_param_fetched_for[param_in_trace.param] = \
                        max(self.__most_recent_step_id_param_fetched_for[param_in_trace.param],
                            param_in_trace.step_id_last_used_at)

                    if do_prefetch:
                        params_to_prefetch.add(param_in_trace.param)
                        numel_prefetching += param_in_trace.param.ds_numel

                for param in params_to_prefetch:
                    debug_rank0(f"-prefetch: {param.ds_summary()}")
                self.__all_gather_params(params_to_prefetch)

                if self.__prefetch_nvme:
                    self.__prefetch_nvme_param_partitions()

        self.__step_id += 1
Example #8
0
def run_on_gpu(kernel, data, repeats, no_grad, fwd_bwd):
    """Measure both GPU runtime and peak memory usage of a kernel."""
    tokens = data[0].shape[0]

    def get_cuda_data():
        """Move the data from CPU to GPU. We make a new weight parameter with this call."""
        with torch.no_grad():
            i, w, t = data  # i, t are tensors, w is a param
            w = nn.Linear(w.shape[1],
                          w.shape[0],
                          bias=False,
                          dtype=w.dtype,
                          device="cuda").weight
            assert w.requires_grad
            return i.cuda().requires_grad_(True), w, t.cuda()

    def _test(kernel_obj, event):
        """Forward and backward passes."""
        context = contextlib.suppress()
        if no_grad:
            context = torch.no_grad()
        with context:
            if event is not None:
                event.record()
            out = kernel_obj(input, target)
            if fwd_bwd:
                assert not no_grad
                out.backward()
            del out
        if fwd_bwd:
            assert input.grad is not None, input
            assert weight.grad is not None, weight
            assert target.grad is None, target
            input.grad = None
            weight.grad = None

    def _get_kernel():
        """Get a kernel instance."""
        return kernel(weight, tile_factor=16)

    #
    # Run the test once to measure memory.
    #

    # Ensure GPU memory is clean, empty, 0.
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    cur_mem_before = round(torch.cuda.memory_allocated() / 1024 / 1024)
    assert cur_mem_before == 0, cur_mem_before

    # Move tensors to GPU.
    input, weight, target = get_cuda_data()

    # Create the kernel
    k = _get_kernel()
    _test(k, None)

    # Might wait for gpu here
    torch.cuda.synchronize()

    # Free memory, ensure everything is clean, no leak.
    del k
    del input
    del weight
    del target
    cur_mem_after = round(torch.cuda.memory_allocated() / 1024 / 1024)
    assert cur_mem_after == 0, cur_mem_after

    # Get peak mem
    peak_mem_after = round(torch.cuda.max_memory_allocated() / 1024 / 1024)
    peak_mem = peak_mem_after - cur_mem_before

    #
    # Run multiple times to get both CPU timing and average GPU timing.
    #

    # Move tensors to GPU and get k, again.
    input, weight, target = get_cuda_data()
    k = _get_kernel()

    # Get the events
    events = [Event(enable_timing=True) for _ in range(repeats + 1)]

    # Queue the ops to GPU
    cpu_start_time = time.time()
    for i in range(repeats):
        _test(k, events[i])
    events[i + 1].record()  # end time of the last run
    # CPU could be done much sooner than the GPU here.
    cpu_time = time.time() - cpu_start_time
    # Might wait for gpu here
    torch.cuda.synchronize()

    # Get the durations
    durations = [cpu_time * 1000]  # convert CPU time, from seconds to ms.
    for x, y in zip(events, events[1:]):
        durations.append(x.elapsed_time(y))
    assert len(durations) == repeats + 1

    # Free memory
    del k
    input, weight, target = None, None, None
    cur_mem_after = round(torch.cuda.memory_allocated() / 1024 / 1024)
    assert cur_mem_after == 0, cur_mem_after

    # Skip 2 for cpu time and first warm up time to compute the average.
    time_per_call = mean(durations[2:])  # ms
    time_per_token = time_per_call * 1000 / tokens  # us
    return peak_mem, durations[:2] + [time_per_call, time_per_token]