def all_gather(self, collectiveArgs, retFlag=False): retObj = xm.all_gather(collectiveArgs.ipTensor, dim=0) collectiveArgs.opTensor = retObj if collectiveArgs.asyncOp: collectiveArgs.waitObj.append(retObj) if retFlag: return retObj
def compute(self) -> Any: """ Compute precision, recall, f1 score and support. Compute micro, macro and weighted average for the metrics. Returns: list of aggregated metrics: per-class, micro, macro and weighted averaging of precision, recall, f1 score and support metrics """ # ddp hotfix, could be done better # but metric must handle DDP on it's own if self._ddp_backend == "xla": device = get_device() for key in self.statistics: key_statistics = torch.tensor([self.statistics[key]], device=device) key_statistics = xm.all_gather(key_statistics).sum( dim=0).cpu().numpy() self.statistics[key] = key_statistics elif self._ddp_backend == "ddp": for key in self.statistics: value: List[np.ndarray] = all_gather(self.statistics[key]) value: np.ndarray = np.sum(np.vstack(value), axis=0) self.statistics[key] = value per_class, micro, macro, weighted = get_aggregated_metrics( tp=self.statistics["tp"], fp=self.statistics["fp"], fn=self.statistics["fn"], support=self.statistics["support"], zero_division=self.zero_division, ) return per_class, micro, macro, weighted
def compute(self) -> Any: """ Returns: Confusion matrix of K rows and K columns, where rows corresponds to ground-truth targets and columns corresponds to predicted targets. """ # ddp hotfix, could be done better # but metric must handle DDP on it's own if self._ddp_backend == "xla": # if you have "RuntimeError: Aborted: Session XXX is not found" here # please, ask Google for a more powerful TPU setup ;) device = get_device() value = torch.tensor([self.conf], device=device) self.conf = xm.all_gather(value).sum(0).cpu().detach().numpy() elif self._ddp_backend == "ddp": value: List[np.ndarray] = all_gather(self.conf) value: np.ndarray = np.sum(np.stack(value, axis=0), axis=0) self.conf = value if self.normalize: conf = self.conf.astype(np.float32) return conf / conf.sum(1).clip(min=1e-12)[:, None] else: return self.conf
def xla_all_gather(data, device): """ Run all_gather on arbitrary picklable data (not necessarily tensors) Args: data: any picklable object Returns: list[data]: list of data gathered from each rank """ import torch_xla.core.xla_model world_size = xm.xrt_world_size() if world_size == 1: return [data] # serialized to a Tensor buffer = pickle.dumps(data) storage = torch.ByteStorage.from_buffer(buffer) tensor = torch.ByteTensor(storage).to(device) # obtain Tensor size of each rank local_size = torch.tensor([tensor.numel()], device=device) size_list = [torch.tensor([0], device=device) for _ in range(world_size)] xla_model.all_gather(size_list, local_size) size_list = [int(size.item()) for size in size_list] max_size = max(size_list) # receiving Tensor from all ranks # we pad the tensor because torch all_gather does not support # gathering tensors of different shapes tensor_list = [] for _ in size_list: tensor_list.append( torch.empty((max_size, ), dtype=torch.uint8, device=device)) if local_size != max_size: padding = torch.empty(size=(max_size - local_size, ), dtype=torch.uint8, device=device) tensor = torch.cat((tensor, padding), dim=0) xla_model.all_gather(tensor_list, tensor) data_list = [] for size, tensor in zip(size_list, tensor_list): buffer = tensor.cpu().numpy().tobytes()[:size] data_list.append(pickle.loads(buffer)) return data_list
def broadcast(self, obj: object, src: int = 0) -> object: buffer = io.BytesIO() torch.save(obj, buffer) data = bytearray(buffer.getbuffer()) data_tensor = torch.tensor(data).to(xm.xla_device(), dtype=torch.float) data = xm.all_gather(data_tensor) buffer = io.BytesIO(data.cpu().byte().numpy()) obj = torch.load(buffer) return obj
def allgather(self, output_tensors_list, input_tensors): for input_tensor, output_tensors in zip(input_tensors, output_tensors_list): result = xm.all_gather(input_tensor, groups=self._mesh) for i, slice in enumerate( torch.split(result, input_tensor.shape[0])): output_tensors[i].copy_(slice) return WorkXla([t for sublist in output_tensors_list for t in sublist])
def broadcast(self, obj: object, src: int = 0) -> object: if not self.is_distributed: return obj buffer = io.BytesIO() torch.save(obj, buffer) data = bytearray(buffer.getbuffer()) data_tensor = torch.tensor(data, device=self.root_device, dtype=torch.float) data = xm.all_gather(data_tensor) buffer = io.BytesIO(data.cpu().byte().numpy()) obj = torch.load(buffer) return obj
def broadcast(self, obj, src=0): if self.trainer.tpu_id is not None: # running on a single core return obj buffer = io.BytesIO() torch.save(obj, buffer) data = bytearray(buffer.getbuffer()) data_tensor = torch.tensor(data).to(xm.xla_device(), dtype=torch.float) data = xm.all_gather(data_tensor) buffer = io.BytesIO(data.cpu().byte().numpy()) obj = torch.load(buffer) return obj
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: """ Function to gather a tensor from several distributed processes Args: tensor: tensor of shape (batch, ...) group: not available with TPUs sync_grads: not available with TPUs Return: A tensor of shape (world_size, batch, ...) """ if isinstance(tensor, torch.Tensor) and tensor.dim() == 0: tensor = tensor.unsqueeze(0) return xm.all_gather(tensor)
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: """ Function to gather a tensor from several distributed processes Args: tensor: tensor of shape (batch, ...) group: the process group to gather results from. Defaults to all processes (world) sync_grads: flag that allows users to synchronize gradients for all_gather op Return: A tensor of shape (world_size, batch, ...) """ return xm.all_gather(tensor, group=group, sync_grads=sync_grads)
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: """ Function to gather a tensor from several distributed processes Args: tensor: tensor of shape (batch, ...) group: not available with TPUs sync_grads: not available with TPUs Return: A tensor of shape (world_size, batch, ...) """ # todo: Add support for backward with all_gather if isinstance(self.training_type_plugin, TPUSpawnPlugin) and self.training_type_plugin.is_distributed: return xm.all_gather(tensor).view(-1, *tensor.shape) return tensor
def _mp_fn(index): device = xm.xla_device() world_size = xm.xrt_world_size() if xm.xla_device_hw(device) in ('TPU', 'GPU'): # Testing with a single replica group ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device) result = xm.all_gather(ordinal_tensor, dim=0) cpu_result = result.cpu() expected = torch.arange(0, world_size, dtype=torch.float) if not cpu_result.allclose(expected): print('xm.all_gather() produced wrong reductions', file=sys.stderr) print(f'[{index}] {cpu_result}', file=sys.stderr) sys.exit(1) # Testing with two replica groups if world_size % 2 == 0 and world_size > 1: mp_groups = [[n for n in range(world_size) if n % 2 == 0], [n for n in range(world_size) if n % 2 == 1]] group_size = len(mp_groups[0]) replica_id = int(index % 2 == 1) result = xm.all_gather(ordinal_tensor, dim=0, groups=mp_groups) cpu_result = result.cpu() expected = torch.arange(replica_id, world_size, step=2, dtype=torch.float) if not cpu_result.allclose(expected): print('xm.all_gather() produced wrong reductions', file=sys.stderr) print(f'[{index}] {cpu_result}', file=sys.stderr) sys.exit(1) else: print( f'Failed to create two replica groups with {world_size} replicas', file=sys.stderr) else: print(f'{device} is not a TPU or GPU device', file=sys.stderr)
def compute(self) -> Tuple[torch.Tensor, float, float, float]: """Computes the AUC metric based on saved statistics.""" targets = torch.cat(self.targets) scores = torch.cat(self.scores) # ddp hotfix, could be done better # but metric must handle DDP on it's own if self._ddp_backend == "xla": # if you have "RuntimeError: Aborted: Session XXX is not found" here # please, ask Google for a more powerful TPU setup ;) device = get_device() scores = xm.all_gather(scores.to(device)).cpu().detach() targets = xm.all_gather(targets.to(device)).cpu().detach() elif self._ddp_backend == "ddp": scores = torch.cat(all_gather(scores)) targets = torch.cat(all_gather(targets)) scores, targets, _ = process_multilabel_components(outputs=scores, targets=targets) per_class = auc(scores=scores, targets=targets) micro = binary_auc(scores=scores.view(-1), targets=targets.view(-1))[0] macro = per_class.mean().item() weights = targets.sum(axis=0) / len(targets) weighted = (per_class * weights).sum().item() return per_class, micro, macro, weighted
def _mp_fn(index): device = xm.xla_device() if xm.xla_device_hw(device) != 'CPU': ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device) result = xm.all_gather(ordinal_tensor) cpu_result = result.cpu() expected = torch.arange(0, xm.xrt_world_size(), dtype=torch.float) if not cpu_result.allclose(expected): print('xm.all_gather() produced wrong reductions', file=sys.stderr) print('[{}] {}'.format(index, cpu_result), file=sys.stderr) sys.exit(1) else: print('Default device {} does not support replication'.format(device), file=sys.stderr)
def compute(self): """ Compute metrics with accumulated statistics Returns: tuple of metrics: per_class, micro_metric, macro_metric, weighted_metric(None if weights is None) """ per_class = [] total_statistics = {} macro_metric = 0 weighted_metric = 0 # ddp hotfix, could be done better # but metric must handle DDP on it's own # TODO: optimise speed if self._ddp_backend == "xla": device = get_device() for _, statistics in self.statistics.items(): for key in statistics: value = torch.tensor([statistics[key]], device=device) statistics[key] = xm.all_gather(value).sum(dim=0) elif self._ddp_backend == "ddp": for _, statistics in self.statistics.items(): for key in statistics: value: List[torch.Tensor] = all_gather(statistics[key]) value: torch.Tensor = torch.sum(torch.vstack(value), dim=0) statistics[key] = value for class_idx, statistics in self.statistics.items(): value = self.metric_fn(**statistics) per_class.append(value) macro_metric += value if self.weights is not None: weighted_metric += value * self.weights[class_idx] for stats_name, value in statistics.items(): total_statistics[stats_name] = ( total_statistics.get(stats_name, 0) + value) macro_metric /= len(self.statistics) micro_metric = self.metric_fn(**total_statistics) if self.weights is None: weighted_metric = None if self.compute_per_class_metrics: return per_class, micro_metric, macro_metric, weighted_metric else: return [], micro_metric, macro_metric, weighted_metric
def gather_tensor(tensor): world_size = get_world_size() if world_size < 2: return tensor with torch.no_grad(): tensor_list = [] if is_xla(): tensor_list = xm.all_gather(tensor) tensor_list = tensor_list.view(world_size, *tensor.size()) else: for _ in range(world_size): tensor_list.append(torch.zeros_like(tensor)) dist.all_gather(tensor_list, tensor) tensor_list = torch.stack(tensor_list, dim=0) return tensor_list
def all_gather(tensor, group, return_tensor=False): """Perform an all-gather operation.""" if use_xla(): result = xm.all_gather(tensor, groups=group[1]) world_size = get_world_size(group=group) result = result.view(world_size, *tensor.size()) if return_tensor: return result else: return [result[i] for i in range(world_size)] else: world_size = get_world_size(group=group) rank = get_rank(group=group) tensor_list = [ tensor if i == rank else torch.empty_like(tensor) for i in range(world_size) ] dist.all_gather(tensor_list, tensor, group=group) if return_tensor: return torch.stack(tensor_list, dim=0) else: return tensor_list
def _train_one_epoch(self, loader): loader_time = .0 train_time = .0 curr_time = time.time() self.epoch_storage = defaultdict(list) for key in ['approx', 'target', 'loss', 'batch_metric']: self.epoch_storage[key] = [] if self.fp16: scaler = amp.GradScaler() self.model.train() if self.progress_bar and self.rank == 0: iterator = enumerate(tqdm(loader, desc='train')) else: iterator = enumerate(loader) for batch_i, inputs in iterator: loader_time += time.time() - curr_time curr_time = time.time() self.optimizer.zero_grad() batches_done = len(loader) * (self.global_epoch-1) + batch_i inputs = [t.to(self.device) for t in inputs] if self.fp16: with amp.autocast(): loss = self.forward_train(self, inputs) loss = loss / self.grad_accumulations if self.parallel == 'ddp' and self.ddp_sync_last \ and batch_i != len(loader) - 1: with self.model.no_sync(): scaler.scale(loss).backward() else: # sync only last batch scaler.scale(loss).backward() if (batch_i + 1) % self.grad_accumulations == 0: scaler.step(self.optimizer) scaler.update() else: loss = self.forward_train(self, inputs) loss = loss / self.grad_accumulations if self.parallel == 'ddp' and self.ddp_sync_last \ and batch_i != len(loader) - 1: with self.model.no_sync(): loss.backward() else: # sync only last batch loss.backward() if (batch_i + 1) % self.grad_accumulations == 0: if self.xla: xm.optimizer_step(self.optimizer, barrier=True) else: self.optimizer.step() if self.batch_scheduler: self.scheduler.step() if self.parallel == 'ddp' and self.ddp_average_loss: if self.xla: loss_batch = xm.all_gather( loss.detach().clone().view(1)).mean().item() else: loss_batch = comm.gather_tensor( loss.detach().clone().view(1)).mean().item() else: # Use loss on device: 0 loss_batch = loss.item() ''' TensorBoard logging ''' learning_rate = [param_group['lr'] for param_group in self.optimizer.param_groups] logs = [ ('batch_train_loss', loss_batch), ('batch_train_lr', learning_rate) ] if len(self.epoch_storage['batch_metric']) > 0: metric = self.epoch_storage['batch_metric'][-1] logs.append(('batch_valid_mertric', metric)) self.tb_logger.list_of_scalars_summary(logs, batches_done) self.epoch_storage['loss'].append(loss_batch) train_time += time.time() - curr_time curr_time = time.time() if self.debug and self.rank == 0: self.logger( f'loader: {loader_time:.1f} s | train: {train_time:.1f} s') for key, val in self.epoch_storage.items(): if len(val) > 0: if isinstance(val[0], torch.Tensor): self.epoch_storage[key] = torch.cat(val) else: self.epoch_storage[key] = torch.tensor(val).to(self.device) loss_total = self.epoch_storage['loss'].mean().item() if self.parallel == 'ddp': ''' Gather tensors ''' for key, val in self.epoch_storage.items(): if len(val) > 0: if self.xla: self.epoch_storage[key] = xm.all_gather(val) else: self.epoch_storage[key] = comm.gather_tensor(val) metric_total, monitor_metrics_total = self.evaluate_epoch(self) else: metric_total, monitor_metrics_total = self.evaluate_epoch(self) if metric_total is None: metric_total = loss_total ''' TensorBoard logging ''' logs = [ ('epoch_train_loss', loss_total), ('epoch_train_metric', metric_total), ] self.tb_logger.list_of_scalars_summary(logs, self.global_epoch) return loss_total, metric_total, monitor_metrics_total
def _valid_one_epoch(self, loader): self.epoch_storage = defaultdict(list) for key in ['approx', 'target', 'loss', 'batch_metric']: self.epoch_storage[key] = [] self.model.eval() if self.progress_bar and self.rank == 0: iterator = enumerate(tqdm(loader, desc='valid')) else: iterator = enumerate(loader) with torch.no_grad(): for batch_i, inputs in iterator: batches_done = len(loader) * (self.global_epoch - 1) + batch_i inputs = [t.to(self.device) for t in inputs] loss, approx = self.forward_train(self, inputs) self.evaluate_batch(self, inputs, approx) if self.parallel == 'ddp' and self.ddp_average_loss: if self.xla: loss_batch = xm.all_gather( loss.detach().clone().view(1)).mean().item() else: loss_batch = comm.gather_tensor( loss.detach().clone().view(1)).mean().item() else: # Use loss on device: 0 loss_batch = loss.item() logs = [ ('batch_valid_loss', loss_batch), ] if len(self.epoch_storage['batch_metric']) > 0: metric = self.epoch_storage['batch_metric'][-1] logs.append(('batch_valid_mertric', metric)) self.tb_logger.list_of_scalars_summary(logs, batches_done) self.epoch_storage['loss'].append(loss_batch) for key, val in self.epoch_storage.items(): if len(val) > 0: if isinstance(val[0], torch.Tensor): self.epoch_storage[key] = torch.cat(val) else: self.epoch_storage[key] = torch.tensor(val).to(self.device) loss_total = self.epoch_storage['loss'].mean().item() if self.parallel == 'ddp': for key, val in self.epoch_storage.items(): if len(val) > 0: if self.xla: self.epoch_storage[key] = xm.all_gather(val) else: self.epoch_storage[key] = comm.gather_tensor(val) metric_total, monitor_metrics_total = self.evaluate_epoch(self) else: metric_total, monitor_metrics_total = self.evaluate_epoch(self) if metric_total is None: metric_total = loss_total logs = [ ('epoch_valid_loss', loss_total), ('epoch_valid_metric', metric_total), ] self.tb_logger.list_of_scalars_summary(logs, self.global_epoch) return loss_total, metric_total, monitor_metrics_total
def _train_one_epoch(self, loader): loader_time = .0 train_time = .0 curr_time = time.time() self.epoch_storage = defaultdict(list) for key in ['approx', 'target', 'loss', 'batch_metric']: self.epoch_storage[key] = [] if self.fp16: scaler = amp.GradScaler() self.model.train() if self.progress_bar and self.rank == 0: iterator = enumerate(tqdm(loader, desc='train')) else: iterator = enumerate(loader) for batch_i, inputs in iterator: loader_time += time.time() - curr_time curr_time = time.time() self.optimizer.zero_grad() batches_done = len(loader) * (self.global_epoch - 1) + batch_i inputs = [t.to(self.device) for t in inputs] # forward and backward if self.fp16: with amp.autocast(): loss, approx = self.forward_train(self, inputs) self.evaluate_batch(self, inputs, approx) # evaluation loss = loss / self.grad_accumulations scaler.scale(loss).backward() if (batch_i + 1) % self.grad_accumulations == 0: if self.sam: # first step optimizer_state = scaler._per_optimizer_states[id( self.optimizer)] scaler.unscale_(self.optimizer) if not sum(v.item() for v in optimizer_state["found_inf_per_device"]. values()): self.optimizer.first_step(zero_grad=True) optimizer_state["stage"] = 2 scaler.update() # second step with amp.autocast(): loss2, _ = self.forward_train(self, inputs) scaler.scale(loss2).backward() scaler.unscale_(self.optimizer) if not sum(v.item() for v in optimizer_state["found_inf_per_device"]. values()): self.optimizer.second_step(zero_grad=True) optimizer_state["stage"] = 2 scaler.update() else: scaler.step(self.optimizer) scaler.update() else: loss, approx = self.forward_train(self, inputs) self.evaluate_batch(self, inputs, approx) # evaluation loss = loss / self.grad_accumulations loss.backward() if (batch_i + 1) % self.grad_accumulations == 0: if self.xla: if self.sam: raise RuntimeError( 'SAM optimizer on XLA device is not available.' ) else: xm.optimizer_step(self.optimizer, barrier=True) else: if self.sam: self.optimizer.first_step(zero_grad=True) loss2, _ = self.forward_train(self, inputs) loss2.backward() self.optimizer.second_step(zero_grad=True) else: self.optimizer.step() if self.batch_scheduler: self.scheduler.step() if self.parallel == 'ddp' and self.ddp_average_loss: if self.xla: loss_batch = xm.all_gather( loss.detach().clone().view(1)).mean().item() else: loss_batch = comm.gather_tensor( loss.detach().clone().view(1)).mean().item() else: # Use loss on device: 0 loss_batch = loss.item() # logging learning_rate = [ param_group['lr'] for param_group in self.optimizer.param_groups ] logs = [('batch_train_loss', loss_batch), ('batch_train_lr', learning_rate)] if len(self.epoch_storage['batch_metric']) > 0: metric = self.epoch_storage['batch_metric'][-1] logs.append(('batch_valid_mertric', metric)) self.tb_logger.list_of_scalars_summary(logs, batches_done) self.epoch_storage['loss'].append(loss_batch) train_time += time.time() - curr_time curr_time = time.time() if self.debug and self.rank == 0: self.logger( f'loader: {loader_time:.1f} s | train: {train_time:.1f} s') for key, val in self.epoch_storage.items(): if len(val) > 0: if isinstance(val[0], torch.Tensor): self.epoch_storage[key] = torch.cat(val) else: self.epoch_storage[key] = torch.tensor(val).to(self.device) loss_total = self.epoch_storage['loss'].mean().item() if self.parallel == 'ddp': # gather tensors for key, val in self.epoch_storage.items(): if len(val) > 0: if self.xla: self.epoch_storage[key] = xm.all_gather(val) else: self.epoch_storage[key] = comm.gather_tensor(val) metric_total, monitor_metrics_total = self.evaluate_epoch(self) else: metric_total, monitor_metrics_total = self.evaluate_epoch(self) if metric_total is None: metric_total = loss_total # logging logs = [ ('epoch_train_loss', loss_total), ('epoch_train_metric', metric_total), ] self.tb_logger.list_of_scalars_summary(logs, self.global_epoch) return loss_total, metric_total, monitor_metrics_total
def _train_one_epoch(self, loader): loader_time = .0 train_time = .0 start_time = time.time() curr_time = time.time() self.epoch_storage = defaultdict(list) for key in ['approx', 'target', 'loss', 'batch_metric']: self.epoch_storage[key] = [] if self.fp16: scaler = amp.GradScaler() self.model.train() if self.progress_bar and self.rank == 0: iterator = enumerate(tqdm(loader, desc='train')) else: iterator = enumerate(loader) batch_total = len(loader) ett_disp = False for batch_i, inputs in iterator: loader_time += time.time() - curr_time curr_time = time.time() elapsed_time = curr_time - start_time if self.rank == 0 and self.state['epoch'] == 0 and elapsed_time > 30 and not ett_disp: # show ETA ett = elapsed_time * batch_total // batch_i self.logger(f'Estimated epoch training time: {int(ett)} s') try: ram_usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss // 1024 self.logger(f'Maximum RAM usage: {ram_usage} MB') except: self.logger('Failed to get RAM usage.') try: gram_usage = int(max(get_gpu_memory().values())) self.logger(f'Maximum GRAM usage: {gram_usage} MB') except: self.logger('Failed to get GRAM usage.') ett_disp = True self.optimizer.zero_grad() batches_done = batch_total * (self.global_epoch-1) + batch_i inputs = [t.to(self.device) for t in inputs] # forward and backward if self.fp16: with amp.autocast(): loss, approx = self.forward_train(self, inputs) self.evaluate_batch(self, inputs, approx) # evaluation loss = loss / self.grad_accumulations scaler.scale(loss).backward() if self.clip_grad is not None: dispatch_clip_grad(self.model.parameters(), self.max_grad_norm, mode=self.clip_grad) if (batch_i + 1) % self.grad_accumulations == 0: if self.sam: # first step optimizer_state = scaler._per_optimizer_states[id(self.optimizer)] scaler.unscale_(self.optimizer) if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()): self.optimizer.first_step(zero_grad=True) optimizer_state["stage"] = 2 scaler.update() # second step with amp.autocast(): loss2, _ = self.forward_train(self, inputs) scaler.scale(loss2).backward() scaler.unscale_(self.optimizer) if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()): self.optimizer.second_step(zero_grad=True) optimizer_state["stage"] = 2 scaler.update() else: scaler.step(self.optimizer) scaler.update() else: loss, approx = self.forward_train(self, inputs) self.evaluate_batch(self, inputs, approx) # evaluation loss = loss / self.grad_accumulations loss.backward() if self.clip_grad is not None: dispatch_clip_grad(self.model.parameters(), self.max_grad_norm, mode=self.clip_grad) if (batch_i + 1) % self.grad_accumulations == 0: if self.xla: if self.sam: raise RuntimeError('SAM optimizer on XLA device is not available.') else: xm.optimizer_step(self.optimizer, barrier=True) else: if self.sam: self.optimizer.first_step(zero_grad=True) loss2, _ = self.forward_train(self, inputs) loss2.backward() self.optimizer.second_step(zero_grad=True) else: self.optimizer.step() if self.batch_scheduler: self.scheduler.step() if torch.isnan(loss).any(): self.logger(f'{torch.isnan(loss).sum()} NaN detected in loss. ({batch_i}/{len(loader)})') if torch.isnan(approx).any(): self.logger(f'{torch.isnan(approx).sum()} NaN detected in output tensor. ({batch_i}/{len(loader)})') if self.parallel == 'ddp' and self.ddp_average_loss: if self.xla: loss_batch = xm.all_gather( loss.detach().clone().view(1)).mean().item() else: loss_batch = comm.gather_tensor( loss.detach().clone().view(1)).mean().item() else: # Use loss on device: 0 loss_batch = loss.item() # logging learning_rate = [param_group['lr'] for param_group in self.optimizer.param_groups] logs = [ ('batch_train_loss', loss_batch), ('batch_train_lr', learning_rate) ] if len(self.epoch_storage['batch_metric']) > 0: metric = self.epoch_storage['batch_metric'][-1] logs.append(('batch_valid_mertric', metric)) self.tb_logger.list_of_scalars_summary(logs, batches_done) self.epoch_storage['loss'].append(loss_batch) train_time += time.time() - curr_time curr_time = time.time() if self.debug and self.rank == 0: self.logger( f'loader: {loader_time:.1f} s | train: {train_time:.1f} s') for key, val in self.epoch_storage.items(): if len(val) > 0: if isinstance(val[0], torch.Tensor): self.epoch_storage[key] = torch.cat(val) else: self.epoch_storage[key] = torch.tensor(val).to(self.device) loss_total = self.epoch_storage['loss'].mean().item() if self.parallel == 'ddp': # gather tensors for key, val in self.epoch_storage.items(): if len(val) > 0: if self.xla: self.epoch_storage[key] = xm.all_gather(val) else: self.epoch_storage[key] = comm.gather_tensor(val) metric_total, monitor_metrics_total = self.evaluate_epoch(self) else: metric_total, monitor_metrics_total = self.evaluate_epoch(self) if metric_total is None: metric_total = loss_total # logging logs = [ ('epoch_train_loss', loss_total), ('epoch_train_metric', metric_total), ] self.tb_logger.list_of_scalars_summary(logs, self.global_epoch) return loss_total, metric_total, monitor_metrics_total
def forward(ctx, input, dim): ctx.dim = dim ctx.ordinal = xm.get_ordinal() ctx.world_size = xm.xrt_world_size() return xm.all_gather(input, dim=dim)
def all_gather(self, collectiveArgs, retFlag=False): retObj = xm.all_gather(collectiveArgs.ipTensor, dim=0) collectiveArgs.opTensor = retObj if retFlag: return retObj
def forward(ctx, x, dim): ctx.dim = dim tensor_list = xm.all_gather(x.unsqueeze(dim), dim=dim) return tensor_list