Ejemplo n.º 1
0
def _ddp_test_fn(rank, worldsize, add_offset: bool, reduction_mean=False):
    _setup_ddp(rank, worldsize)
    if add_offset:
        tensor = torch.tensor([float(rank)])
    else:
        tensor = torch.tensor([1.], )
    if reduction_mean:
        reduced_tensor = _sync_ddp_if_available(tensor, reduce_op='avg')

        manual_reduction = sum([i for i in range(dist.get_world_size())]) / dist.get_world_size()
        assert reduced_tensor.item() == manual_reduction
    else:
        reduced_tensor = _sync_ddp_if_available(tensor)

        assert reduced_tensor.item() == dist.get_world_size(), \
            'Sync-Reduce does not work properly with DDP and Tensors'
def test_sync_reduce_simple():
    """Make sure sync-reduce works without DDP"""
    tensor = torch.tensor([1.], device='cpu')

    reduced_tensor = _sync_ddp_if_available(tensor)

    assert torch.allclose(tensor, reduced_tensor), \
        'Sync-Reduce does not work properly without DDP and Tensors'
def _ddp_test_fn(rank, worldsize):
    _setup_ddp(rank, worldsize)
    tensor = torch.tensor([1.], device='cuda:0')

    reduced_tensor = _sync_ddp_if_available(tensor)

    assert reduced_tensor.item() == dist.get_world_size(), \
        'Sync-Reduce does not work properly with DDP and Tensors'
Ejemplo n.º 4
0
    def log(self,
            name: str,
            value: Any,
            prog_bar: bool = False,
            logger: bool = True,
            on_step: bool = False,
            on_epoch: bool = True,
            reduce_fx: Callable = torch.mean,
            enable_graph: bool = False,
            sync_ddp: bool = False,
            sync_ddp_op: Union[Any, str] = 'mean',
            sync_ddp_group: Optional[Any] = None):
        # no metrics should be logged with graphs
        if not enable_graph and isinstance(value, torch.Tensor):
            value = value.detach()

        # sync across ddp
        if sync_ddp and isinstance(value, (torch.Tensor, numbers.Number)):
            value = _sync_ddp_if_available(value,
                                           group=sync_ddp_group,
                                           reduce_op=sync_ddp_op)

        if 'meta' not in self:
            self.__setitem__('meta', {})

        # if user requests both step and epoch, then we split the metric in two automatically
        # one will be logged per step. the other per epoch
        if on_step and on_epoch:
            # set step version
            step_name = f'step_{name}'
            self.__set_meta(step_name,
                            value,
                            prog_bar,
                            logger,
                            on_step=True,
                            on_epoch=False,
                            reduce_fx=reduce_fx)
            self.__setitem__(step_name, value)

            # set epoch version
            epoch_name = f'epoch_{name}'
            self.__set_meta(epoch_name,
                            value,
                            prog_bar,
                            logger,
                            on_step=False,
                            on_epoch=True,
                            reduce_fx=reduce_fx)
            self.__setitem__(epoch_name, value)
        else:
            self.__set_meta(name, value, prog_bar, logger, on_step, on_epoch,
                            reduce_fx)

            # set the value
            self.__setitem__(name, value)
Ejemplo n.º 5
0
 def validation_epoch_end(self, outputs):
     """Compute and log training loss and accuracy at the epoch level."""
     metrics = {}
     
     for key in outputs[0].keys():
         metrics[key] = torch.stack([output[key] for output in outputs]).mean()
         metrics[key] = _sync_ddp_if_available(metrics[key], reduce_op='avg')
                     
     metrics['step'] = self.current_epoch    
         
     return {'log': metrics}
Ejemplo n.º 6
0
 def validation_epoch_end(self, outputs):
     """Compute and log training loss and accuracy at the epoch level.
     Average statistics accross GPUs in case of DDP
     """
     keys = outputs[0].keys()          
     metrics = {}
     for metric_name in keys:
         metrics[metric_name] = _sync_ddp_if_available(torch.stack([output[metric_name] for output in outputs]).mean(), reduce_op='avg')
                     
     metrics['step'] = self.current_epoch    
         
     return {'log': metrics}