def all_gather(self, collectiveArgs, retFlag=False):
     retObj = xm.all_gather(collectiveArgs.ipTensor, dim=0)
     collectiveArgs.opTensor = retObj
     if collectiveArgs.asyncOp:
         collectiveArgs.waitObj.append(retObj)
     if retFlag:
         return retObj
Esempio n. 2
0
    def compute(self) -> Any:
        """
        Compute precision, recall, f1 score and support.
        Compute micro, macro and weighted average for the metrics.

        Returns:
            list of aggregated metrics: per-class, micro, macro and weighted averaging of
                precision, recall, f1 score and support metrics
        """
        # ddp hotfix, could be done better
        # but metric must handle DDP on it's own
        if self._ddp_backend == "xla":
            device = get_device()
            for key in self.statistics:
                key_statistics = torch.tensor([self.statistics[key]],
                                              device=device)
                key_statistics = xm.all_gather(key_statistics).sum(
                    dim=0).cpu().numpy()
                self.statistics[key] = key_statistics
        elif self._ddp_backend == "ddp":
            for key in self.statistics:
                value: List[np.ndarray] = all_gather(self.statistics[key])
                value: np.ndarray = np.sum(np.vstack(value), axis=0)
                self.statistics[key] = value

        per_class, micro, macro, weighted = get_aggregated_metrics(
            tp=self.statistics["tp"],
            fp=self.statistics["fp"],
            fn=self.statistics["fn"],
            support=self.statistics["support"],
            zero_division=self.zero_division,
        )
        return per_class, micro, macro, weighted
    def compute(self) -> Any:
        """
        Returns:
            Confusion matrix of K rows and K columns, where rows corresponds
            to ground-truth targets and columns corresponds to predicted
            targets.
        """
        # ddp hotfix, could be done better
        # but metric must handle DDP on it's own
        if self._ddp_backend == "xla":
            # if you have "RuntimeError: Aborted: Session XXX is not found" here
            # please, ask Google for a more powerful TPU setup ;)
            device = get_device()
            value = torch.tensor([self.conf], device=device)
            self.conf = xm.all_gather(value).sum(0).cpu().detach().numpy()
        elif self._ddp_backend == "ddp":
            value: List[np.ndarray] = all_gather(self.conf)
            value: np.ndarray = np.sum(np.stack(value, axis=0), axis=0)
            self.conf = value

        if self.normalize:
            conf = self.conf.astype(np.float32)
            return conf / conf.sum(1).clip(min=1e-12)[:, None]
        else:
            return self.conf
Esempio n. 4
0
def xla_all_gather(data, device):
    """
    Run all_gather on arbitrary picklable data (not necessarily tensors)
    Args:
        data: any picklable object
    Returns:
        list[data]: list of data gathered from each rank
    """
    import torch_xla.core.xla_model

    world_size = xm.xrt_world_size()
    if world_size == 1:
        return [data]

    # serialized to a Tensor
    buffer = pickle.dumps(data)
    storage = torch.ByteStorage.from_buffer(buffer)
    tensor = torch.ByteTensor(storage).to(device)

    # obtain Tensor size of each rank
    local_size = torch.tensor([tensor.numel()], device=device)
    size_list = [torch.tensor([0], device=device) for _ in range(world_size)]
    xla_model.all_gather(size_list, local_size)
    size_list = [int(size.item()) for size in size_list]
    max_size = max(size_list)

    # receiving Tensor from all ranks
    # we pad the tensor because torch all_gather does not support
    # gathering tensors of different shapes
    tensor_list = []
    for _ in size_list:
        tensor_list.append(
            torch.empty((max_size, ), dtype=torch.uint8, device=device))
    if local_size != max_size:
        padding = torch.empty(size=(max_size - local_size, ),
                              dtype=torch.uint8,
                              device=device)
        tensor = torch.cat((tensor, padding), dim=0)
    xla_model.all_gather(tensor_list, tensor)

    data_list = []
    for size, tensor in zip(size_list, tensor_list):
        buffer = tensor.cpu().numpy().tobytes()[:size]
        data_list.append(pickle.loads(buffer))

    return data_list
 def broadcast(self, obj: object, src: int = 0) -> object:
     buffer = io.BytesIO()
     torch.save(obj, buffer)
     data = bytearray(buffer.getbuffer())
     data_tensor = torch.tensor(data).to(xm.xla_device(), dtype=torch.float)
     data = xm.all_gather(data_tensor)
     buffer = io.BytesIO(data.cpu().byte().numpy())
     obj = torch.load(buffer)
     return obj
Esempio n. 6
0
    def allgather(self, output_tensors_list, input_tensors):
        for input_tensor, output_tensors in zip(input_tensors,
                                                output_tensors_list):
            result = xm.all_gather(input_tensor, groups=self._mesh)
            for i, slice in enumerate(
                    torch.split(result, input_tensor.shape[0])):
                output_tensors[i].copy_(slice)

        return WorkXla([t for sublist in output_tensors_list for t in sublist])
Esempio n. 7
0
 def broadcast(self, obj: object, src: int = 0) -> object:
     if not self.is_distributed:
         return obj
     buffer = io.BytesIO()
     torch.save(obj, buffer)
     data = bytearray(buffer.getbuffer())
     data_tensor = torch.tensor(data, device=self.root_device, dtype=torch.float)
     data = xm.all_gather(data_tensor)
     buffer = io.BytesIO(data.cpu().byte().numpy())
     obj = torch.load(buffer)
     return obj
Esempio n. 8
0
 def broadcast(self, obj, src=0):
     if self.trainer.tpu_id is not None:
         # running on a single core
         return obj
     buffer = io.BytesIO()
     torch.save(obj, buffer)
     data = bytearray(buffer.getbuffer())
     data_tensor = torch.tensor(data).to(xm.xla_device(), dtype=torch.float)
     data = xm.all_gather(data_tensor)
     buffer = io.BytesIO(data.cpu().byte().numpy())
     obj = torch.load(buffer)
     return obj
Esempio n. 9
0
 def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
     """
     Function to gather a tensor from several distributed processes
     Args:
         tensor: tensor of shape (batch, ...)
         group: not available with TPUs
         sync_grads: not available with TPUs
     Return:
         A tensor of shape (world_size, batch, ...)
     """
     if isinstance(tensor, torch.Tensor) and tensor.dim() == 0:
         tensor = tensor.unsqueeze(0)
     return xm.all_gather(tensor)
Esempio n. 10
0
 def all_gather(self,
                tensor: torch.Tensor,
                group: Optional[Any] = None,
                sync_grads: bool = False) -> torch.Tensor:
     """
     Function to gather a tensor from several distributed processes
     Args:
         tensor: tensor of shape (batch, ...)
         group: the process group to gather results from. Defaults to all processes (world)
         sync_grads: flag that allows users to synchronize gradients for all_gather op
     Return:
         A tensor of shape (world_size, batch, ...)
     """
     return xm.all_gather(tensor, group=group, sync_grads=sync_grads)
Esempio n. 11
0
 def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
     """
     Function to gather a tensor from several distributed processes
     Args:
         tensor: tensor of shape (batch, ...)
         group: not available with TPUs
         sync_grads: not available with TPUs
     Return:
         A tensor of shape (world_size, batch, ...)
     """
     # todo: Add support for backward with all_gather
     if isinstance(self.training_type_plugin, TPUSpawnPlugin) and self.training_type_plugin.is_distributed:
         return xm.all_gather(tensor).view(-1, *tensor.shape)
     return tensor
Esempio n. 12
0
def _mp_fn(index):
  device = xm.xla_device()
  world_size = xm.xrt_world_size()
  if xm.xla_device_hw(device) in ('TPU', 'GPU'):
    # Testing with a single replica group
    ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)
    result = xm.all_gather(ordinal_tensor, dim=0)

    cpu_result = result.cpu()
    expected = torch.arange(0, world_size, dtype=torch.float)
    if not cpu_result.allclose(expected):
      print('xm.all_gather() produced wrong reductions', file=sys.stderr)
      print(f'[{index}] {cpu_result}', file=sys.stderr)
      sys.exit(1)

    # Testing with two replica groups
    if world_size % 2 == 0 and world_size > 1:
      mp_groups = [[n for n in range(world_size) if n % 2 == 0],
                   [n for n in range(world_size) if n % 2 == 1]]
      group_size = len(mp_groups[0])
      replica_id = int(index % 2 == 1)

      result = xm.all_gather(ordinal_tensor, dim=0, groups=mp_groups)

      cpu_result = result.cpu()
      expected = torch.arange(replica_id, world_size, step=2, dtype=torch.float)
      if not cpu_result.allclose(expected):
        print('xm.all_gather() produced wrong reductions', file=sys.stderr)
        print(f'[{index}] {cpu_result}', file=sys.stderr)
        sys.exit(1)
    else:
      print(
          f'Failed to create two replica groups with {world_size} replicas',
          file=sys.stderr)

  else:
    print(f'{device} is not a TPU or GPU device', file=sys.stderr)
Esempio n. 13
0
    def compute(self) -> Tuple[torch.Tensor, float, float, float]:
        """Computes the AUC metric based on saved statistics."""
        targets = torch.cat(self.targets)
        scores = torch.cat(self.scores)

        # ddp hotfix, could be done better
        # but metric must handle DDP on it's own
        if self._ddp_backend == "xla":
            # if you have "RuntimeError: Aborted: Session XXX is not found" here
            # please, ask Google for a more powerful TPU setup ;)
            device = get_device()
            scores = xm.all_gather(scores.to(device)).cpu().detach()
            targets = xm.all_gather(targets.to(device)).cpu().detach()
        elif self._ddp_backend == "ddp":
            scores = torch.cat(all_gather(scores))
            targets = torch.cat(all_gather(targets))

        scores, targets, _ = process_multilabel_components(outputs=scores, targets=targets)
        per_class = auc(scores=scores, targets=targets)
        micro = binary_auc(scores=scores.view(-1), targets=targets.view(-1))[0]
        macro = per_class.mean().item()
        weights = targets.sum(axis=0) / len(targets)
        weighted = (per_class * weights).sum().item()
        return per_class, micro, macro, weighted
def _mp_fn(index):
    device = xm.xla_device()
    if xm.xla_device_hw(device) != 'CPU':
        ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)
        result = xm.all_gather(ordinal_tensor)

        cpu_result = result.cpu()
        expected = torch.arange(0, xm.xrt_world_size(), dtype=torch.float)
        if not cpu_result.allclose(expected):
            print('xm.all_gather() produced wrong reductions', file=sys.stderr)
            print('[{}] {}'.format(index, cpu_result), file=sys.stderr)
            sys.exit(1)
    else:
        print('Default device {} does not support replication'.format(device),
              file=sys.stderr)
Esempio n. 15
0
    def compute(self):
        """
        Compute metrics with accumulated statistics

        Returns:
            tuple of metrics: per_class, micro_metric, macro_metric,
                weighted_metric(None if weights is None)
        """
        per_class = []
        total_statistics = {}
        macro_metric = 0
        weighted_metric = 0
        # ddp hotfix, could be done better
        # but metric must handle DDP on it's own
        # TODO: optimise speed
        if self._ddp_backend == "xla":
            device = get_device()
            for _, statistics in self.statistics.items():
                for key in statistics:
                    value = torch.tensor([statistics[key]], device=device)
                    statistics[key] = xm.all_gather(value).sum(dim=0)
        elif self._ddp_backend == "ddp":
            for _, statistics in self.statistics.items():
                for key in statistics:
                    value: List[torch.Tensor] = all_gather(statistics[key])
                    value: torch.Tensor = torch.sum(torch.vstack(value), dim=0)
                    statistics[key] = value

        for class_idx, statistics in self.statistics.items():
            value = self.metric_fn(**statistics)
            per_class.append(value)
            macro_metric += value
            if self.weights is not None:
                weighted_metric += value * self.weights[class_idx]
            for stats_name, value in statistics.items():
                total_statistics[stats_name] = (
                    total_statistics.get(stats_name, 0) + value)

        macro_metric /= len(self.statistics)
        micro_metric = self.metric_fn(**total_statistics)

        if self.weights is None:
            weighted_metric = None
        if self.compute_per_class_metrics:
            return per_class, micro_metric, macro_metric, weighted_metric
        else:
            return [], micro_metric, macro_metric, weighted_metric
Esempio n. 16
0
def gather_tensor(tensor):
    world_size = get_world_size()

    if world_size < 2:
        return tensor

    with torch.no_grad():
        tensor_list = []

        if is_xla():
            tensor_list = xm.all_gather(tensor)
            tensor_list = tensor_list.view(world_size, *tensor.size())
        else:
            for _ in range(world_size):
                tensor_list.append(torch.zeros_like(tensor))
            dist.all_gather(tensor_list, tensor)
            tensor_list = torch.stack(tensor_list, dim=0)
    return tensor_list
Esempio n. 17
0
def all_gather(tensor, group, return_tensor=False):
    """Perform an all-gather operation."""
    if use_xla():
        result = xm.all_gather(tensor, groups=group[1])
        world_size = get_world_size(group=group)
        result = result.view(world_size, *tensor.size())
        if return_tensor:
            return result
        else:
            return [result[i] for i in range(world_size)]
    else:
        world_size = get_world_size(group=group)
        rank = get_rank(group=group)
        tensor_list = [
            tensor if i == rank else torch.empty_like(tensor) for i in range(world_size)
        ]
        dist.all_gather(tensor_list, tensor, group=group)
        if return_tensor:
            return torch.stack(tensor_list, dim=0)
        else:
            return tensor_list
Esempio n. 18
0
    def _train_one_epoch(self, loader):
        loader_time = .0
        train_time = .0
        curr_time = time.time()

        self.epoch_storage = defaultdict(list)
        for key in ['approx', 'target', 'loss', 'batch_metric']:
            self.epoch_storage[key] = []
        if self.fp16:
            scaler = amp.GradScaler()

        self.model.train()
        if self.progress_bar and self.rank == 0:
            iterator = enumerate(tqdm(loader, desc='train'))
        else:
            iterator = enumerate(loader)

        for batch_i, inputs in iterator:
            loader_time += time.time() - curr_time
            curr_time = time.time()

            self.optimizer.zero_grad()
            batches_done = len(loader) * (self.global_epoch-1) + batch_i
            inputs = [t.to(self.device) for t in inputs]

            if self.fp16:
                with amp.autocast():
                    loss = self.forward_train(self, inputs)
                loss = loss / self.grad_accumulations
                if self.parallel == 'ddp' and self.ddp_sync_last \
                    and batch_i != len(loader) - 1:
                    with self.model.no_sync():
                        scaler.scale(loss).backward()
                else: # sync only last batch
                    scaler.scale(loss).backward()
                if (batch_i + 1) % self.grad_accumulations == 0:
                    scaler.step(self.optimizer)
                    scaler.update()
            else:
                loss = self.forward_train(self, inputs)
                loss = loss / self.grad_accumulations
                if self.parallel == 'ddp' and self.ddp_sync_last \
                    and batch_i != len(loader) - 1:
                    with self.model.no_sync():
                        loss.backward()
                else:  # sync only last batch
                    loss.backward()

                if (batch_i + 1) % self.grad_accumulations == 0:
                    if self.xla:
                        xm.optimizer_step(self.optimizer, barrier=True)
                    else:
                        self.optimizer.step()
                    if self.batch_scheduler:
                        self.scheduler.step()

            if self.parallel == 'ddp' and self.ddp_average_loss:
                if self.xla:
                    loss_batch = xm.all_gather(
                        loss.detach().clone().view(1)).mean().item()
                else:
                    loss_batch = comm.gather_tensor(
                        loss.detach().clone().view(1)).mean().item()
            else:  # Use loss on device: 0
                loss_batch = loss.item()

            ''' TensorBoard logging '''
            learning_rate = [param_group['lr']
                             for param_group in self.optimizer.param_groups]
            logs = [
                ('batch_train_loss', loss_batch),
                ('batch_train_lr', learning_rate)
            ]
            if len(self.epoch_storage['batch_metric']) > 0:
                metric = self.epoch_storage['batch_metric'][-1]
                logs.append(('batch_valid_mertric', metric))
            self.tb_logger.list_of_scalars_summary(logs, batches_done)

            self.epoch_storage['loss'].append(loss_batch)

            train_time += time.time() - curr_time
            curr_time = time.time()

        if self.debug and self.rank == 0:
            self.logger(
                f'loader: {loader_time:.1f} s | train: {train_time:.1f} s')

        for key, val in self.epoch_storage.items():
            if len(val) > 0:
                if isinstance(val[0], torch.Tensor):
                    self.epoch_storage[key] = torch.cat(val)
                else:
                    self.epoch_storage[key] = torch.tensor(val).to(self.device)

        loss_total = self.epoch_storage['loss'].mean().item()

        if self.parallel == 'ddp':
            ''' Gather tensors '''
            for key, val in self.epoch_storage.items():
                if len(val) > 0:
                    if self.xla:
                        self.epoch_storage[key] = xm.all_gather(val)
                    else:
                        self.epoch_storage[key] = comm.gather_tensor(val)

            metric_total, monitor_metrics_total = self.evaluate_epoch(self)

        else:
            metric_total, monitor_metrics_total = self.evaluate_epoch(self)

        if metric_total is None:
            metric_total = loss_total

        ''' TensorBoard logging '''
        logs = [
            ('epoch_train_loss', loss_total),
            ('epoch_train_metric', metric_total),
        ]
        self.tb_logger.list_of_scalars_summary(logs, self.global_epoch)
        return loss_total, metric_total, monitor_metrics_total
Esempio n. 19
0
    def _valid_one_epoch(self, loader):
        self.epoch_storage = defaultdict(list)
        for key in ['approx', 'target', 'loss', 'batch_metric']:
            self.epoch_storage[key] = []

        self.model.eval()
        if self.progress_bar and self.rank == 0:
            iterator = enumerate(tqdm(loader, desc='valid'))
        else:
            iterator = enumerate(loader)

        with torch.no_grad():
            for batch_i, inputs in iterator:
                batches_done = len(loader) * (self.global_epoch - 1) + batch_i
                inputs = [t.to(self.device) for t in inputs]
                loss, approx = self.forward_train(self, inputs)
                self.evaluate_batch(self, inputs, approx)
                if self.parallel == 'ddp' and self.ddp_average_loss:
                    if self.xla:
                        loss_batch = xm.all_gather(
                            loss.detach().clone().view(1)).mean().item()
                    else:
                        loss_batch = comm.gather_tensor(
                            loss.detach().clone().view(1)).mean().item()
                else:  # Use loss on device: 0
                    loss_batch = loss.item()

                logs = [
                    ('batch_valid_loss', loss_batch),
                ]
                if len(self.epoch_storage['batch_metric']) > 0:
                    metric = self.epoch_storage['batch_metric'][-1]
                    logs.append(('batch_valid_mertric', metric))
                self.tb_logger.list_of_scalars_summary(logs, batches_done)
                self.epoch_storage['loss'].append(loss_batch)

        for key, val in self.epoch_storage.items():
            if len(val) > 0:
                if isinstance(val[0], torch.Tensor):
                    self.epoch_storage[key] = torch.cat(val)
                else:
                    self.epoch_storage[key] = torch.tensor(val).to(self.device)

        loss_total = self.epoch_storage['loss'].mean().item()

        if self.parallel == 'ddp':
            for key, val in self.epoch_storage.items():
                if len(val) > 0:
                    if self.xla:
                        self.epoch_storage[key] = xm.all_gather(val)
                    else:
                        self.epoch_storage[key] = comm.gather_tensor(val)

            metric_total, monitor_metrics_total = self.evaluate_epoch(self)

        else:
            metric_total, monitor_metrics_total = self.evaluate_epoch(self)

        if metric_total is None:
            metric_total = loss_total

        logs = [
            ('epoch_valid_loss', loss_total),
            ('epoch_valid_metric', metric_total),
        ]
        self.tb_logger.list_of_scalars_summary(logs, self.global_epoch)
        return loss_total, metric_total, monitor_metrics_total
Esempio n. 20
0
    def _train_one_epoch(self, loader):
        loader_time = .0
        train_time = .0
        curr_time = time.time()

        self.epoch_storage = defaultdict(list)
        for key in ['approx', 'target', 'loss', 'batch_metric']:
            self.epoch_storage[key] = []
        if self.fp16:
            scaler = amp.GradScaler()

        self.model.train()
        if self.progress_bar and self.rank == 0:
            iterator = enumerate(tqdm(loader, desc='train'))
        else:
            iterator = enumerate(loader)

        for batch_i, inputs in iterator:
            loader_time += time.time() - curr_time
            curr_time = time.time()

            self.optimizer.zero_grad()
            batches_done = len(loader) * (self.global_epoch - 1) + batch_i
            inputs = [t.to(self.device) for t in inputs]

            # forward and backward
            if self.fp16:
                with amp.autocast():
                    loss, approx = self.forward_train(self, inputs)
                    self.evaluate_batch(self, inputs, approx)  # evaluation
                loss = loss / self.grad_accumulations
                scaler.scale(loss).backward()
                if (batch_i + 1) % self.grad_accumulations == 0:
                    if self.sam:
                        # first step
                        optimizer_state = scaler._per_optimizer_states[id(
                            self.optimizer)]
                        scaler.unscale_(self.optimizer)
                        if not sum(v.item() for v in
                                   optimizer_state["found_inf_per_device"].
                                   values()):
                            self.optimizer.first_step(zero_grad=True)
                        optimizer_state["stage"] = 2
                        scaler.update()
                        # second step
                        with amp.autocast():
                            loss2, _ = self.forward_train(self, inputs)
                        scaler.scale(loss2).backward()
                        scaler.unscale_(self.optimizer)
                        if not sum(v.item() for v in
                                   optimizer_state["found_inf_per_device"].
                                   values()):
                            self.optimizer.second_step(zero_grad=True)
                        optimizer_state["stage"] = 2
                        scaler.update()
                    else:
                        scaler.step(self.optimizer)
                        scaler.update()
            else:
                loss, approx = self.forward_train(self, inputs)
                self.evaluate_batch(self, inputs, approx)  # evaluation
                loss = loss / self.grad_accumulations
                loss.backward()
                if (batch_i + 1) % self.grad_accumulations == 0:
                    if self.xla:
                        if self.sam:
                            raise RuntimeError(
                                'SAM optimizer on XLA device is not available.'
                            )
                        else:
                            xm.optimizer_step(self.optimizer, barrier=True)
                    else:
                        if self.sam:
                            self.optimizer.first_step(zero_grad=True)
                            loss2, _ = self.forward_train(self, inputs)
                            loss2.backward()
                            self.optimizer.second_step(zero_grad=True)
                        else:
                            self.optimizer.step()
                    if self.batch_scheduler:
                        self.scheduler.step()

            if self.parallel == 'ddp' and self.ddp_average_loss:
                if self.xla:
                    loss_batch = xm.all_gather(
                        loss.detach().clone().view(1)).mean().item()
                else:
                    loss_batch = comm.gather_tensor(
                        loss.detach().clone().view(1)).mean().item()
            else:  # Use loss on device: 0
                loss_batch = loss.item()

            # logging
            learning_rate = [
                param_group['lr']
                for param_group in self.optimizer.param_groups
            ]
            logs = [('batch_train_loss', loss_batch),
                    ('batch_train_lr', learning_rate)]
            if len(self.epoch_storage['batch_metric']) > 0:
                metric = self.epoch_storage['batch_metric'][-1]
                logs.append(('batch_valid_mertric', metric))
            self.tb_logger.list_of_scalars_summary(logs, batches_done)
            self.epoch_storage['loss'].append(loss_batch)

            train_time += time.time() - curr_time
            curr_time = time.time()

        if self.debug and self.rank == 0:
            self.logger(
                f'loader: {loader_time:.1f} s | train: {train_time:.1f} s')

        for key, val in self.epoch_storage.items():
            if len(val) > 0:
                if isinstance(val[0], torch.Tensor):
                    self.epoch_storage[key] = torch.cat(val)
                else:
                    self.epoch_storage[key] = torch.tensor(val).to(self.device)

        loss_total = self.epoch_storage['loss'].mean().item()

        if self.parallel == 'ddp':
            # gather tensors
            for key, val in self.epoch_storage.items():
                if len(val) > 0:
                    if self.xla:
                        self.epoch_storage[key] = xm.all_gather(val)
                    else:
                        self.epoch_storage[key] = comm.gather_tensor(val)

            metric_total, monitor_metrics_total = self.evaluate_epoch(self)

        else:
            metric_total, monitor_metrics_total = self.evaluate_epoch(self)

        if metric_total is None:
            metric_total = loss_total

        # logging
        logs = [
            ('epoch_train_loss', loss_total),
            ('epoch_train_metric', metric_total),
        ]
        self.tb_logger.list_of_scalars_summary(logs, self.global_epoch)
        return loss_total, metric_total, monitor_metrics_total
Esempio n. 21
0
    def _train_one_epoch(self, loader):
        loader_time = .0
        train_time = .0
        start_time = time.time()
        curr_time = time.time()

        self.epoch_storage = defaultdict(list)
        for key in ['approx', 'target', 'loss', 'batch_metric']:
            self.epoch_storage[key] = []
        if self.fp16:
            scaler = amp.GradScaler()

        self.model.train()
        if self.progress_bar and self.rank == 0:
            iterator = enumerate(tqdm(loader, desc='train'))
        else:
            iterator = enumerate(loader)
        batch_total = len(loader)
        ett_disp = False

        for batch_i, inputs in iterator:
            loader_time += time.time() - curr_time
            curr_time = time.time()
            elapsed_time = curr_time - start_time
            if self.rank == 0 and self.state['epoch'] == 0 and elapsed_time > 30 and not ett_disp: # show ETA
                ett = elapsed_time * batch_total // batch_i
                self.logger(f'Estimated epoch training time: {int(ett)} s')
                try:
                    ram_usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss // 1024
                    self.logger(f'Maximum RAM usage: {ram_usage} MB')
                except:
                    self.logger('Failed to get RAM usage.')
                try:
                    gram_usage = int(max(get_gpu_memory().values()))
                    self.logger(f'Maximum GRAM usage: {gram_usage} MB')
                except:
                    self.logger('Failed to get GRAM usage.')
                ett_disp = True

            self.optimizer.zero_grad()
            batches_done = batch_total * (self.global_epoch-1) + batch_i
            inputs = [t.to(self.device) for t in inputs]
            
            # forward and backward
            if self.fp16:
                with amp.autocast():
                    loss, approx = self.forward_train(self, inputs)
                    self.evaluate_batch(self, inputs, approx) # evaluation
                loss = loss / self.grad_accumulations
                scaler.scale(loss).backward()
                if self.clip_grad is not None:
                    dispatch_clip_grad(self.model.parameters(), self.max_grad_norm, mode=self.clip_grad)
                if (batch_i + 1) % self.grad_accumulations == 0:
                    if self.sam:
                        # first step
                        optimizer_state = scaler._per_optimizer_states[id(self.optimizer)]
                        scaler.unscale_(self.optimizer)
                        if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()):
                            self.optimizer.first_step(zero_grad=True)
                        optimizer_state["stage"] = 2
                        scaler.update()
                        # second step
                        with amp.autocast():
                            loss2, _ = self.forward_train(self, inputs)
                        scaler.scale(loss2).backward()
                        scaler.unscale_(self.optimizer)
                        if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()):
                            self.optimizer.second_step(zero_grad=True)
                        optimizer_state["stage"] = 2
                        scaler.update()
                    else:
                        scaler.step(self.optimizer)
                        scaler.update()
            else:
                loss, approx = self.forward_train(self, inputs)
                self.evaluate_batch(self, inputs, approx) # evaluation
                loss = loss / self.grad_accumulations
                loss.backward()
                if self.clip_grad is not None:
                    dispatch_clip_grad(self.model.parameters(), self.max_grad_norm, mode=self.clip_grad)
                if (batch_i + 1) % self.grad_accumulations == 0:
                    if self.xla:
                        if self.sam:
                            raise RuntimeError('SAM optimizer on XLA device is not available.')
                        else:
                            xm.optimizer_step(self.optimizer, barrier=True)
                    else:
                        if self.sam:
                            self.optimizer.first_step(zero_grad=True)
                            loss2, _ = self.forward_train(self, inputs)
                            loss2.backward()
                            self.optimizer.second_step(zero_grad=True)
                        else:
                            self.optimizer.step()
                    if self.batch_scheduler:
                        self.scheduler.step()
            
            if torch.isnan(loss).any():
                self.logger(f'{torch.isnan(loss).sum()} NaN detected in loss. ({batch_i}/{len(loader)})')
            if torch.isnan(approx).any():
                self.logger(f'{torch.isnan(approx).sum()} NaN detected in output tensor. ({batch_i}/{len(loader)})')
            
            if self.parallel == 'ddp' and self.ddp_average_loss:
                if self.xla:
                    loss_batch = xm.all_gather(
                        loss.detach().clone().view(1)).mean().item()
                else:
                    loss_batch = comm.gather_tensor(
                        loss.detach().clone().view(1)).mean().item()
            else:  # Use loss on device: 0
                loss_batch = loss.item()

            # logging
            learning_rate = [param_group['lr']
                             for param_group in self.optimizer.param_groups]
            logs = [
                ('batch_train_loss', loss_batch),
                ('batch_train_lr', learning_rate)
            ]
            if len(self.epoch_storage['batch_metric']) > 0:
                metric = self.epoch_storage['batch_metric'][-1]
                logs.append(('batch_valid_mertric', metric))
            self.tb_logger.list_of_scalars_summary(logs, batches_done)
            self.epoch_storage['loss'].append(loss_batch)

            train_time += time.time() - curr_time
            curr_time = time.time()

        if self.debug and self.rank == 0:
            self.logger(
                f'loader: {loader_time:.1f} s | train: {train_time:.1f} s')

        for key, val in self.epoch_storage.items():
            if len(val) > 0:
                if isinstance(val[0], torch.Tensor):
                    self.epoch_storage[key] = torch.cat(val)
                else:
                    self.epoch_storage[key] = torch.tensor(val).to(self.device)

        loss_total = self.epoch_storage['loss'].mean().item()

        if self.parallel == 'ddp':
            # gather tensors
            for key, val in self.epoch_storage.items():
                if len(val) > 0:
                    if self.xla:
                        self.epoch_storage[key] = xm.all_gather(val)
                    else:
                        self.epoch_storage[key] = comm.gather_tensor(val)

            metric_total, monitor_metrics_total = self.evaluate_epoch(self)

        else:
            metric_total, monitor_metrics_total = self.evaluate_epoch(self)

        if metric_total is None:
            metric_total = loss_total

        # logging
        logs = [
            ('epoch_train_loss', loss_total),
            ('epoch_train_metric', metric_total),
        ]
        self.tb_logger.list_of_scalars_summary(logs, self.global_epoch)
        return loss_total, metric_total, monitor_metrics_total
Esempio n. 22
0
 def forward(ctx, input, dim):
     ctx.dim = dim
     ctx.ordinal = xm.get_ordinal()
     ctx.world_size = xm.xrt_world_size()
     return xm.all_gather(input, dim=dim)
Esempio n. 23
0
 def all_gather(self, collectiveArgs, retFlag=False):
     retObj = xm.all_gather(collectiveArgs.ipTensor, dim=0)
     collectiveArgs.opTensor = retObj
     if retFlag:
         return retObj
Esempio n. 24
0
 def forward(ctx, x, dim):
     ctx.dim = dim
     tensor_list = xm.all_gather(x.unsqueeze(dim), dim=dim)
     return tensor_list