Beispiel #1
0
def test_sync_reduce_simple():
    """Make sure sync-reduce works without DDP"""
    tensor = torch.tensor([1.], device='cpu')

    reduced_tensor = sync_ddp_if_available(tensor)

    assert torch.allclose(tensor, reduced_tensor), \
        'Sync-Reduce does not work properly without DDP and Tensors'
Beispiel #2
0
def _ddp_test_fn(rank, worldsize, add_offset: bool, reduction_mean=False):
    _setup_ddp(rank, worldsize)
    if add_offset:
        tensor = torch.tensor([float(rank)])
    else:
        tensor = torch.tensor([1.], )
    if reduction_mean:
        reduced_tensor = sync_ddp_if_available(tensor, reduce_op='avg')

        manual_reduction = sum([i for i in range(dist.get_world_size())
                                ]) / dist.get_world_size()
        assert reduced_tensor.item() == manual_reduction
    else:
        reduced_tensor = sync_ddp_if_available(tensor)

        assert reduced_tensor.item() == dist.get_world_size(), \
            'Sync-Reduce does not work properly with DDP and Tensors'
    def log(
        self,
        name: str,
        value: Any,
        prog_bar: bool = False,
        logger: bool = True,
        on_step: bool = False,
        on_epoch: bool = True,
        reduce_fx: Callable = torch.mean,
        tbptt_reduce_fx: Callable = torch.mean,
        tbptt_pad_token: int = 0,
        enable_graph: bool = False,
        sync_dist: bool = False,
        sync_dist_op: Union[Any, str] = 'mean',
        sync_dist_group: Optional[Any] = None,
    ):
        # no metrics should be logged with graphs
        if not enable_graph and isinstance(value, torch.Tensor):
            value = value.detach()

        # sync across ddp
        if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)):
            value = sync_ddp_if_available(value,
                                          group=sync_dist_group,
                                          reduce_op=sync_dist_op)

        if 'meta' not in self:
            self.__setitem__('meta', {})

        # if user requests both step and epoch, then we split the metric in two automatically
        # one will be logged per step. the other per epoch
        was_forked = False
        if on_step and on_epoch:
            was_forked = True

            # set step version
            step_name = f'step_{name}'
            self.__set_meta(step_name,
                            value,
                            prog_bar,
                            logger,
                            on_step=True,
                            on_epoch=False,
                            reduce_fx=reduce_fx,
                            tbptt_reduce_fx=tbptt_reduce_fx,
                            tbptt_pad_token=tbptt_pad_token,
                            forked=False)
            self.__setitem__(step_name, value)

            # set epoch version
            epoch_name = f'epoch_{name}'
            self.__set_meta(epoch_name,
                            value,
                            prog_bar,
                            logger,
                            on_step=False,
                            on_epoch=True,
                            reduce_fx=reduce_fx,
                            tbptt_reduce_fx=tbptt_reduce_fx,
                            tbptt_pad_token=tbptt_pad_token,
                            forked=False)
            self.__setitem__(epoch_name, value)

        # always log the original metric
        self.__set_meta(name,
                        value,
                        prog_bar,
                        logger,
                        on_step,
                        on_epoch,
                        reduce_fx,
                        tbptt_reduce_fx=tbptt_reduce_fx,
                        tbptt_pad_token=tbptt_pad_token,
                        forked=was_forked)

        # set the value
        self.__setitem__(name, value)