Exemplo n.º 1
0
    def apply_mask(self, x, padding_mask):
        B, T, C = x.shape
        if self.mask_prob > 0:
            mask_indices = compute_mask_indices(
                (B, T),
                padding_mask,
                self.mask_prob,
                self.mask_length,
                self.mask_selection,
                self.mask_other,
                min_masks=2,
                no_overlap=self.no_mask_overlap,
                min_space=self.mask_min_space,
            )
            mask_indices = torch.from_numpy(mask_indices).to(x.device)
            x = index_put(x, mask_indices, 0)

        if self.mask_channel_prob > 0:
            mask_channel_indices = compute_mask_indices(
                (B, C),
                None,
                self.mask_channel_prob,
                self.mask_channel_length,
                self.mask_channel_selection,
                self.mask_channel_other,
                no_overlap=self.no_mask_channel_overlap,
                min_space=self.mask_channel_min_space,
            )
            mask_channel_indices = (torch.from_numpy(mask_channel_indices).to(
                x.device).unsqueeze(1).expand(-1, T, -1))
            x = index_put(x, mask_channel_indices, 0)

        return x
Exemplo n.º 2
0
    def apply_mask_teacher(
        self,
        x,
        padding_mask,
        mask_indices=None,
        mask_channel_indices=None,
    ):
        B, T, C = x.shape

        if self.mask_channel_prob > 0 and self.mask_channel_before:
            mask_channel_indices = compute_mask_indices(
                (B, C),
                None,
                self.mask_channel_prob,
                self.mask_channel_length,
                self.mask_channel_selection,
                self.mask_channel_other,
                no_overlap=self.no_mask_channel_overlap,
                min_space=self.mask_channel_min_space,
            )
            mask_channel_indices = (torch.from_numpy(mask_channel_indices).to(
                x.device).unsqueeze(1).expand(-1, T, -1))
            x[mask_channel_indices] = 0

        if self.mask_prob > 0:
            if mask_indices is None:
                mask_indices = compute_mask_indices(
                    (B, T),
                    padding_mask,
                    self.mask_prob,
                    self.mask_length,
                    self.mask_selection,
                    self.mask_other,
                    min_masks=2,
                    no_overlap=self.no_mask_overlap,
                    min_space=self.mask_min_space,
                )
                mask_indices = torch.from_numpy(mask_indices).to(x.device)
            x = index_put(x, mask_indices, self.mask_emb_teacher)
        else:
            mask_indices = None

        if self.mask_channel_prob > 0 and not self.mask_channel_before:
            if mask_channel_indices is None:
                mask_channel_indices = compute_mask_indices(
                    (B, C),
                    None,
                    self.mask_channel_prob,
                    self.mask_channel_length,
                    self.mask_channel_selection,
                    self.mask_channel_other,
                    no_overlap=self.no_mask_channel_overlap,
                    min_space=self.mask_channel_min_space,
                )
                mask_channel_indices = (
                    torch.from_numpy(mask_channel_indices).to(
                        x.device).unsqueeze(1).expand(-1, T, -1))
            x = index_put(x, mask_channel_indices, 0)

        return x, mask_indices
Exemplo n.º 3
0
 def apply_seg_dropout(self, inp, mask_prob, mask_leng, mask_type,
                       mask_val):
     B, T = inp.size()
     if mask_prob > 0:
         mask_indices = compute_mask_indices(
             (B, T),
             None,
             mask_prob,
             mask_leng,
             mask_type  # may mask padding
         )
         mask_indices = torch.from_numpy(mask_indices).to(inp.device)
         inp[mask_indices] = mask_val
     else:
         mask_indices = torch.zeros_like(inp).bool()
     return inp, mask_indices