示例#1
0
    def validation_step(self, batch, batch_idx):
        """
        A validation step that also calculates the STOI/PESQ scores,
        and the scored for high repetition count (repeat_validation argument)
        """
        self.set_operation_mode(OperationMode.validation)
        val_repeats = self._cfg.train_params.repeat_validation

        x, mag, max_length, y, T_ys, length, path_speech = batch
        output_loss, output, _ = self(x=x,
                                      mag=mag,
                                      max_length=max_length,
                                      repeats=1)
        _, output_x, _ = self(x=x,
                              mag=mag,
                              max_length=max_length,
                              repeats=val_repeats)

        loss = self.calc_loss(output_loss, y, T_ys)
        cnt = x.shape[0]

        stoi, pesq, stoi_x, pesq_x = (0.0, 0.0, 0.0, 0.0)
        for p in range(cnt):
            y_wav_path = path_speech[p]
            y_wav = sf.read(y_wav_path)[0].astype(np.float32)

            n_sample = length[p]
            out = self.postprocess(output, T_ys, p)
            out_wav = reconstruct_wave(out,
                                       kwargs_istft=self.kwargs_istft,
                                       n_sample=n_sample)
            measure = eval_tts_scores(y_wav, out_wav)
            stoi += torch.tensor(measure['STOI'])
            pesq += torch.tensor(measure['PESQ'])

            out = self.postprocess(output_x, T_ys, p)
            out_wav = reconstruct_wave(out,
                                       kwargs_istft=self.kwargs_istft,
                                       n_sample=n_sample)
            measure = eval_tts_scores(y_wav, out_wav)
            stoi_x += torch.tensor(measure['STOI'])
            pesq_x += torch.tensor(measure['PESQ'])

        return {
            "val_loss": loss,
            "stoi": stoi / cnt,
            "pesq": pesq / cnt,
            "stoi_x%d" % val_repeats: stoi_x / cnt,
            "pesq_x%d" % val_repeats: pesq_x / cnt,
        }
示例#2
0
    def validation_step(self, batch, batch_idx):
        _, y_spec, _, _, T_ys, _, path_speech = batch

        x_mel = self.ed_mel2spec.spec_to_mel(y_spec)

        x_spec = self(mel=x_mel)
        z_mel = self.ed_mel2spec.spec_to_mel(x_spec)

        loss_L1 = self.calc_loss(x_spec, y_spec, T_ys, self.criterion)
        loss_reg = self.calc_loss(x_mel, z_mel, T_ys, self.criterion)

        loss = loss_L1 + self.lreg_factor * loss_reg

        output = {
            'val_loss': loss,
            'loss_L1': loss_L1,
            'loss_reg': loss_reg,
        }

        if self._cfg.train_params.validate_scores:
            '''
                For validaiton, estimate the wave using standard griffin lim,
                comparing the real wave with the griffin lim counterpart.
            '''

            cnt = x_spec.shape[0]
            np_x = x_spec.to('cpu').numpy()
            np_y = y_spec.to('cpu').numpy()
            stoi_real, pesq_real, stoi_est, pesq_est = (0.0, 0.0, 0.0, 0.0)

            for p in range(cnt):
                y_wav_path = path_speech[p]
                wav = sf.read(y_wav_path)[0].astype(np.float32)

                y_est_wav = griffin_lim(np_y[p, 0, :, :])
                x_est_wav = griffin_lim(np_x[p, 0, :, :])

                min_size = min(wav.shape[0], x_est_wav.shape[0],
                               y_est_wav.shape[0])
                wav = wav[0:min_size, ...]
                y_est_wav = y_est_wav[0:min_size, ...]
                x_est_wav = x_est_wav[0:min_size, ...]

                measure = eval_tts_scores(x_est_wav, wav)
                stoi_real += torch.tensor(measure['STOI'])
                pesq_real += torch.tensor(measure['PESQ'])

                measure = eval_tts_scores(x_est_wav, y_est_wav)
                stoi_est += torch.tensor(measure['STOI'])
                pesq_est += torch.tensor(measure['PESQ'])

            output['stoi_real'] = stoi_real / cnt
            output['pesq_real'] = pesq_real / cnt
            output['stoi_est'] = stoi_est / cnt
            output['pesq_est'] = pesq_est / cnt

        for (k, s) in self.f_specs:
            new_loss = self.calc_loss_smooth(x_spec, y_spec, T_ys, k, s)
            output[f'loss_{k}_{s}'] = new_loss
            loss = loss + new_loss

        output['val_loss'] = loss

        return output