Exemplo n.º 1
0
 def reduce(self,
            output,
            group: Optional[Any] = None,
            reduce_op: Optional[Union[ReduceOp, str]] = None):
     if isinstance(output, torch.Tensor):
         output = sync_ddp_if_available(output, group, reduce_op)
     return output
Exemplo n.º 2
0
    def sync_tensor(self,
                    tensor: Union[torch.Tensor],
                    group: Optional[Any] = None,
                    reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
        """

        """
        return sync_ddp_if_available(tensor, group, reduce_op)
Exemplo n.º 3
0
    def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean") -> torch.Tensor:
        """Reduces a tensor from several distributed processes to one aggregated tensor.

        Args:
            tensor: the tensor to sync and reduce
            group: the process group to gather results from. Defaults to all processes (world)
            reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
                Can also be a string 'sum' to calculate the sum during reduction.

        Return:
            reduced value, except when the input was not a tensor the output remains is unchanged
        """
        if isinstance(tensor, torch.Tensor):
            tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
        return tensor
Exemplo n.º 4
0
    def log(
        self,
        name: str,
        value: Any,
        prog_bar: bool = False,
        logger: bool = True,
        on_step: bool = False,
        on_epoch: bool = True,
        reduce_fx: Callable = torch.mean,
        tbptt_reduce_fx: Callable = torch.mean,
        tbptt_pad_token: int = 0,
        enable_graph: bool = False,
        sync_dist: bool = False,
        sync_dist_op: Union[Any, str] = 'mean',
        sync_dist_group: Optional[Any] = None,
    ):
        # no metrics should be logged with graphs
        if not enable_graph and isinstance(value, torch.Tensor):
            value = value.detach()

        # sync across ddp
        if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)):
            value = sync_ddp_if_available(value, group=sync_dist_group, reduce_op=sync_dist_op)

        if 'meta' not in self:
            self.__setitem__('meta', {})

        # if user requests both step and epoch, then we split the metric in two automatically
        # one will be logged per step. the other per epoch
        was_forked = False
        if on_step and on_epoch:
            was_forked = True

            # set step version
            step_name = f'step_{name}'
            self.__set_meta(
                step_name,
                value,
                prog_bar,
                logger,
                on_step=True,
                on_epoch=False,
                reduce_fx=reduce_fx,
                tbptt_reduce_fx=tbptt_reduce_fx,
                tbptt_pad_token=tbptt_pad_token,
                forked=False
            )
            self.__setitem__(step_name, value)

            # set epoch version
            epoch_name = f'epoch_{name}'
            self.__set_meta(
                epoch_name,
                value,
                prog_bar,
                logger,
                on_step=False,
                on_epoch=True,
                reduce_fx=reduce_fx,
                tbptt_reduce_fx=tbptt_reduce_fx,
                tbptt_pad_token=tbptt_pad_token,
                forked=False
            )
            self.__setitem__(epoch_name, value)

        # always log the original metric
        self.__set_meta(
            name,
            value,
            prog_bar,
            logger,
            on_step,
            on_epoch,
            reduce_fx,
            tbptt_reduce_fx=tbptt_reduce_fx,
            tbptt_pad_token=tbptt_pad_token,
            forked=was_forked
        )

        # set the value
        self.__setitem__(name, value)
Exemplo n.º 5
0
    def __init__(
        self,
        punct=True,
        stresses=False,
        spaces=True,
        chars=False,
        *,
        space=' ',
        silence=None,
        apostrophe=True,
        oov=Base.OOV,
        sep='|',  # To be able to distinguish between 2/3 letters codes.
        add_blank_at="last_but_one",
        pad_with_space=False,
        improved_version_g2p=False,
        phoneme_dict_path=None,
    ):
        labels = []
        self.space, labels = len(labels), labels + [space]  # Space

        if silence is not None:
            self.silence, labels = len(labels), labels + [silence]  # Silence

        labels.extend(self.CONSONANTS)
        vowels = list(self.VOWELS)

        if stresses:
            vowels = [
                f'{p}{s}' for p, s in itertools.product(vowels, (0, 1, 2))
            ]
        labels.extend(vowels)

        if chars:
            labels.extend(string.ascii_lowercase)

        if apostrophe:
            labels.append("'")  # Apostrophe

        if punct:
            labels.extend(self.PUNCT)

        super().__init__(labels, oov=oov, sep=sep, add_blank_at=add_blank_at)

        self.punct = punct
        self.stresses = stresses
        self.spaces = spaces
        self.pad_with_space = pad_with_space

        download_corpora()
        _ = sync_ddp_if_available(
            torch.tensor(0))  # Barrier until rank 0 downloads the corpora

        # g2p_en tries to run download_corpora() on import but it is not rank zero guarded
        import g2p_en  # noqa pylint: disable=import-outside-toplevel

        _g2p = g2p_en.G2p()
        _g2p.variables = None

        if improved_version_g2p:
            self.g2p = G2p(_g2p, phoneme_dict_path)
        else:
            self.g2p = _g2p