示例#1
0
def test_SpecAuc(apply_time_warp, apply_freq_mask, apply_time_mask):
    if not apply_time_warp and not apply_time_mask and not apply_freq_mask:
        with pytest.raises(ValueError):
            specaug = SpecAug(
                apply_time_warp=apply_time_warp,
                apply_freq_mask=apply_freq_mask,
                apply_time_mask=apply_time_mask,
            )
    else:
        specaug = SpecAug(
            apply_time_warp=apply_time_warp,
            apply_freq_mask=apply_freq_mask,
            apply_time_mask=apply_time_mask,
        )
        x = torch.randn(2, 1000, 80)
        specaug(x)
示例#2
0
def test_SpecAuc_repr(
    apply_time_warp,
    apply_freq_mask,
    apply_time_mask,
    time_mask_width_range,
    time_mask_width_ratio_range,
):
    if (
        (not apply_time_warp and not apply_time_mask and not apply_freq_mask)
        or (
            apply_time_mask
            and time_mask_width_range is None
            and time_mask_width_ratio_range is None
        )
        or (
            apply_time_mask
            and time_mask_width_range is not None
            and time_mask_width_ratio_range is not None
        )
    ):
        return
    specaug = SpecAug(
        apply_time_warp=apply_time_warp,
        apply_freq_mask=apply_freq_mask,
        apply_time_mask=apply_time_mask,
        time_mask_width_range=time_mask_width_range,
        time_mask_width_ratio_range=time_mask_width_ratio_range,
    )
    print(specaug)
示例#3
0
    def __init__(self, input_size: int, vocab_size: int,
                 token_list: Union[Tuple[str, ...],
                                   List[str]], device, config):
        super().__init__()

        self.sos = vocab_size - 1
        self.eos = vocab_size - 1
        self.vocab_size = vocab_size
        self.ignore_id = config.ignore_id
        self.ctc_weight = config.mtlalpha
        self.token_list = config.char_list.copy()

        self.specaug = SpecAug() if config.specaug else None
        self.normalize = UtteranceMVN() if config.normalize else None

        print(self.specaug, self.normalize)

        self.frontend = CustomFrontend(fs=SAMPLE_RATE,
                                       n_fft=512,
                                       normalized=True,
                                       hop_length=int(0.01 * SAMPLE_RATE),
                                       win_length=int(0.03 * SAMPLE_RATE),
                                       n_mels=80)

        self.model = E2E(input_size, vocab_size, config, device)
示例#4
0
def test_SpecAuc_repr(apply_time_warp, apply_freq_mask, apply_time_mask):
    if not apply_time_warp and not apply_time_mask and not apply_freq_mask:
        return
    specaug = SpecAug(
        apply_time_warp=apply_time_warp,
        apply_freq_mask=apply_freq_mask,
        apply_time_mask=apply_time_mask,
    )
    print(specaug)
示例#5
0
 def __init__(self, sos, eos, is_spec_aug=False, dtype='FLOAT16'):
     '''
     Args:
         sos: sos token id
         eos: eos token id
         is_spec_aug(bool): Whether to use feature spec augment
     '''
     self.sos = sos
     self.eos = eos
     self.is_spec_aug = is_spec_aug
     self.spec_aug = SpecAug()
     self.dtype = dtype
示例#6
0
def test_SpecAuc(
    apply_time_warp,
    apply_freq_mask,
    apply_time_mask,
    time_mask_width_range,
    time_mask_width_ratio_range,
):
    if (
        (not apply_time_warp and not apply_time_mask and not apply_freq_mask)
        or (
            apply_time_mask
            and time_mask_width_range is None
            and time_mask_width_ratio_range is None
        )
        or (
            apply_time_mask
            and time_mask_width_range is not None
            and time_mask_width_ratio_range is not None
        )
    ):
        with pytest.raises(ValueError):
            specaug = SpecAug(
                apply_time_warp=apply_time_warp,
                apply_freq_mask=apply_freq_mask,
                apply_time_mask=apply_time_mask,
                time_mask_width_range=time_mask_width_range,
                time_mask_width_ratio_range=time_mask_width_ratio_range,
            )
    else:
        specaug = SpecAug(
            apply_time_warp=apply_time_warp,
            apply_freq_mask=apply_freq_mask,
            apply_time_mask=apply_time_mask,
            time_mask_width_range=time_mask_width_range,
            time_mask_width_ratio_range=time_mask_width_ratio_range,
        )
        x = torch.randn(2, 1000, 80)
        specaug(x)
示例#7
0
def get_specaug():
    return SpecAug(
        apply_time_warp=True,
        apply_freq_mask=True,
        apply_time_mask=False,
    )