class Layer(nn.Module): def __init__(self, compute_cycles, has_params: bool): super().__init__() self.sleep_cycles = compute_cycles self.optional_param = None if has_params: self.optional_param = nn.Parameter(torch.rand(1)) def forward(self, x): # Get 2 events. self.e1 = Event(enable_timing=True) self.e2 = Event(enable_timing=True) # Record the fake forward compute time. self.e1.record() if self.sleep_cycles > 0: torch.cuda._sleep(self.sleep_cycles) if self.optional_param is not None: x = x + self.optional_param # force the param to be part of the graph self.e2.record() return x def get_time(self): # return the recorded duration. return self.e1.elapsed_time(self.e2)
def run(compute_cycles, all_gather_cycles): has_params = all_gather_cycles > 0 model = _create_model(fsdp_config, compute_cycles, has_params) # Get the input and sets the input's requires_grad to True because # we have a fake compute in the forward pass. batch = torch.rand(1).cuda() batch.requires_grad = True # We run 20 iterations but only collect timing data from the minimal 10 # data points because nondeterministic system events can disturb the timing. cpu_iter = Min10() cpu_wait = Min10() gpu_compute = Min10() gpu_total = Min10() for _ in range(20): # Get two events for measuring the overall time. e1 = Event(enable_timing=True) e2 = Event(enable_timing=True) cpu_start = time.process_time() all_gather_called = False def _delayed_all_gather(*args, **kwargs): nonlocal all_gather_called all_gather_called = True torch.cuda._sleep(all_gather_cycles) return orig_all_gather(*args, **kwargs) # forward pass # # Even though both e1 & e2 are on the compute stream, since # compute depends on all_gather, e2-e1 includes all_gather time. e1.record() with patch("torch.distributed.all_gather", _delayed_all_gather): out = model(batch) if has_params and world_size > 1: assert all_gather_called else: assert not all_gather_called e2.record() # backward pass out.backward() if torch_version() >= (1, 7, 0): model.zero_grad(set_to_none=True) else: for p in model.parameters(): p.grad = None cpu_iter_time = time.process_time() - cpu_start # wait for gpu out.item() cpu_wait_for_gpu_time = time.process_time() - cpu_start - cpu_iter_time # get sum of the compute time times = [] for mod in model.modules(): if not isinstance(mod, Layer): continue times.append(mod.get_time()) # get gpu compute + all_gather time overall_gpu_time = e1.elapsed_time(e2) cpu_iter.add(cpu_iter_time) cpu_wait.add(cpu_wait_for_gpu_time) gpu_compute.add(sum(times)) gpu_total.add(overall_gpu_time) del model return { "cpu_iter": cpu_iter.avg(), "cpu_wait": cpu_wait.avg(), "gpu_compute": gpu_compute.avg(), "gpu_total": gpu_total.avg(), }