Exemplo n.º 1
0
 def get_output_fn(sound, args):
     output = kaldi.resample_waveform(sound, args[1], args[2])
     return output
Exemplo n.º 2
0
 def test_resample_waveform_upsample_size(self):
     sound, sample_rate = torchaudio.load_wav(self.test_8000_filepath)
     upsample_sound = kaldi.resample_waveform(sound, sample_rate,
                                              sample_rate * 2)
     self.assertTrue(upsample_sound.size(-1) == sound.size(-1) * 2)
Exemplo n.º 3
0
 def test_resample_waveform_identity_size(self):
     sound, sample_rate = torchaudio.load_wav(self.test_8000_filepath)
     downsample_sound = kaldi.resample_waveform(sound, sample_rate,
                                                sample_rate)
     self.assertTrue(downsample_sound.size(-1) == sound.size(-1))
Exemplo n.º 4
0
 def test_resample_waveform_identity_size(self):
     downsample_sound = kaldi.resample_waveform(self.test1_signal,
                                                self.test1_signal_sr,
                                                self.test1_signal_sr)
     self.assertTrue(
         downsample_sound.size(-1) == self.test1_signal.size(-1))
Exemplo n.º 5
0
 def test_resample_waveform_upsample_size(self):
     upsample_sound = kaldi.resample_waveform(self.test1_signal,
                                              self.test1_signal_sr,
                                              self.test1_signal_sr * 2)
     self.assertTrue(
         upsample_sound.size(-1) == self.test1_signal.size(-1) * 2)
Exemplo n.º 6
0
 def get_output_fn(sound, args):
     output = kaldi.resample_waveform(sound.to(torch.float32), args[1],
                                      args[2])
     return output
    def get_activation_inner(audio):
        # assert sr == _AUDIO_RATE, f"Hardcoded to sample rate==16000, got {sr}"
        model_srate = 16000
        if sr != model_srate:
            # resample audio if necessary
            from torchaudio.compliance.kaldi import resample_waveform
            if audio.dim() == 1:
                audio = audio.unsqueeze(0)
            if not sampler:
                audio = resample_waveform(audio, sr, model_srate).squeeze()
            else:
                audio = sampler(audio).squeeze()


        if audio.dim() == 3 or audio.dim() == 2:
            audio = audio.squeeze()
        # if audio.dim() == 2:
        #     audio = audio.mean(dim=0)  # make mono
        audio = audio.float()

        # pad so that frames are centered around their timestamps (i.e. first frame
        # is zero centered).
        if center:
            # audio = np.pad(audio, 512, mode='constant', constant_values=0)
            import torch.nn.functional as F
            audio = F.pad(audio, [512, 512])
            # assert False, "Center pad not supported"

        # make 1024-sample frames of the audio with hop length of 10 milliseconds
        hop_length = int(model_srate * step_size / 1000)
        n_frames = 1 + int((audio.shape[-1] - 1024) / hop_length)
        if len(audio.shape) > 1:
            frames = []
            for channel in audio:
                frames_iter = torch.as_strided(
                    channel,
                    size=(1024, n_frames),
                    stride=(1, hop_length)
                )
                frames.append(frames_iter)
            frames = torch.cat(frames, dim=1)
        else:
            frames = torch.as_strided(
                audio,
                size=(1024, n_frames),
                stride=(1, hop_length)
            )
        frames = frames.transpose(0, 1).contiguous()

        # normalize each frame -- this is expected by the model
        frames_mean = frames.mean(dim=1).unsqueeze(1)
        frames = frames - frames_mean

        frames_std = frames.std(dim=1).detach()
        frames_std_ = torch.ones(frames_std.shape).to(frames)
        frames_std_[frames_std != 0] = frames_std[frames_std != 0]
        frames_std_ = frames_std_.unsqueeze(1)
        frames = frames / frames_std_

        # run prediction and convert the frequency bin weights to Hz
        return model(frames.view(frames.shape[0], 1, -1), layer=layer)