コード例 #1
0
class FeatureCreator(nn.Module):

    def __init__(self):
        super(FeatureCreator, self).__init__()
        self.stft = STFT(FILTER_LENGTH, HOP_LENGTH).cuda(CUDA_ID[0])
        self.label_helper = LabelHelper()

    def forward(self, batch_info):
        batch_info.mix = batch_info.mix.cuda(CUDA_ID[0])
        batch_info.speech = batch_info.speech.cuda(CUDA_ID[0])
        batch_info.noise = batch_info.noise.cuda(CUDA_ID[0])

        mix_spec = self.stft.transform(batch_info.mix)
        speech_spec = self.stft.transform(batch_info.speech)
        noise_spec = self.stft.transform(batch_info.noise)

        mix_real = mix_spec[:, :, :, 0]
        mix_imag = mix_spec[:, :, :, 1]

        mix_mag = torch.sqrt(mix_real ** 2 + mix_imag ** 2)
        # mix_mag = torch.log(mix_mag)
        mix_mag = mix_mag.unsqueeze(1)
        label = self.label_helper(speech_spec, noise_spec)
        # label = torch.load(label)

        return mix_mag, label, batch_info.nframe
コード例 #2
0
 def __init__(self, scp_file, window='hann', nfft=256, window_length=256, hop_length=64, center=False, is_mag=True, is_log=True, chunk_size=32000, least=16000):
     self.wave = ut.read_scp(scp_file)
     self.wave_keys = [key for key in self.wave.keys()]
     self.STFT = STFT(window=window, nfft=nfft,
                      window_length=window_length, hop_length=hop_length, center=center)
     self.is_mag = is_mag
     self.is_log = is_log
     self.samp_list = []
     self.samp_stft = []
     self.chunk_size = chunk_size
     self.least = least
     self.split()
     self.stft()
コード例 #3
0
 def __init__(self,
              scp_file,
              window='hann',
              nfft=256,
              window_length=256,
              hop_length=64,
              center=False,
              is_mag=True,
              is_log=True):
     self.wave = ut.read_scp(scp_file)
     self.wave_keys = [key for key in self.wave.keys()]
     self.STFT = STFT(window=window,
                      nfft=nfft,
                      window_length=window_length,
                      hop_length=hop_length,
                      center=center)
     self.is_mag = is_mag
     self.is_log = is_log
コード例 #4
0
class AudioData(object):
    '''
        Loading wave file
        scp_file: the scp file path
        other kwargs is stft's kwargs
        is_mag: if True, abs(stft)
    '''

    def __init__(self, scp_file, window='hann', nfft=256, window_length=256, hop_length=64, center=False, is_mag=True, is_log=True, chunk_size=32000, least=16000):
        self.wave = ut.read_scp(scp_file)
        self.wave_keys = [key for key in self.wave.keys()]
        self.STFT = STFT(window=window, nfft=nfft,
                         window_length=window_length, hop_length=hop_length, center=center)
        self.is_mag = is_mag
        self.is_log = is_log
        self.samp_list = []
        self.samp_stft = []
        self.chunk_size = chunk_size
        self.least = least
        self.split()
        self.stft()

    def __len__(self):
        return len(self.wave_keys)

    def split(self):
        for key in self.wave_keys:
            wave_path = self.wave[key]
            samp = ut.read_wav(wave_path)
            length = samp.shape[0]
            if length < self.least:
                continue
            if length < self.chunk_size:
                gap = self.chunk_size-length
                
                samp = np.pad(samp, (0, gap), mode='constant')
                self.samp_list.append(samp)
            else:
                random_start = 0
                while True:
                    if random_start+self.chunk_size > length:
                        break
                    self.samp_list.append(
                        samp[random_start:random_start+self.chunk_size])
                    random_start += self.least

    def stft(self):
        for samp in self.samp_list:
            self.samp_stft.append(self.STFT.stft(
                samp, is_mag=True, is_log=True))

    def __iter__(self):
        for stft in self.samp_stft:
            yield stft

    def __getitem__(self, index):
        return self.samp_stft[index]
コード例 #5
0
 def __init__(self, root_dir):
     """
     初始化dataset,读入文件的list
     :param root_dir: 文件的根目录
     :param type: 暂时未用
     :param transform: 暂时未用
     """
     # 初始化变量
     self.stft = STFT(filter_length=FILTER_LENGTH, hop_length=HOP_LENGTH)
     self.root_dir = root_dir
     self.files = os.listdir(root_dir)
コード例 #6
0
    def run(self):
        stft_settings = {'window': self.opt['datasets']['audio_setting']['window'],
                         'nfft': self.opt['datasets']['audio_setting']['nfft'],
                         'window_length': self.opt['datasets']['audio_setting']['window_length'],
                         'hop_length': self.opt['datasets']['audio_setting']['hop_length'],
                         'center': self.opt['datasets']['audio_setting']['center']}

        stft_istft = STFT(**stft_settings)
        index = 0
        for spec_m, spec_l, spec_r in tqdm(self.waves):
            # log spk_spectrogram
            EPSILON = np.finfo(np.float32).eps
            log_spec = np.log(np.maximum(np.abs(spec_m), EPSILON))

            # apply cmvn 
            cmvn = pickle.load(open(self.opt['datasets']['dataloader_setting']['cmvn_file'],'rb'))
            cmvn_wave = util.apply_cmvn(log_spec,cmvn)

            # calculate non silent
            non_silent = util.compute_non_silent(log_spec).astype(np.bool)
            
            target_mask = self._cluster(cmvn_wave, non_silent)
            for i in range(len(target_mask)):
                name = self.keys[index]
                spk_spectrogram_l = target_mask[i] * spec_l
                spk_spectrogram_r = target_mask[i] * spec_r
                i_stft_l = stft_istft.istft(spk_spectrogram_l)
                i_stft_r = stft_istft.istft(spk_spectrogram_r)

                i_stft = np.concatenate((np.reshape(i_stft_l,(1,-1)), np.reshape(-1*i_stft_r,(1,-1))), axis=0)
                #output_file = os.path.join(
                #    self.save_file, self.opt['name'], 'spk'+str(i+1))
                output_file = self.save_file
                os.makedirs(output_file, exist_ok=True)
                
                #librosa.output.write_wav(output_file+'/'+name, i_stft, 8000)
                sf.write(output_file+'/'+name[:-4]+'_'+str(i+1)+'.wav', i_stft.T, 8000, 'PCM_16')
            index+=1
        print('Processing {} utterances'.format(index))
コード例 #7
0
def validation(path, net):
    net.eval()
    files = os.listdir(path)
    pesq_unprocess = 0
    pesq_res = 0
    bar = progressbar.ProgressBar(0, len(files))
    stft = STFT(filter_length=FILTER_LENGTH, hop_length=HOP_LENGTH)
    for i in range(len(files)):
        bar.update(i)
        with torch.no_grad():
            speech = loadmat(path + files[i])['speech']
            noise = loadmat(path + files[i])['noise']
            mix = speech + noise

            sf.write('clean.wav', speech, 16000)
            sf.write('mix.wav', mix, 16000)

            c = get_alpha(mix)
            mix *= c
            speech *= c
            noise *= c

            speech = stft.transform(torch.Tensor(speech.T).cuda(CUDA_ID[0]))
            mix = stft.transform(torch.Tensor(mix.T).cuda(CUDA_ID[0]))
            noise = stft.transform(torch.Tensor(noise.T).cuda(CUDA_ID[0]))

            mix_real = mix[:, :, :, 0]
            mix_imag = mix[:, :, :, 1]
            mix_mag = torch.sqrt(mix_real**2 + mix_imag**2)

            # mix_(T,F)
            mix_mag = mix_mag.unsqueeze(0).cuda(CUDA_ID[0])
            # output(1, T, F)

            mapping_out = net(mix_mag)

            res_real = mapping_out * mix_real / mix_mag.squeeze(0)
            res_imag = mapping_out * mix_imag / mix_mag.squeeze(0)

            res = torch.stack([res_real, res_imag], 3)
            output = stft.inverse(res)

            output = output.permute(1, 0).detach().cpu().numpy()

            # 写入的必须是(F,T)istft之后的
            sf.write('est.wav', output / c, 16000)
            try:
                p1 = pesq('clean.wav', 'mix.wav', 16000)
                p2 = pesq('clean.wav', 'est.wav', 16000)
            except:
                print('wrong test item : ' + files[i])
                pass
            pesq_unprocess += p1[0]
            pesq_res += p2[0]

    bar.finish()
    net.train()
    return [pesq_unprocess / len(files), pesq_res / len(files)]
コード例 #8
0
    def run(self):
        stft_settings = {
            'window': self.opt['audio_setting']['window'],
            'nfft': self.opt['audio_setting']['nfft'],
            'window_length': self.opt['audio_setting']['window_length'],
            'hop_length': self.opt['audio_setting']['hop_length'],
            'center': self.opt['audio_setting']['center']
        }

        stft_istft = STFT(**stft_settings)
        index = 0
        for wave in tqdm(self.waves):
            # log spk_spectrogram
            EPSILON = np.finfo(np.float32).eps
            log_wave = np.log(np.maximum(np.abs(wave), EPSILON))

            # apply cmvn
            cmvn = pickle.load(open(self.opt['cmvn_file'], 'rb'))
            cmvn_wave = util.apply_cmvn(log_wave, cmvn)

            # calculate non silent
            non_silent = util.compute_non_silent(log_wave).astype(np.bool)

            target_mask = self._cluster(cmvn_wave, non_silent)
            for i in range(len(target_mask)):
                name = self.keys[index]
                spk_spectrogram = target_mask[i] * wave
                i_stft = stft_istft.istft(spk_spectrogram)
                output_file = os.path.join(self.save_file, self.opt['name'],
                                           'spk' + str(i + 1))
                os.makedirs(output_file, exist_ok=True)

                librosa.output.write_wav(output_file + '/' + name, i_stft,
                                         8000)
            index += 1
        print('Processing {} utterances'.format(index))
コード例 #9
0
class AudioData(object):
    '''
        Loading wave file
        scp_file: the scp file path
        other kwargs is stft's kwargs
        is_mag: if True, abs(stft)
    '''
    def __init__(self,
                 scp_file,
                 window='hann',
                 nfft=256,
                 window_length=256,
                 hop_length=64,
                 center=False,
                 is_mag=True,
                 is_log=True):
        self.wave = ut.read_scp(scp_file)
        self.wave_keys = [key for key in self.wave.keys()]
        self.STFT = STFT(window=window,
                         nfft=nfft,
                         window_length=window_length,
                         hop_length=hop_length,
                         center=center)
        self.is_mag = is_mag
        self.is_log = is_log

    def __len__(self):
        return len(self.wave_keys)

    def stft(self, wave_path):
        samp = ut.read_wav(wave_path)
        return self.STFT.stft(samp, self.is_mag, self.is_log)

    def __iter__(self):
        for key in self.wave_keys:
            yield self.stft(self.wave[key])

    def __getitem__(self, key):
        if key not in self.wave_keys:
            raise ValueError
        return self.stft(self.wave[key])
コード例 #10
0
 def __init__(self):
     super(LabelHelper, self).__init__()
     self.stft = STFT(FILTER_LENGTH, HOP_LENGTH)
コード例 #11
0
 def __init__(self):
     super(FeatureCreator, self).__init__()
     self.stft = STFT(FILTER_LENGTH, HOP_LENGTH).cuda(CUDA_ID[0])
     self.label_helper = LabelHelper()
コード例 #12
0
    def __init__(self):
        super(CRNN, self).__init__()
        self.stft = STFT(FILTER_LENGTH, HOP_LENGTH)
        # Encoder
        self.conv1 = nn.Conv2d(in_channels=1,
                               out_channels=16,
                               kernel_size=(1, 3),
                               stride=(1, 2))
        self.bn1 = nn.BatchNorm2d(num_features=16)
        self.conv2 = nn.Conv2d(in_channels=16,
                               out_channels=32,
                               kernel_size=(1, 3),
                               stride=(1, 2))
        self.bn2 = nn.BatchNorm2d(num_features=32)
        self.conv3 = nn.Conv2d(in_channels=32,
                               out_channels=64,
                               kernel_size=(1, 3),
                               stride=(1, 2))
        self.bn3 = nn.BatchNorm2d(num_features=64)
        self.conv4 = nn.Conv2d(in_channels=64,
                               out_channels=128,
                               kernel_size=(1, 3),
                               stride=(1, 2))
        self.bn4 = nn.BatchNorm2d(num_features=128)
        self.conv5 = nn.Conv2d(in_channels=128,
                               out_channels=256,
                               kernel_size=(1, 3),
                               stride=(1, 2))
        self.bn5 = nn.BatchNorm2d(num_features=256)

        # LSTM
        self.LSTM1 = nn.LSTM(input_size=1024,
                             hidden_size=1024,
                             num_layers=2,
                             batch_first=True)

        # Decoder
        self.convT1 = nn.ConvTranspose2d(in_channels=512,
                                         out_channels=128,
                                         kernel_size=(1, 3),
                                         stride=(1, 2))
        self.bnT1 = nn.BatchNorm2d(num_features=128)
        self.convT2 = nn.ConvTranspose2d(in_channels=256,
                                         out_channels=64,
                                         kernel_size=(1, 3),
                                         stride=(1, 2))
        self.bnT2 = nn.BatchNorm2d(num_features=64)
        self.convT3 = nn.ConvTranspose2d(in_channels=128,
                                         out_channels=32,
                                         kernel_size=(1, 3),
                                         stride=(1, 2))
        self.bnT3 = nn.BatchNorm2d(num_features=32)
        # output_padding为1,不然算出来是79
        self.convT4 = nn.ConvTranspose2d(in_channels=64,
                                         out_channels=16,
                                         kernel_size=(1, 3),
                                         stride=(1, 2),
                                         output_padding=(0, 1))
        self.bnT4 = nn.BatchNorm2d(num_features=16)
        self.convT5 = nn.ConvTranspose2d(in_channels=32,
                                         out_channels=1,
                                         kernel_size=(1, 3),
                                         stride=(1, 2))
        self.bnT5 = nn.BatchNorm2d(num_features=1)