def mag(self, noisy, inference_args):
        noisy_complex = self.torch_stft(noisy)
        noisy_mag, noisy_phase = mag_phase(noisy_complex)  # [B, F, T] => [B, 1, F, T]

        enhanced_mag = self.model(noisy_mag.unsqueeze(1)).squeeze(1)

        enhanced = self.torch_istft((enhanced_mag, noisy_phase), length=noisy.size(-1), use_mag_phase=True)
        enhanced = enhanced.detach().squeeze(0).cpu().numpy()

        return enhanced
    def scaled_mask(self, noisy, inference_args):
        noisy_complex = self.torch_stft(noisy)
        noisy_mag, noisy_phase = mag_phase(noisy_complex)

        # [B, F, T] => [B, 1, F, T] => model => [B, 2, F, T] => [B, F, T, 2]
        noisy_mag = noisy_mag.unsqueeze(1)
        scaled_mask = self.model(noisy_mag)
        scaled_mask = scaled_mask.permute(0, 2, 3, 1)

        enhanced_complex = noisy_complex * scaled_mask
        enhanced = self.torch_istft(enhanced_complex, length=noisy.size(-1), use_mag_phase=False)
        enhanced = enhanced.detach().squeeze(0).cpu().numpy()

        return enhanced
    def full_band_crm_mask(self, noisy, inference_args):
        noisy_complex = self.torch_stft(noisy)
        noisy_mag, _ = mag_phase(noisy_complex)

        noisy_mag = noisy_mag.unsqueeze(1)
        pred_crm = self.model(noisy_mag)
        pred_crm = pred_crm.permute(0, 2, 3, 1)

        pred_crm = decompress_cIRM(pred_crm)
        enhanced_real = pred_crm[..., 0] * noisy_complex.real - pred_crm[..., 1] * noisy_complex.imag
        enhanced_imag = pred_crm[..., 1] * noisy_complex.real + pred_crm[..., 0] * noisy_complex.imag
        enhanced_complex = torch.stack((enhanced_real, enhanced_imag), dim=-1)
        enhanced = self.torch_istft(enhanced_complex, length=noisy.size(-1))
        enhanced = enhanced.detach().squeeze(0).cpu().numpy()
        return enhanced
Exemple #4
0
    def _train_epoch(self, epoch):

        loss_total = 0.0
        progress_bar = None

        if self.rank == 0:
            progress_bar = tqdm(total=len(self.train_dataloader), desc=f"Training")

        for noisy, clean in self.train_dataloader:
            self.optimizer.zero_grad()

            noisy = noisy.to(self.rank)
            clean = clean.to(self.rank)

            noisy_complex = self.torch_stft(noisy)
            clean_complex = self.torch_stft(clean)

            noisy_mag, _ = mag_phase(noisy_complex)
            ground_truth_cIRM = build_complex_ideal_ratio_mask(noisy_complex, clean_complex)  # [B, F, T, 2]
            ground_truth_cIRM = drop_band(
                ground_truth_cIRM.permute(0, 3, 1, 2),  # [B, 2, F ,T]
                self.model.module.num_groups_in_drop_band
            ).permute(0, 2, 3, 1)

            with autocast(enabled=self.use_amp):
                # [B, F, T] => [B, 1, F, T] => model => [B, 2, F, T] => [B, F, T, 2]
                noisy_mag = noisy_mag.unsqueeze(1)
                cRM = self.model(noisy_mag)
                cRM = cRM.permute(0, 2, 3, 1)
                loss = self.loss_function(ground_truth_cIRM, cRM)

            self.scaler.scale(loss).backward()
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad_norm_value)
            self.scaler.step(self.optimizer)
            self.scaler.update()

            loss_total += loss.item()

            if self.rank == 0:
                progress_bar.update(1)

        if self.rank == 0:
            self.writer.add_scalar(f"Loss/Train", loss_total / len(self.train_dataloader), epoch)
Exemple #5
0
    def _validation_epoch(self, epoch):
        visualization_n_samples = self.visualization_config["n_samples"]
        visualization_num_workers = self.visualization_config["num_workers"]
        visualization_metrics = self.visualization_config["metrics"]

        loss_total = 0.0
        loss_list = {
            "With_reverb": 0.0,
            "No_reverb": 0.0,
        }
        item_idx_list = {
            "With_reverb": 0,
            "No_reverb": 0,
        }
        noisy_y_list = {
            "With_reverb": [],
            "No_reverb": [],
        }
        clean_y_list = {
            "With_reverb": [],
            "No_reverb": [],
        }
        enhanced_y_list = {
            "With_reverb": [],
            "No_reverb": [],
        }
        validation_score_list = {"With_reverb": 0.0, "No_reverb": 0.0}

        for i, (noisy, clean, name,
                speech_type) in tqdm(enumerate(self.valid_dataloader),
                                     desc="Validation"):
            assert len(
                name) == 1, "The batch size of validation stage must be one."
            name = name[0]
            speech_type = speech_type[0]

            noisy = noisy.to(self.rank)
            clean = clean.to(self.rank)

            noisy_complex = self.torch_stft(noisy)
            clean_complex = self.torch_stft(clean)

            noisy_mag, _ = mag_phase(noisy_complex)
            ground_truth_cIRM = build_complex_ideal_ratio_mask(
                noisy_complex, clean_complex)  # [B, F, T, 2]

            noisy_mag = noisy_mag.unsqueeze(1)
            pred_cRM = self.model(noisy_mag)
            pred_cRM = pred_cRM.permute(0, 2, 3, 1)

            loss = self.loss_function(ground_truth_cIRM, pred_cRM)

            lim = 9.9
            pred_cRM = lim * (pred_cRM >= lim) - lim * (
                pred_cRM <= -lim) + pred_cRM * (torch.abs(pred_cRM) < lim)
            pred_cRM = -10 * torch.log((10 - pred_cRM) / (10 + pred_cRM))

            enhanced_real = pred_cRM[..., 0] * noisy_complex[
                ..., 0] - pred_cRM[..., 1] * noisy_complex[..., 1]
            enhanced_imag = pred_cRM[..., 1] * noisy_complex[
                ..., 0] + pred_cRM[..., 0] * noisy_complex[..., 1]
            enhanced_complex = torch.stack((enhanced_real, enhanced_imag),
                                           dim=-1)

            enhanced = self.torch_istft(enhanced_complex,
                                        length=noisy.size(-1),
                                        use_mag_phase=False)

            noisy = noisy.detach().squeeze(0).cpu().numpy()
            clean = clean.detach().squeeze(0).cpu().numpy()
            enhanced = enhanced.detach().squeeze(0).cpu().numpy()

            assert len(noisy) == len(clean) == len(enhanced)
            loss_total += loss
            """=== === === Visualization === === ==="""
            # Separated Loss
            loss_list[speech_type] += loss
            item_idx_list[speech_type] += 1

            if item_idx_list[speech_type] <= visualization_n_samples:
                self.spec_audio_visualization(noisy,
                                              enhanced,
                                              clean,
                                              name,
                                              epoch,
                                              mark=speech_type)

            noisy_y_list[speech_type].append(noisy)
            clean_y_list[speech_type].append(clean)
            enhanced_y_list[speech_type].append(enhanced)

        self.writer.add_scalar(f"Loss/Validation_Total",
                               loss_total / len(self.valid_dataloader), epoch)

        for speech_type in ("With_reverb", "No_reverb"):
            self.writer.add_scalar(
                f"Loss/{speech_type}",
                loss_list[speech_type] / len(self.valid_dataloader), epoch)

            validation_score_list[speech_type] = self.metrics_visualization(
                noisy_y_list[speech_type],
                clean_y_list[speech_type],
                enhanced_y_list[speech_type],
                visualization_metrics,
                epoch,
                visualization_num_workers,
                mark=speech_type)

        return validation_score_list["No_reverb"]
Exemple #6
0
    def _validation_epoch(self, epoch):
        visualization_n_samples = self.visualization_config["n_samples"]
        visualization_num_workers = self.visualization_config["num_workers"]
        visualization_metrics = self.visualization_config["metrics"]

        loss_total = 0.0
        loss_list = {
            "With_reverb": 0.0,
            "No_reverb": 0.0,
        }
        item_idx_list = {
            "With_reverb": 0,
            "No_reverb": 0,
        }
        noisy_y_list = {
            "With_reverb": [],
            "No_reverb": [],
        }
        clean_y_list = {
            "With_reverb": [],
            "No_reverb": [],
        }
        enhanced_y_list = {
            "With_reverb": [],
            "No_reverb": [],
        }
        validation_score_list = {"With_reverb": 0.0, "No_reverb": 0.0}

        # speech_type in ("with_reverb", "no_reverb")
        for i, (noisy, clean, name,
                speech_type) in tqdm(enumerate(self.valid_dataloader),
                                     desc="Validation"):
            assert len(
                name
            ) == 1, "The batch size for the validation stage must be one."
            name = name[0]
            speech_type = speech_type[0]

            noisy = noisy.to(self.rank)
            clean = clean.to(self.rank)

            noisy_complex = self.torch_stft(noisy)
            clean_complex = self.torch_stft(clean)

            noisy_mag, _ = mag_phase(noisy_complex)
            cIRM = build_complex_ideal_ratio_mask(
                noisy_complex, clean_complex)  # [B, F, T, 2]

            noisy_mag = noisy_mag.unsqueeze(1)
            cRM = self.model(noisy_mag)
            cRM = cRM.permute(0, 2, 3, 1)

            loss = self.loss_function(cIRM, cRM)

            cRM = decompress_cIRM(cRM)

            enhanced_real = cRM[..., 0] * noisy_complex.real - cRM[
                ..., 1] * noisy_complex.imag
            enhanced_imag = cRM[..., 1] * noisy_complex.real + cRM[
                ..., 0] * noisy_complex.imag
            enhanced_complex = torch.stack((enhanced_real, enhanced_imag),
                                           dim=-1)
            enhanced = self.torch_istft(enhanced_complex,
                                        length=noisy.size(-1))

            noisy = noisy.detach().squeeze(0).cpu().numpy()
            clean = clean.detach().squeeze(0).cpu().numpy()
            enhanced = enhanced.detach().squeeze(0).cpu().numpy()

            assert len(noisy) == len(clean) == len(enhanced)
            loss_total += loss

            # Separated loss
            loss_list[speech_type] += loss
            item_idx_list[speech_type] += 1

            if item_idx_list[speech_type] <= visualization_n_samples:
                self.spec_audio_visualization(noisy,
                                              enhanced,
                                              clean,
                                              name,
                                              epoch,
                                              mark=speech_type)

            noisy_y_list[speech_type].append(noisy)
            clean_y_list[speech_type].append(clean)
            enhanced_y_list[speech_type].append(enhanced)

        self.writer.add_scalar(f"Loss/Validation_Total",
                               loss_total / len(self.valid_dataloader), epoch)

        for speech_type in ("With_reverb", "No_reverb"):
            self.writer.add_scalar(
                f"Loss/{speech_type}",
                loss_list[speech_type] / len(self.valid_dataloader), epoch)

            validation_score_list[speech_type] = self.metrics_visualization(
                noisy_y_list[speech_type],
                clean_y_list[speech_type],
                enhanced_y_list[speech_type],
                visualization_metrics,
                epoch,
                visualization_num_workers,
                mark=speech_type)

        return validation_score_list["No_reverb"]