def forward(self, x): """ Encodes the given input into a q_phi(z|x) probability distribution, samples a latent vector from that distribution, and finally calls the decoder network. For compatibility, it returns zK_sampled = z_sampled and the log abs det jacobian(T) = 0.0 (T = identity) :returns: z_mu_logvar, z_sampled, zK_sampled=z_sampled, logabsdetjacT=0.0, x_out (reconstructed spectrogram) """ with profiler.record_function( "ENCODING") if self.is_profiled else contextlib.nullcontext(): z_mu_logvar = self.encoder(x) n_minibatch = z_mu_logvar.size()[0] mu = z_mu_logvar[:, 0, :] sigma = torch.exp(z_mu_logvar[:, 1, :] / 2.0) with profiler.record_function( "LATENT_SAMPLING" ) if self.is_profiled else contextlib.nullcontext(): if self.training: # Sampling from the q_phi(z|x) probability distribution - with re-parametrization trick eps = Normal( torch.zeros(n_minibatch, self.dim_z, device=mu.device), torch.ones(n_minibatch, self.dim_z, device=mu.device)).sample() z_sampled = mu + sigma * eps else: # eval mode: no random sampling z_sampled = mu with profiler.record_function( "DECODING") if self.is_profiled else contextlib.nullcontext(): x_out = self.decoder(z_sampled) return z_mu_logvar, z_sampled, z_sampled, torch.zeros( (z_sampled.shape[0], 1), device=x.device), x_out
def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]: """Performs a single optimization step (parameter update). Arguments: closure (callable): A closure that reevaluates the model and returns the loss. Optional for most optimizers. .. note: Any extra parameter is passed to the base optimizer as-is""" # Sync oss param_groups attributes in case they've been updated by a scheduler. OSS._sync_param_groups(self.param_groups, self.optim.param_groups) # Catch a possible change of devices in between OSS construction and step() with profiler.record_function("fairscale::oss::refresh_trainable"): if self._default_device.type != self.param_groups[0]["params"][0].device.type: logging.info("OSS detected that the parameter changed devices, re-allocating buffers") self._clear_cache() self.refresh_trainable() # Run the optimizer step on this shard only: with profiler.record_function("fairscale::oss::optim_step"): if closure is not None: loss = self.optim.step(closure=closure, **kwargs) # type: ignore else: loss = self.optim.step(**kwargs) # Sync all the updated shards in between the ranks self._broadcast_params() # Sync hypothethical new results from the wrapped optimizer to the exposed param_groups OSS._sync_param_groups(self.optim.param_groups, self.param_groups) return loss
def calculate_loss(self, interaction): user = interaction[self.USER_ID] item = interaction[self.ITEM_ID] label = interaction[self.LABEL] with profiler.record_function("REC output and loss"): output_rec = self.forward_rec(user, item) loss_rec = self.loss_rec(output_rec, label) with profiler.record_function("LM output"): output_lm = self.forward_lm(item) if self.variant == 3: label_lm = self.lm_gt[item].to_dense().to(self.device) if self.variant == 2: label_lm = self.lm_gt[item].to(self.device).to_dense() if self.variant == 1: with profiler.record_function("LM making label on GPU"): label_lm = self.lm_gt[item].to_dense() # label_lm = torch.zeros(len(item), self.vocab_size, device=self.device) # for i in range(len(item)): # item_id = item[i] # label_lm[i] = self.lm_gt[item_id].to_dense() with profiler.record_function("LM loss"): loss_lm = self.loss_lm(output_lm, label_lm) return loss_rec, self.alpha * loss_lm
def benchmark_model(model, optimizer, parameters, name): # Run step_times = [] # Autograd profiler adds some overhead, so we time the forward pass with # and without enabling it. for profile_autograd in [False, True]: with profiler.profile(enabled=profile_autograd, use_cuda=(device == "cuda")) as prof: for i in range(15): # Warm up for five steps, reset step_times after this. if i == 5: step_times = [] with profiler.record_function("forward"): loss = model(x).sum() with profiler.record_function("backward"): loss.backward() if device == 'cuda': torch.cuda.synchronize() start = time.time() with profiler.record_function("gradient_norm"): torch.nn.utils.clip_grad_norm_(parameters, 0.1) with profiler.record_function("step"): optimizer.step() with profiler.record_function("zero_grad"): optimizer.zero_grad() if device == 'cuda': torch.cuda.synchronize() step_times.append(time.time() - start) print(f"Mean step time: {sum(step_times) / 10} seconds. " f"(Autograd profiler enabled: {profile_autograd})") prof.export_chrome_trace(f"{name}_timeline.json")
def func(model, device, input_size, epochs=100, dendrite=False): batch_size = 4096 use_cuda = device.type == "cuda" dummy_tensor = torch.rand((batch_size, input_size), device=device) wall_clock = 0.0 for _ in range(epochs): if dendrite: dummy_context = torch.rand((batch_size, model.dim_context), device=device) s = time.time() with profiler.profile(record_shapes=True, use_cuda=use_cuda) as prof: with profiler.record_function("model_inference"): res = model(dummy_tensor, dummy_context) else: s = time.time() with profiler.profile(record_shapes=True, use_cuda=use_cuda) as prof: with profiler.record_function("model_inference"): res = model(dummy_tensor) wall_clock += time.time() - s print("Wall clock:", wall_clock / epochs) if device.type == "cuda": print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) else: print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) dense_params, sparse_params = count_nonzero_params(model) print(f"Total params:{dense_params}, non-zero params:{sparse_params}") if res.sum() == 0: # Just to make Python think we need res print(res.sum()) return wall_clock / epochs
def calculate_loss(self, interaction): user = interaction[self.USER_ID] item = interaction[self.ITEM_ID] label = interaction[self.LABEL] with profiler.record_function("REC output and loss"): output_rec = self.forward_rec(user, item) loss_rec = self.loss_rec(output_rec, label) with profiler.record_function("LM output"): output_lm = self.forward_lm(item) if self.variant == 1: with profiler.record_function("LM making tensor on GPU"): label_lm_k = self.lm_gt[item].to(device=self.device) label_lm_len = self.lm_gt_len[item].to(device=self.device) label_lm = (label_lm_k.T / label_lm_len).T elif self.variant == 2: with profiler.record_function("LM making tensor on CPU"): label_lm_k = self.lm_gt[item] label_lm_len = self.lm_gt_len[item] label_lm_k = label_lm_k.T / label_lm_len label_lm = label_lm_k.to(device=self.device).T with profiler.record_function("LM loss"): loss_lm = self.loss_lm(output_lm, label_lm) return loss_rec, self.alpha * loss_lm
def calculate_loss(self, interaction): user = interaction[self.USER_ID] item = interaction[self.ITEM_ID] label = interaction[self.LABEL] with profiler.record_function("REC output and loss"): output_rec = self.forward_rec(user, item) loss_rec = self.loss_rec(output_rec, label) with profiler.record_function("LM output"): output_lm = self.forward_lm(item) if self.variant == 3: label_lm = self.lm_gt[item].to_dense().to(self.device) if self.variant == 2: label_lm = self.lm_gt[item].to(self.device).to_dense() if self.variant == 1: with profiler.record_function("LM making label on GPU"): label_lm = self.lm_gt[item].to_dense() with profiler.record_function("LM loss"): loss_lm = self.loss_lm(output_lm, label_lm) return loss_rec, self.alpha * loss_lm
def index(self, uv, cam_z=None, image_size=(), z_bounds=None): """ Get pixel-aligned image features at 2D image coordinates :param uv (B, N, 2) image points (x,y) :param cam_z ignored (for compatibility) :param image_size image size, either (width, height) or single int. if not specified, assumes coords are in [-1, 1] :param z_bounds ignored (for compatibility) :return (B, L, N) L is latent size """ with profiler.record_function("encoder_index"): if uv.shape[0] == 1 and self.latent.shape[0] > 1: uv = uv.expand(self.latent.shape[0], -1, -1) with profiler.record_function("encoder_index_pre"): if len(image_size) > 0: if len(image_size) == 1: image_size = (image_size, image_size) scale = self.latent_scaling / image_size uv = uv * scale - 1.0 uv = uv.unsqueeze(2) # (B, N, 1, 2) samples = F.grid_sample( self.latent, uv, align_corners=True, mode=self.index_interp, padding_mode=self.index_padding, ) return samples[:, :, :, 0] # (B, C, N)
def forward(self, input, mask): with profiler.record_function("LABEL1: linear pass"): out = self.linear(input) with profiler.record_function("LABEL2: masking"): threshold = out.sum(axis=1).mean() # removed.item() hi_idx = (mask > threshold).nonzero(as_tuple=True) return out, hi_idx
def background_idf(self, term): with profiler.record_function("jterm"): jterm = self.JTerm("contents", term) with profiler.record_function("df"): df = self.reader.docFreq(jterm) with profiler.record_function("idf"): idf = np.log10((self.numdocs - df + 0.5) / (df + 0.5)) return idf
def forward(self, input, mask): with profiler.record_function("LINEAR PASS"): out = self.linear(input) with profiler.record_function("MASK INDICES"): threshold = out.sum(axis=1).mean() hi_idx = (mask > threshold).nonzero(as_tuple=True) return out, hi_idx
def forward(self, src_tokens, src_lengths, prev_output_tokens): with profiler.record_function("Encoder_out"): encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths) with profiler.record_function("Decoder_out"): decoder_out = self.decoder(prev_output_tokens=prev_output_tokens, encoder_out=encoder_out) return decoder_out
def forward(self, x, mask): with profiler.record_function('Linear'): out = self.linear(x) with profiler.record_function('Mask'): threshold = out.sum(dim=1).mean().item() idx = np.argwhere(mask.cpu().numpy() > threshold) idx = torch.from_numpy(idx).cuda() return out, idx
def forward(self, input, mask): with profiler.record_function("LINEAR PASS"): out = self.linear(input) with profiler.record_function("MASK INDICES"): threshold = out.sum(axis=1).mean().item() hi_idx = np.argwhere(mask.cpu().numpy() > threshold) hi_idx = torch.from_numpy(hi_idx).cuda() return out, hi_idx
def forward(self, x, sample_info=None): """ Encodes the given input into a q_Z0(z_0|x) probability distribution, samples a latent vector from that distribution, transforms it into q_ZK(z_K|x) using a invertible normalizing flow, and finally calls the decoder network using the z_K samples. :param x: Single- or Multi-channel spectrogram tensor :param sample_info: Required for MIDI pitch end velocity to be appended to the latent vector. On the last dim, index 0 should be a preset UID, index 1 a MIDI pitch, index 2 a MIDI velocity. :returns: z0_mu_logvar, z0_sampled, zK_sampled, logabsdetjacT, x_out (reconstructed spectrogram) """ n_minibatch = x.size()[0] with profiler.record_function( "ENCODING") if self.is_profiled else contextlib.nullcontext(): # Don't ask for requires_grad or this tensor becomes a leaf variable (it will require grad later) z_0_mu_logvar = torch.empty((n_minibatch, 2, self.dim_z), device=x.device, requires_grad=False) if not self.concat_midi_to_z0: z_0_mu_logvar = self.encoder(x) else: # insert midi notes if required z_0_mu_logvar[:, :, 2:] = self.encoder(x) if sample_info is None: # missing MIDI notes are tolerated for graphs and summaries z_0_mu_logvar[:, :, [0, 1]] = 0.0 else: # MIDI pitch and velocity models: free-mean and unit-variance scaled in [-1.0, 1.0] # TODO extend this to work with multiple MIDI notes? # Mean is simply scaled to [-1.0, 1.0] (min/max normalization) midi_pitch_and_vel_mu = -1.0 + 2.0 * sample_info[:, [ 1, 2 ]].float() / 127.0 z_0_mu_logvar[:, 0, [0, 1]] = midi_pitch_and_vel_mu # log(var) corresponds to a unit standard deviation in the original [0, 127] MIDI domain z_0_mu_logvar[:, 1, [0, 1]] = np.log(4.0 / (127**2)) # Separate mean and standard deviation mu0 = z_0_mu_logvar[:, 0, :] sigma0 = torch.exp(z_0_mu_logvar[:, 1, :] / 2.0) with profiler.record_function( "LATENT_FLOW") if self.is_profiled else contextlib.nullcontext( ): if self.training: # Sampling from the q_phi(z|x) probability distribution - with re-parametrization trick eps = Normal( torch.zeros(n_minibatch, self.dim_z, device=mu0.device), torch.ones(n_minibatch, self.dim_z, device=mu0.device)).sample() z_0_sampled = mu0 + sigma0 * eps else: # eval mode: no random sampling z_0_sampled = mu0 # Forward flow (fast with nflows MAF implementation - always fast with RealNVP) z_K_sampled, log_abs_det_jac = self.flow_transform(z_0_sampled) with profiler.record_function( "DECODING") if self.is_profiled else contextlib.nullcontext(): x_out = self.decoder(z_K_sampled) return z_0_mu_logvar, z_0_sampled, z_K_sampled, log_abs_det_jac, x_out
def start(self, action_name: str) -> None: if self.profiler is None: # close profiler if it is already opened. might happen if 2 profilers # are created and the first one did not call `describe` if torch.autograd._profiler_enabled(): torch.autograd._disable_profiler() if self._schedule is not None: self._schedule.setup(action_name) self._create_profilers() profiler = self.profiler.__enter__() if profiler is not None: self.profiler = profiler if self._parent_profiler is not None: self._parent_profiler.__enter__() if self._lightning_module is not None and self._register is None and self._record_module_names: self._register = RegisterRecordFunction(self._lightning_module) self._register.__enter__() if self.profiler is not None and action_name not in self._recording_map: # Add [pl][profile] in name for pytorch profiler to recognize recording = record_function("[pl][profile]" + action_name) recording.__enter__() self._recording_map[action_name] = recording
def _broadcast(values, group): with torch.no_grad(): coalesced = get_apex_wrapper().flatten(values) with record_function("torch.distributed.broadcast"): dist.broadcast(coalesced, 0, group=group) get_apex_wrapper().multi_tensor_scale( get_apex_wrapper().unflatten(coalesced, values), values, 1.0)
def dummy_train_3x1(device): model = nn.Sequential( spnn.Conv3d(4, 32, kernel_size=(3, 1, 3), stride=1), spnn.Conv3d(32, 64, kernel_size=(1, 3, 3), stride=1), spnn.Conv3d(64, 128, kernel_size=(3, 1, 3), stride=1), spnn.Conv3d(128, 256, kernel_size=(1, 3, 3), stride=1), spnn.Conv3d(256, 128, kernel_size=(3, 1, 3), stride=1, transpose=True), spnn.Conv3d(128, 64, kernel_size=(1, 3, 3), stride=1, transpose=True), spnn.Conv3d(64, 32, kernel_size=(3, 1, 3), stride=1, transpose=True), spnn.Conv3d(32, 10, kernel_size=(1, 3, 3), stride=1, transpose=True), ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss().to(device) print('Starting dummy_train_3x1 ...') time = datetime.now() with profiler.profile(profile_memory=True, use_cuda=True) as prof: with profiler.record_function("model_inference"): for i in range(10): feed_dict = generate_batched_random_point_clouds() inputs = feed_dict['lidar'].to(device) targets = feed_dict['targets'].F.to(device).long() outputs = model(inputs) optimizer.zero_grad() loss = criterion(outputs.F, targets) loss.backward() optimizer.step() # print('[step %d] loss = %f.'%(i, loss.item())) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) prof.export_chrome_trace("trace_dummy_3x1.json") time = datetime.now() - time print('Finished dummy_train_3x1 in ', time)
def start(self, action_name: str) -> None: if self.profiler is None and action_name in self._record_functions_start: # close profiler if it is already opened. might happen if 2 profilers # are created and the first one did not call `describe` try: torch.autograd._disable_profiler() # noqa except (AttributeError, RuntimeError): pass if self._schedule is not None: self._schedule.setup(action_name) self._create_profilers() profiler = self.profiler.__enter__() if profiler is not None: self.profiler = profiler if self._parent_profiler is not None: self._parent_profiler.__enter__() if self._register is not None: self._register.__enter__() if (self.profiler is not None and (action_name in self._record_functions or action_name.startswith(self.RECORD_FUNCTION_PREFIX)) and action_name not in self._recording_map): recording = record_function(action_name) recording.__enter__() self._recording_map[action_name] = recording
def eval(): data_loader, _ = get_data_loader() lstm = load_model().to(device) acc = 0 profiler_x = None lstm.eval() with torch.no_grad(): for d in data_loader: x = d['x'].to(device) label = d['label'].to(device) if profiler_x is None: profiler_x = x pred = lstm(x).to(device) _acc = torch.true_divide((torch.argmax(pred, 1) == label).double().sum().item(), 8) acc += _acc.item() print(torch.true_divide(acc, len(data_loader)), len(data_loader)) with profiler.profile(record_shapes=True, profile_memory=True, use_cuda=True) as prof: with profiler.record_function("model_inference"): lstm(profiler_x) print(prof.key_averages().table(row_limit=100)) prof.export_chrome_trace("trace.json")
def step(self): """ Signals the profiler that the next profiling step has started. """ if self.record_steps and self.step_rec_fn: self.step_rec_fn.__exit__(None, None, None) prev_action = self.current_action self.step_num += 1 self.current_action = self.schedule(self.step_num) if self.current_action == ProfilerAction.NONE: if prev_action == ProfilerAction.NONE: pass elif prev_action == ProfilerAction.WARMUP: warn("Incorrect schedule: WARMUP followed by NONE") self._start_trace() self._stop_trace() elif prev_action == ProfilerAction.RECORD: warn("Incorrect schedule: RECORD followed by NONE") self._stop_trace() else: assert prev_action == ProfilerAction.RECORD_AND_SAVE self._stop_trace() if self.on_trace_ready: self.on_trace_ready(self) elif self.current_action == ProfilerAction.WARMUP: if prev_action == ProfilerAction.NONE: self._start_warmup() elif prev_action == ProfilerAction.WARMUP: pass elif prev_action == ProfilerAction.RECORD: warn("Incorrect schedule: RECORD followed by WARMUP") self._stop_trace() else: assert prev_action == ProfilerAction.RECORD_AND_SAVE self._stop_trace() if self.on_trace_ready: self.on_trace_ready(self) self._start_warmup() elif self.current_action in \ [ProfilerAction.RECORD, ProfilerAction.RECORD_AND_SAVE]: if prev_action == ProfilerAction.NONE: self._start_warmup() self._start_trace() elif prev_action == ProfilerAction.WARMUP: self._start_trace() elif prev_action == ProfilerAction.RECORD: pass else: assert prev_action == ProfilerAction.RECORD_AND_SAVE self._stop_trace() if self.on_trace_ready: self.on_trace_ready(self) self._start_warmup() self._start_trace() if self.record_steps: self.step_rec_fn = prof.record_function("ProfilerStep#" + str(self.step_num)) self.step_rec_fn.__enter__()
def __enter__(self): self._enter_actions() if self.record_steps: self.step_rec_fn = prof.record_function("ProfilerStep#" + str(self.step_num)) self.step_rec_fn.__enter__() return self
def refresh_trainable(self) -> None: """ If the module trainability has changed, update all the assumptions """ # Make sure that this is not done while gradients are waiting to be reduced (if no_sync context for instance) if functools.reduce(lambda x, y: x or y, self._grad_to_be_reduced, False): logging.warning( "Grads waiting to be reduced. If this is on purpose (grad accumulation), please use a no_sync() context" ) with profiler.record_function("fairscale::sdp::refresh_trainable"): self._trainable_params = list( filter(lambda x: x.requires_grad, self._all_params)) self._trainable_params.sort(key=lambda x: x.numel()) self._trainable_param_to_rank = {} for optim in self._sharded_optimizers: # OSS may need to change the communication pattern optim.refresh_trainable() # Update ShardedDDP given the new partitions for (device_per_rank_params ) in optim._per_device_params.values( ): # all the params on this device (inc all ranks) for device_params in device_per_rank_params: for param in filter(lambda x: x.requires_grad, device_params): self._trainable_param_to_rank[ param] = optim._param_to_rank[param] self._setup_bucket_strategy() self._setup_backward_hooks()
def forward(self, *inputs: Any, **kwargs: Any) -> Any: """ Module forward pass, handles any DDP-specific work in the background. Primes the backward pass for gradient reduction to the proper ranks. """ with profiler.record_function("fairscale::sdp::forward"): # Deferred initialization, or change detection needs_setup = len(self._grad_hooks) == 0 and self.training if self._auto_refresh_trainable: needs_setup |= self._detect_train_change() if needs_setup: self.refresh_trainable() if self._enable_broadcast_buffers: # NCCL communications are on a different stream, needs to be blocking # for the subsequent FW to be correct self.sync_buffers(blocking=True) # Reset all the grad reduce and bucket state flags self._clear_counters() # Normal FW on the base model return self.module(*inputs, **kwargs)
def record_function_on_caller_rpc_async(dst_worker_name: str, block: str) -> Tensor: t: Tensor = torch.ones(1) with record_function(block) as rf: fut1 = rpc.rpc_async(dst_worker_name, script_add_ones, (t, )) fut2 = rpc.rpc_async(dst_worker_name, script_add_ones, (t, )) res = fut1.wait() + fut2.wait() return res
def _benchmark_op(self, op: Callable, args: List, kwargs: Dict[str, Any], tag: str, label_str: str) -> Tuple[float, float]: logger.debug(f"benchmarking {label_str}") gpu_memory = 0 timer = Timer(self.device) # flush cache if self.use_cuda: if not self.cuda_l2_cache: _clear_cache() # Reset to measure peak memory usage torch.cuda.reset_peak_memory_stats() with record_function(label_str): timer.start() if self.use_cuda: op_run_id_range = torch.cuda.nvtx.range_start(label_str) op(*args, **kwargs) timer.stop() if self.use_cuda: torch.cuda.nvtx.range_end(op_run_id_range) # Memory size in MB gpu_memory = torch.cuda.max_memory_allocated() / (1048576) # Return result in milliseconds. return timer.elapsed_time_ms(), gpu_memory
def _start_recording_forward(self, _: nn.Module, input: Tensor, record_name: str) -> Tensor: # Add [pl][module] in name for pytorch profiler to recognize record = record_function("[pl][module]" + record_name) record.__enter__() self._records[record_name] = record return input
def forward(ctx, a2a_info, *inputs): global myreq with record_function("DLRM alltoall_req_fwd_single"): batch_split_lengths = a2a_info.global_batch_partition_slices if batch_split_lengths: batch_split_lengths = [ m * a2a_info.emb_dim * a2a_info.local_table_num for m in batch_split_lengths ] table_split_lengths = a2a_info.global_table_wise_parition_slices if table_split_lengths: table_split_lengths = [ a2a_info.local_batch_num * e * a2a_info.emb_dim for e in table_split_lengths ] input = torch.cat(inputs, dim=1).view([-1]) output = input.new_empty([ a2a_info.global_table_num * a2a_info.local_batch_num * a2a_info.emb_dim ]) req = dist.all_to_all_single(output, input, table_split_lengths, batch_split_lengths, async_op=True) myreq.req = req myreq.tensor = [] myreq.tensor.append(output) myreq.tensor = tuple(myreq.tensor) a2a_info.batch_split_lengths = batch_split_lengths a2a_info.table_split_lengths = table_split_lengths myreq.a2a_info = a2a_info ctx.a2a_info = a2a_info return myreq.tensor
def block(self, tag: str, time_tag: Optional[str] = None, monitor_gpu_utils: bool = False): self.debug(f"start {tag}") time_tag = time_tag or tag if not torch.cuda.is_available(): monitor_gpu_utils = False if monitor_gpu_utils: torch.cuda.synchronize() begin = time.time() try: with record_function(f"{self.name}:{tag}"): yield finally: self.debug(f"finish {tag}") if time_tag not in self.elapsed_time_log: self.elapsed_time_log[time_tag] = (0, 0.0) if monitor_gpu_utils and time_tag not in self.gpu_elapsed_time_log: self.gpu_elapsed_time_log[time_tag] = (0, 0.0) elapsed_time = time.time() - begin n, t = self.elapsed_time_log[time_tag] self.elapsed_time_log[time_tag] = (n + 1, t + elapsed_time) if monitor_gpu_utils: torch.cuda.synchronize() gpu_elapsed_time = time.time() - begin n, t = self.gpu_elapsed_time_log[time_tag] self.gpu_elapsed_time_log[time_tag] = \ (n + 1, t + gpu_elapsed_time)
def clip_grad_norm( self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0, filter_params_fn: Callable[[Any], Any] = None, ) -> torch.Tensor: """ Clip all gradients at this point in time. The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place. Arguments: max_norm (float or int): max norm of the gradients norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. Returns: Total norm of the parameters (viewed as a single vector). .. note: This is analogous to `torch.nn.utils.clip_grad_norm_` but handles the partitioning and multiple devices per rank under the hood. The default torch util is not applicable here, because each rank only has a partial view of all the grads in the model, so calling it in the OSS context would lead to different scaling being applied per subset of model parameters .. warning: This needs to be called on all ranks, since synchronization primitives will be used """ # Compute the max norm for this shards's worth of gradients max_norm = float(max_norm) norm_type = float(norm_type) with profiler.record_function("fairscale::oss::clip_grad_norm"): # Option to filter parameters from the grad_norm calculation. This is useful for model parallelism. # To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel' # 'model_parallel' flag is set in Megatron-LM: # https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54 local_params = filter_params_fn(self._local_params) if filter_params_fn is not None else self._local_params local_norm = calc_grad_norm(local_params, norm_type).to(self._default_device) # Compute the norm on this grad set, # then sync all the norms from all ranks if norm_type == inf: total_norm = local_norm # all reduce over data parallel and model parallel workers dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=dist.group.WORLD) else: # local norm result can be accumulated with the remote ones if put to the right power # n_i = sum_rank(a^p)^1/p # -> n_total = all_reduce(n_i^p)^(1/p) = sum_i(n_i^p)^1/p = sum_i(sum_rank(a^p))^1/p # all reduce over data parallel and model parallel workers total_norm = local_norm ** norm_type dist.all_reduce(total_norm) total_norm = total_norm ** (1.0 / norm_type) clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6) if clip_coef < 1: for device, device_params in self._per_device_params.items(): for p in filter(lambda x: x.grad is not None, device_params[self.rank]): p.grad.detach().mul_(clip_coef.to(device)) # type: ignore # mypy trips on the filter return total_norm