Exemplo n.º 1
0
    def on_stage_start(self, stage, epoch=None):
        """Gets called at the beginning of each epoch

        This method calls ``fit()`` again to train the discriminator
        before proceeding with generator training.
        """

        self.mse_metric = MetricStats(metric=self.hparams.compute_cost)
        self.metrics = {"G": [], "D": []}

        if stage == sb.Stage.TRAIN:
            if self.hparams.target_metric == "pesq":
                self.target_metric = MetricStats(metric=pesq_eval, n_jobs=40)
            elif self.hparams.target_metric == "stoi":
                self.target_metric = MetricStats(metric=stoi_loss)
            else:
                raise NotImplementedError(
                    "Right now we only support 'pesq' and 'stoi'"
                )

            # Train discriminator before we start generator training
            if self.sub_stage == SubStage.GENERATOR:
                self.epoch = epoch
                self.train_discriminator()
                self.sub_stage = SubStage.GENERATOR
                print("Generator training by current data...")

        if stage != sb.Stage.TRAIN:
            self.pesq_metric = MetricStats(metric=pesq_eval, n_jobs=30)
            self.stoi_metric = MetricStats(metric=stoi_loss)
Exemplo n.º 2
0
    def on_stage_start(self, stage, epoch=None):
        self.loss_metric = MetricStats(metric=self.hparams.compute_cost)
        self.stoi_metric = MetricStats(metric=stoi_loss)

        # Define function taking (prediction, target) for parallel eval
        def pesq_eval(pred_wav, target_wav):
            return pesq(
                fs=16000,
                ref=target_wav.cpu().numpy(),
                deg=pred_wav.cpu().numpy(),
                mode="wb",
            )

        if stage != sb.Stage.TRAIN:
            self.pesq_metric = MetricStats(metric=pesq_eval, n_jobs=4)
Exemplo n.º 3
0
    def on_stage_start(self, stage, epoch=None):
        """Gets called at the beginning of each epoch

        This method calls ``fit()`` again to train the discriminator
        before proceeding with generator training.
        """

        self.metrics = {"G": [], "D": []}

        if stage == sb.Stage.TRAIN:
            if self.hparams.target_metric == "srmr":
                self.target_metric = MetricStats(
                    metric=srmrpy_eval,
                    n_jobs=hparams["n_jobs"],
                    batch_eval=False,
                )
            elif self.hparams.target_metric == "dnsmos":
                self.target_metric = MetricStats(
                    metric=dnsmos_eval,
                    n_jobs=hparams["n_jobs"],
                    batch_eval=False,
                )
            else:
                raise NotImplementedError(
                    "Right now we only support 'srmr' and 'dnsmos'")

            # Train discriminator before we start generator training
            if self.sub_stage == SubStage.GENERATOR:
                self.epoch = epoch
                self.train_discriminator()
                self.sub_stage = SubStage.GENERATOR
                print("Generator training by current data...")

        if stage != sb.Stage.TRAIN:
            self.pesq_metric = MetricStats(metric=pesq_eval,
                                           n_jobs=hparams["n_jobs"],
                                           batch_eval=False)
            self.stoi_metric = MetricStats(metric=stoi_loss)
            self.srmr_metric = MetricStats(
                metric=srmrpy_eval_valid,
                n_jobs=hparams["n_jobs"],
                batch_eval=False,
            )
            self.dnsmos_metric = MetricStats(
                metric=dnsmos_eval_valid,
                n_jobs=hparams["n_jobs"],
                batch_eval=False,
            )
Exemplo n.º 4
0
    def on_stage_start(self, stage, epoch=None):
        """Gets called at the beginning of each epoch"""
        self.loss_metric = MetricStats(metric=self.hparams.compute_cost)
        self.stoi_metric = MetricStats(metric=stoi_loss)

        # Define function taking (prediction, target) for parallel eval
        def pesq_eval(pred_wav, target_wav):
            """Computes the PESQ evaluation metric"""
            return pesq(
                fs=16000,
                ref=target_wav.numpy(),
                deg=pred_wav.numpy(),
                mode="wb",
            )

        if stage != sb.Stage.TRAIN:
            self.pesq_metric = MetricStats(metric=pesq_eval, n_jobs=4)
Exemplo n.º 5
0
def test_metric_stats():
    from speechbrain.utils.metric_stats import MetricStats
    from speechbrain.nnet.losses import l1_loss

    l1_stats = MetricStats(metric=l1_loss)
    l1_stats.append(
        ids=["utterance1", "utterance2"],
        predictions=torch.tensor([[0.1, 0.2], [0.1, 0.2]]),
        targets=torch.tensor([[0.1, 0.3], [0.2, 0.3]]),
        length=torch.ones(2),
        reduction="batch",
    )
    summary = l1_stats.summarize()
    assert math.isclose(summary["average"], 0.075, rel_tol=1e-5)
    assert math.isclose(summary["min_score"], 0.05, rel_tol=1e-5)
    assert summary["min_id"] == "utterance1"
    assert math.isclose(summary["max_score"], 0.1, rel_tol=1e-5)
    assert summary["max_id"] == "utterance2"
Exemplo n.º 6
0
    def on_stage_start(self, stage, epoch=None):
        """Gets called at the beginning of each epoch"""
        if stage != sb.Stage.TRAIN:
            # Define function taking (prediction, target) for parallel eval
            def pesq_eval(pred_wav, target_wav):
                """Computes the PESQ evaluation metric"""
                psq_mode = "wb" if self.hparams.sample_rate == 16000 else "nb"
                try:
                    return pesq(
                        fs=self.hparams.sample_rate,
                        ref=target_wav.numpy(),
                        deg=pred_wav.numpy(),
                        mode=psq_mode,
                    )
                except Exception:
                    print("pesq encountered an error for this data item")
                    return 0

            self.pesq_metric = MetricStats(
                metric=pesq_eval, n_jobs=1, batch_eval=False
            )
Exemplo n.º 7
0
    def on_stage_start(self, stage, epoch=None):
        """Gets called at the beginning of each epoch"""
        self.loss_metric_d1 = MetricStats(
            metric=self.hparams.compute_cost["d1"])
        self.loss_metric_d2 = MetricStats(
            metric=self.hparams.compute_cost["d2"])
        self.loss_metric_g3 = MetricStats(
            metric=self.hparams.compute_cost["g3"])
        self.stoi_metric = MetricStats(metric=stoi_loss)

        # Define function taking (prediction, target) for parallel eval
        def pesq_eval(pred_wav, target_wav):
            """Computes the PESQ evaluation metric"""
            return pesq(
                fs=hparams["sample_rate"],
                ref=target_wav.numpy().squeeze(),
                deg=pred_wav.numpy().squeeze(),
                mode="wb",
            )

        if stage != sb.Stage.TRAIN:
            self.pesq_metric = MetricStats(metric=pesq_eval,
                                           batch_eval=False,
                                           n_jobs=1)
Exemplo n.º 8
0
class MetricGanBrain(sb.Brain):
    def compute_feats(self, wavs):
        """Feature computation pipeline"""
        feats = self.hparams.compute_STFT(wavs)
        feats = spectral_magnitude(feats, power=0.5)
        feats = torch.log1p(feats)
        return feats

    def compute_forward(self, batch, stage):
        "Given an input batch computes the enhanced signal"
        batch = batch.to(self.device)

        if self.sub_stage == SubStage.HISTORICAL:
            predict_wav, lens = batch.enh_sig
        else:
            noisy_wav, lens = batch.noisy_sig
            noisy_spec = self.compute_feats(noisy_wav)

            # mask with "signal approximation (SA)"
            # print("check noisy spec",noisy_spec.size())
            mask = self.modules.generator(noisy_spec, lengths=lens)
            mask = mask.clamp(min=self.hparams.min_mask).squeeze(2)
            predict_spec = torch.mul(mask, noisy_spec)

            # Also return predicted wav
            predict_wav = self.hparams.resynth(
                torch.expm1(predict_spec), noisy_wav
            )

        return predict_wav

    def compute_objectives(self, predictions, batch, stage, optim_name=""):
        "Given the network predictions and targets compute the total loss"
        predict_wav = predictions
        predict_spec = self.compute_feats(predict_wav)

        clean_wav, lens = batch.clean_sig
        clean_spec = self.compute_feats(clean_wav)
        # print(clean_wav.size(), predict_wav.size())
        # print(clean_spec.size(), predict_spec.size())
        mse_cost = self.hparams.compute_cost(predict_spec, clean_spec, lens)

        # print("batch.id", batch.id)
        ids = self.compute_ids(batch.id, optim_name)

        # One is real, zero is fake
        if optim_name == "generator" or optim_name == "":
            if optim_name == "generator":
                target_score = torch.ones(self.batch_size, 1, device=self.device)
            else:
                target_score = torch.ones(1, 1, device=self.device)
            # if optim_name == "":
            #     print("predict_wav", predict_wav.size(), clean_wav.size())
            #     print("predict_spec: ", predict_spec.size(), clean_spec.size())
            est_score = self.est_score(predict_spec, clean_spec)
            self.mse_metric.append(
                ids, predict_spec, clean_spec, lens, reduction="batch"
            )

        # D Learns to estimate the scores of clean speech
        elif optim_name == "D_clean":
            target_score = torch.ones(self.batch_size, 1, device=self.device)
            est_score = self.est_score(clean_spec, clean_spec)

        # D Learns to estimate the scores of enhanced speech
        elif optim_name == "D_enh" and self.sub_stage == SubStage.CURRENT:
            target_score = self.score(ids, predict_wav, clean_wav, lens)
            est_score = self.est_score(predict_spec, clean_spec)

            # Write enhanced wavs during discriminator training, because we
            # compute the actual score here and we can save it
            self.write_wavs(batch.id, ids, predict_wav, target_score, lens)

        # D Relearns to estimate the scores of previous epochs
        elif optim_name == "D_enh" and self.sub_stage == SubStage.HISTORICAL:
            target_score = batch.score.unsqueeze(1).float()
            est_score = self.est_score(predict_spec, clean_spec)

        # D Learns to estimate the scores of noisy speech
        elif optim_name == "D_noisy":
            noisy_wav, _ = batch.noisy_sig
            noisy_spec = self.compute_feats(noisy_wav)
            target_score = self.score(ids, noisy_wav, clean_wav, lens)
            est_score = self.est_score(noisy_spec, clean_spec)

            # Save scores of noisy wavs
            self.save_noisy_scores(ids, target_score)

        else:
            raise ValueError(f"{optim_name} is not a valid 'optim_name'")

        # Compute the cost
        adv_cost = self.hparams.compute_cost(est_score, target_score)
        if optim_name == "generator":
            adv_cost += self.hparams.mse_weight * mse_cost
            self.metrics["G"].append(adv_cost.detach())
        else:
            self.metrics["D"].append(adv_cost.detach())

        # On validation data compute scores
        if stage != sb.Stage.TRAIN:
            # Evaluate speech quality/intelligibility
            if self.hparams.mode == 'val':
                if self.hparams.target_metric == "stoi":
                    self.stoi_metric.append(
                        batch.id, predict_wav, clean_wav, lens, reduction="batch"
                    )
                elif self.hparams.target_metric == "pesq":
                    self.pesq_metric.append(
                        batch.id, predict=predict_wav, target=clean_wav, lengths=lens
                    )

            # Write wavs to file, for evaluation
            elif self.hparams.mode == 'test':
                self.stoi_metric.append(
                        batch.id, predict_wav, clean_wav, lens, reduction="batch"
                    )
                self.pesq_metric.append(
                        batch.id, predict=predict_wav, target=clean_wav, lengths=lens
                    )
                lens = lens * clean_wav.shape[1]
                for name, pred_wav, length in zip(batch.id, predict_wav, lens):
                    name += ".wav"
                    enhance_path = os.path.join(self.hparams.enhanced_folder, name)
                    torchaudio.save(
                        enhance_path,
                        torch.unsqueeze(pred_wav[: int(length)].cpu(), 0),
                        16000,
                    )

        # we do not use mse_cost to update model
        return adv_cost

    def compute_ids(self, batch_id, optim_name):
        """Returns the list of ids, edited via optimizer name."""
        if optim_name == "D_enh":
            return [f"{uid}@{self.epoch}" for uid in batch_id]
        return batch_id

    def save_noisy_scores(self, batch_id, scores):
        for i, score in zip(batch_id, scores):
            self.noisy_scores[i] = score

    def score(self, batch_id, deg_wav, ref_wav, lens):
        """Returns actual metric score, either pesq or stoi

        Arguments
        ---------
        batch_id : list of str
            A list of the utterance ids for the batch
        deg_wav : torch.Tensor
            The degraded waveform to score
        ref_wav : torch.Tensor
            The reference waveform to use for scoring
        length : torch.Tensor
            The relative lengths of the utterances
        """
        new_ids = [
            i
            for i, d in enumerate(batch_id)
            if d not in self.historical_set and d not in self.noisy_scores
        ]

        if len(new_ids) == 0:
            pass
        elif self.hparams.target_metric == "pesq":
            self.target_metric.append(
                ids=[batch_id[i] for i in new_ids],
                predict=deg_wav[new_ids].detach(),
                target=ref_wav[new_ids].detach(),
                lengths=lens[new_ids],
            )
            score = torch.tensor(
                [[s] for s in self.target_metric.scores], device=self.device,
            )
        elif self.hparams.target_metric == "stoi":
            self.target_metric.append(
                [batch_id[i] for i in new_ids],
                deg_wav[new_ids],
                ref_wav[new_ids],
                lens[new_ids],
                reduction="batch",
            )
            score = torch.tensor(
                [[-s] for s in self.target_metric.scores], device=self.device,
            )
        else:
            raise ValueError("Expected 'pesq' or 'stoi' for target_metric")

        # Clear metric scores to prepare for next batch
        self.target_metric.clear()

        # Combine old scores and new
        final_score = []
        for i, d in enumerate(batch_id):
            if d in self.historical_set:
                final_score.append([self.historical_set[d]["score"]])
            elif d in self.noisy_scores:
                final_score.append([self.noisy_scores[d]])
            else:
                final_score.append([score[new_ids.index(i)]])

        return torch.tensor(final_score, device=self.device)

    def est_score(self, deg_spec, ref_spec):
        """Returns score as estimated by discriminator

        Arguments
        ---------
        deg_spec : torch.Tensor
            The spectral features of the degraded utterance
        ref_spec : torch.Tensor
            The spectral features of the reference utterance
        """
        combined_spec = torch.cat(
            [deg_spec.unsqueeze(1), ref_spec.unsqueeze(1)], 1
        )
        return self.modules.discriminator(combined_spec)

    def write_wavs(self, clean_id, batch_id, wavs, scores, lens):
        """Write wavs to files, for historical discriminator training

        Arguments
        ---------
        batch_id : list of str
            A list of the utterance ids for the batch
        wavs : torch.Tensor
            The wavs to write to files
        score : torch.Tensor
            The actual scores for the corresponding utterances
        lens : torch.Tensor
            The relative lengths of each utterance
        """
        lens = lens * wavs.shape[1]
        record = {}
        for i, (cleanid, name, pred_wav, length) in enumerate(
            zip(clean_id, batch_id, wavs, lens)
        ):
            path = os.path.join(self.hparams.MetricGAN_folder, name + ".wav")
            data = torch.unsqueeze(pred_wav[: int(length)].cpu(), 0)
            torchaudio.save(path, data, self.hparams.Sample_rate)

            # Make record of path and score for historical training
            score = float(scores[i][0])
            clean_path = cleanid.split('-', 1)
            clean_path = os.path.join(
                self.hparams.train_clean_folder, clean_path[0], clean_path[1] + ".pkl"
            )
            record[name] = {
                "enh_wav": path,
                "score": score,
                "clean_wav": clean_path,
            }

        # Update records for historical training
        self.historical_set.update(record)

    def fit_batch(self, batch):
        "Compute gradients and update either D or G based on sub-stage."
        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
        loss_tracker = 0
        if self.sub_stage == SubStage.CURRENT:
            for mode in ["clean", "enh", "noisy"]:
                loss = self.compute_objectives(
                    predictions, batch, sb.Stage.TRAIN, f"D_{mode}"
                )
                self.d_optimizer.zero_grad()
                loss.backward()
                if self.check_gradients(loss):
                    self.d_optimizer.step()
                loss_tracker += loss.detach() / 3
        elif self.sub_stage == SubStage.HISTORICAL:
            loss = self.compute_objectives(
                predictions, batch, sb.Stage.TRAIN, "D_enh"
            )
            self.d_optimizer.zero_grad()
            loss.backward()
            if self.check_gradients(loss):
                self.d_optimizer.step()
            loss_tracker += loss.detach()
        elif self.sub_stage == SubStage.GENERATOR:
            loss = self.compute_objectives(
                predictions, batch, sb.Stage.TRAIN, "generator"
            )
            self.g_optimizer.zero_grad()
            loss.backward()
            if self.check_gradients(loss):
                self.g_optimizer.step()
            loss_tracker += loss.detach()
        return loss_tracker

    def on_stage_start(self, stage, epoch=None):
        """Gets called at the beginning of each epoch

        This method calls ``fit()`` again to train the discriminator
        before proceeding with generator training.
        """

        self.mse_metric = MetricStats(metric=self.hparams.compute_cost)
        self.metrics = {"G": [], "D": []}

        if stage == sb.Stage.TRAIN:
            if self.hparams.target_metric == "pesq":
                self.target_metric = MetricStats(metric=pesq_eval, n_jobs=40)
            elif self.hparams.target_metric == "stoi":
                self.target_metric = MetricStats(metric=stoi_loss)
            else:
                raise NotImplementedError(
                    "Right now we only support 'pesq' and 'stoi'"
                )

            # Train discriminator before we start generator training
            if self.sub_stage == SubStage.GENERATOR:
                self.epoch = epoch
                self.train_discriminator()
                self.sub_stage = SubStage.GENERATOR
                print("Generator training by current data...")

        if stage != sb.Stage.TRAIN:
            self.pesq_metric = MetricStats(metric=pesq_eval, n_jobs=30)
            self.stoi_metric = MetricStats(metric=stoi_loss)

    def train_discriminator(self):
        """A total of 3 data passes to update discriminator."""
        # First, iterate train subset w/ updates for clean, enh, noisy
        print("Discriminator training by current data...")
        self.sub_stage = SubStage.CURRENT
        self.fit(
            range(1),
            self.train_set,
            train_loader_kwargs=self.hparams.dataloader_options,
        )

        # Next, iterate historical subset w/ updates for enh
        if self.historical_set:
            print("Discriminator training by historical data...")
            self.sub_stage = SubStage.HISTORICAL
            self.fit(
                range(1),
                self.historical_set,
                train_loader_kwargs=self.hparams.dataloader_options,
            )

        # Finally, iterate train set again. Should iterate same
        # samples as before, due to ReproducibleRandomSampler
        print("Discriminator training by current data again...")
        self.sub_stage = SubStage.CURRENT
        self.fit(
            range(1),
            self.train_set,
            train_loader_kwargs=self.hparams.dataloader_options,
        )

    def on_stage_end(self, stage, stage_loss, epoch=None):
        "Called at the end of each stage to summarize progress"
        # epoch is awared in each stage
        def ckpt_predicate(ckpt):
            return ckpt.meta['epoch'] == epoch
        
        def ckpt_predicate_lessthan(ckpt):
            return ckpt.meta['epoch'] < epoch

        if self.sub_stage != SubStage.GENERATOR:
            return
        
        if stage == sb.Stage.TRAIN:
            self.train_loss = stage_loss
            g_loss = torch.tensor(self.metrics["G"])  # batch_size
            d_loss = torch.tensor(self.metrics["D"])  # batch_size
            print("Avg G loss: %.3f" % torch.mean(g_loss))
            print("Avg D loss: %.3f" % torch.mean(d_loss))
            print("MSE distance: %.3f" % self.mse_metric.summarize("average"))
            # save the checkpoint every 10 epochs
            # use default timestamp or use epoch numbers as max key
            # run the test as 
            stats = {
                "epoch": epoch,
                "MSE distance": stage_loss,
                # "target_metric": self.target_metric.summarize("average")
            }
            self.checkpointer.save_checkpoint(meta=stats)

        elif self.hparams.target_metric == "pesq":
            stats = {
                "epoch": epoch,
                "MSE distance": stage_loss,
                "pesq": 5 * self.pesq_metric.summarize("average") - 0.5,
            }
        elif self.hparams.target_metric == "stoi":
            stats = {
                "epoch": epoch,
                "MSE distance": stage_loss,
                "stoi": -self.stoi_metric.summarize("average"),
            }

        if stage == sb.Stage.VALID:
            if self.hparams.use_tensorboard:
                valid_stats = {
                    "mse": stage_loss,
                    "pesq": 5 * self.pesq_metric.summarize("average") - 0.5,
                    "stoi": -self.stoi_metric.summarize("average"),
                }
                self.hparams.tensorboard_train_logger.log_stats(valid_stats)
            self.hparams.train_logger.log_stats(
                {"Epoch": epoch},
                train_stats={"loss": self.train_loss},
                valid_stats=stats,
            )
            self.checkpointer.save_and_keep_only(
                meta=stats, max_keys=[self.hparams.target_metric]
            )

        if stage == sb.Stage.TEST:
            if self.hparams.mode == 'val':
                ckpts = self.checkpointer.find_checkpoints(ckpt_predicate=ckpt_predicate)
                assert len(ckpts) == 1
                # delete old ckpt from train
                self.checkpointer.delete_checkpoints(num_to_keep=0, ckpt_predicate=ckpt_predicate)

                if self.hparams.use_tensorboard:
                    valid_stats = stats
                    # will two tensorboard corrupt?
                    self.hparams.tensorboard_train_logger.log_stats(valid_stats)
                self.hparams.train_logger.log_stats(
                    {"Epoch": epoch},
                    valid_stats=stats,
                )
                # TODO:save current checkpointer again
                self.checkpointer.save_and_keep_only(
                    meta=stats, max_keys=[self.hparams.target_metric], ckpt_predicate=ckpt_predicate_lessthan
                    )
                # self.checkpointer.save_checkpoint(meta=stats)

            else:
                print("Epoch loaded", self.hparams.epoch_counter.current)
                print("stoi", -self.stoi_metric.summarize("average"))
                print("pesq", 5 * self.pesq_metric.summarize("average") - 0.5)
                test_stats = {
                    "mse": stage_loss,
                    "pesq": 5 * self.pesq_metric.summarize("average") - 0.5,
                    "stoi": -self.stoi_metric.summarize("average"),
                }
                # self.hparams.train_logger.log_stats(
                #     {"Epoch loaded": self.hparams.epoch_counter.current},
                #     test_stats=test_stats,
                # )

    def make_dataloader(
        self, dataset, stage, ckpt_prefix="dataloader-", **loader_kwargs
    ):
        "Override dataloader to insert custom sampler/dataset"
        if stage == sb.Stage.TRAIN:

            # Create a new dataset each time, this set grows
            if self.sub_stage == SubStage.HISTORICAL:
                dataset = sb.dataio.dataset.DynamicItemDataset(
                    data=dataset,
                    dynamic_items=[enh_pipeline],
                    output_keys=["id", "enh_sig", "clean_sig", "score"],
                )
                samples = round(len(dataset) * self.hparams.history_portion)
            else:
                samples = self.hparams.number_of_samples

            # This sampler should give the same samples for D and G
            epoch = self.hparams.epoch_counter.current

            # Equal weights for all samples, we use "Weighted" so we can do
            # both "replacement=False" and a set number of samples, reproducibly
            weights = torch.ones(len(dataset))
            sampler = ReproducibleWeightedRandomSampler(
                weights, epoch=epoch, replacement=False, num_samples=samples
            )
            loader_kwargs["sampler"] = sampler

            if self.sub_stage == SubStage.GENERATOR:
                self.train_sampler = sampler

        # Make the dataloader as normal
        return super().make_dataloader(
            dataset, stage, ckpt_prefix, **loader_kwargs
        )

    def on_fit_start(self):
        "Override to prevent this from running for D training"
        if self.sub_stage == SubStage.GENERATOR:
            super().on_fit_start()

    def init_optimizers(self):
        "Initializes the generator and discriminator optimizers"
        self.g_optimizer = self.hparams.g_opt_class(
            self.modules.generator.parameters()
        )
        self.d_optimizer = self.hparams.d_opt_class(
            self.modules.discriminator.parameters()
        )

        if self.checkpointer is not None:
            self.checkpointer.add_recoverable("g_opt", self.g_optimizer)
            self.checkpointer.add_recoverable("d_opt", self.d_optimizer)
Exemplo n.º 9
0
class SEBrain(sb.core.Brain):
    def compute_forward(self, batch, stage):
        """Forward computations from the waveform batches to the enhanced output."""
        batch = batch.to(self.device)
        noisy_wavs, lens = batch.noisy_sig

        feats = self.hparams.compute_STFT(noisy_wavs)
        feats = spectral_magnitude(feats, power=0.5)
        feats = torch.log1p(feats)

        predict_spec = self.hparams.model(feats)

        # Also return predicted wav
        if stage != sb.Stage.TRAIN:
            predict_wav = self.hparams.resynth(torch.expm1(predict_spec),
                                               noisy_wavs)
        else:
            predict_wav = None

        return predict_spec, predict_wav

    def compute_objectives(self, predictions, batch, stage):
        """Computes the loss given the predicted and targeted outputs"""
        predict_spec, predict_wav = predictions
        ids = batch.id
        clean_wav, lens = batch.clean_sig

        feats = self.hparams.compute_STFT(clean_wav)
        feats = spectral_magnitude(feats, power=0.5)
        feats = torch.log1p(feats)

        loss = self.hparams.compute_cost(predict_spec, feats, lens)

        self.loss_metric.append(ids,
                                predict_spec,
                                feats,
                                lens,
                                reduction="batch")

        if stage != sb.Stage.TRAIN:
            # Evaluate speech quality/intelligibility
            self.stoi_metric.append(ids,
                                    predict_wav,
                                    clean_wav,
                                    lens,
                                    reduction="batch")
            self.pesq_metric.append(batch.id,
                                    predict=predict_wav,
                                    target=clean_wav,
                                    lengths=lens)

        # Write wavs to file
        if stage == sb.Stage.TEST:
            lens = lens * clean_wav.shape[1]
            for name, wav, length in zip(ids, predict_wav, lens):
                enhance_path = os.path.join(self.hparams.enhanced_folder, name)
                if not enhance_path.endswith(".wav"):
                    enhance_path = enhance_path + ".wav"
                torchaudio.save(
                    enhance_path,
                    torch.unsqueeze(wav[:int(length)].cpu(), 0),
                    self.hparams.Sample_rate,
                )

        return loss

    def on_stage_start(self, stage, epoch=None):
        """Gets called at the beginning of each epoch"""
        self.loss_metric = MetricStats(metric=self.hparams.compute_cost)
        self.stoi_metric = MetricStats(metric=stoi_loss)

        # Define function taking (prediction, target) for parallel eval
        def pesq_eval(pred_wav, target_wav):
            """Computes the PESQ evaluation metric"""
            return pesq(
                fs=16000,
                ref=target_wav.cpu().numpy(),
                deg=pred_wav.cpu().numpy(),
                mode="wb",
            )

        if stage != sb.Stage.TRAIN:
            self.pesq_metric = MetricStats(metric=pesq_eval, n_jobs=4)

    def on_stage_end(self, stage, stage_loss, epoch=None):
        """Gets called at the end of an epoch."""
        if stage == sb.Stage.TRAIN:
            self.train_loss = stage_loss
            self.train_stats = {"loss": self.loss_metric.scores}
        else:
            stats = {
                "loss": stage_loss,
                "pesq": self.pesq_metric.summarize("average"),
                "stoi": -self.stoi_metric.summarize("average"),
            }

        if stage == sb.Stage.VALID:
            old_lr, new_lr = self.hparams.lr_annealing(4.5 - stats["pesq"])
            sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr)

            if self.hparams.use_tensorboard:
                valid_stats = {
                    "loss": self.loss_metric.scores,
                    "stoi": self.stoi_metric.scores,
                    "pesq": self.pesq_metric.scores,
                }
                self.hparams.tensorboard_train_logger.log_stats(
                    {
                        "Epoch": epoch,
                        "lr": old_lr
                    },
                    self.train_stats,
                    valid_stats,
                )

            self.hparams.train_logger.log_stats(
                {
                    "Epoch": epoch,
                    "lr": old_lr
                },
                train_stats={"loss": self.train_loss},
                valid_stats=stats,
            )
            self.checkpointer.save_and_keep_only(meta=stats, max_keys=["pesq"])

        if stage == sb.Stage.TEST:
            self.hparams.train_logger.log_stats(
                {"Epoch loaded": self.hparams.epoch_counter.current},
                test_stats=stats,
            )
Exemplo n.º 10
0
class MetricGanBrain(sb.Brain):
    def load_history(self):
        if os.path.isfile(self.hparams.historical_file):
            with open(self.hparams.historical_file, "rb") as fp:  # Unpickling
                self.historical_set = pickle.load(fp)

    def compute_feats(self, wavs):
        """Feature computation pipeline"""
        feats = self.hparams.compute_STFT(wavs)
        spec = spectral_magnitude(feats, power=0.5)
        return spec

    def compute_forward(self, batch, stage):
        "Given an input batch computes the enhanced signal"
        batch = batch.to(self.device)

        if self.sub_stage == SubStage.HISTORICAL:
            predict_wav, lens = batch.enh_sig
            return predict_wav
        else:
            noisy_wav, lens = batch.noisy_sig
            noisy_spec = self.compute_feats(noisy_wav)

            mask = self.modules.generator(noisy_spec, lengths=lens)
            mask = mask.clamp(min=self.hparams.min_mask).squeeze(2)
            predict_spec = torch.mul(mask, noisy_spec)

            # Also return predicted wav
            predict_wav = self.hparams.resynth(predict_spec, noisy_wav)

            return predict_wav, mask

    def compute_objectives(self, predictions, batch, stage, optim_name=""):
        "Given the network predictions and targets compute the total loss"
        if self.sub_stage == SubStage.HISTORICAL:
            predict_wav = predictions
        else:
            predict_wav, mask = predictions
        predict_spec = self.compute_feats(predict_wav)

        ids = self.compute_ids(batch.id, optim_name)
        if self.sub_stage != SubStage.HISTORICAL:
            noisy_wav, lens = batch.noisy_sig

        if optim_name == "generator":
            est_score = self.est_score(predict_spec)
            target_score = self.hparams.target_score * torch.ones(
                self.batch_size, 1, device=self.device
            )

            noisy_wav, lens = batch.noisy_sig
            noisy_spec = self.compute_feats(noisy_wav)
            mse_cost = self.hparams.compute_cost(predict_spec, noisy_spec, lens)

        # D Learns to estimate the scores of enhanced speech
        elif optim_name == "D_enh" and self.sub_stage == SubStage.CURRENT:
            target_score = self.score(
                ids, predict_wav, predict_wav, lens
            )  # no clean_wav is needed
            est_score = self.est_score(predict_spec)

            # Write enhanced wavs during discriminator training, because we
            # compute the actual score here and we can save it
            self.write_wavs(ids, predict_wav, target_score, lens)

        # D Relearns to estimate the scores of previous epochs
        elif optim_name == "D_enh" and self.sub_stage == SubStage.HISTORICAL:
            target_score = batch.score.unsqueeze(1).float()
            est_score = self.est_score(predict_spec)

        # D Learns to estimate the scores of noisy speech
        elif optim_name == "D_noisy":
            noisy_spec = self.compute_feats(noisy_wav)
            target_score = self.score(
                ids, noisy_wav, noisy_wav, lens
            )  # no clean_wav is needed
            est_score = self.est_score(noisy_spec)
            # Save scores of noisy wavs
            self.save_noisy_scores(ids, target_score)

        if stage == sb.Stage.TRAIN:
            # Compute the cost
            cost = self.hparams.compute_cost(est_score, target_score)
            if optim_name == "generator":
                cost += self.hparams.mse_weight * mse_cost
                self.metrics["G"].append(cost.detach())
            else:
                self.metrics["D"].append(cost.detach())

        # Compute scores on validation data
        if stage != sb.Stage.TRAIN:
            clean_wav, lens = batch.clean_sig

            cost = self.hparams.compute_si_snr(predict_wav, clean_wav, lens)
            # Evaluate speech quality/intelligibility
            self.stoi_metric.append(
                batch.id, predict_wav, clean_wav, lens, reduction="batch"
            )
            self.pesq_metric.append(
                batch.id, predict=predict_wav, target=clean_wav, lengths=lens
            )
            if (
                self.hparams.calculate_dnsmos_on_validation_set
            ):  # Note: very time consuming........
                self.dnsmos_metric.append(
                    batch.id,
                    predict=predict_wav,
                    target=predict_wav,
                    lengths=lens,  # no clean_wav is needed
                )

            # Write wavs to file, for evaluation
            lens = lens * clean_wav.shape[1]
            for name, pred_wav, length in zip(batch.id, predict_wav, lens):
                name += ".wav"
                enhance_path = os.path.join(self.hparams.enhanced_folder, name)
                torchaudio.save(
                    enhance_path,
                    torch.unsqueeze(pred_wav[: int(length)].cpu(), 0),
                    16000,
                )

        return cost

    def compute_ids(self, batch_id, optim_name):
        """Returns the list of ids, edited via optimizer name."""
        if optim_name == "D_enh":
            return [f"{uid}@{self.epoch}" for uid in batch_id]
        return batch_id

    def save_noisy_scores(self, batch_id, scores):
        for i, score in zip(batch_id, scores):
            self.noisy_scores[i] = score

    def score(self, batch_id, deg_wav, ref_wav, lens):
        """Returns actual metric score, either pesq or stoi

        Arguments
        ---------
        batch_id : list of str
            A list of the utterance ids for the batch
        deg_wav : torch.Tensor
            The degraded waveform to score
        ref_wav : torch.Tensor
            The reference waveform to use for scoring
        length : torch.Tensor
            The relative lengths of the utterances
        """
        new_ids = [
            i
            for i, d in enumerate(batch_id)
            if d not in self.historical_set and d not in self.noisy_scores
        ]

        if len(new_ids) == 0:
            pass
        elif self.hparams.target_metric == "srmr" or "dnsmos":
            self.target_metric.append(
                ids=[batch_id[i] for i in new_ids],
                predict=deg_wav[new_ids].detach(),
                target=ref_wav[
                    new_ids
                ].detach(),  # target is not used in the function !!!
                lengths=lens[new_ids],
            )
            score = torch.tensor(
                [[s] for s in self.target_metric.scores], device=self.device,
            )
        else:
            raise ValueError("Expected 'srmr' or 'dnsmos' for target_metric")

        # Clear metric scores to prepare for next batch
        self.target_metric.clear()

        # Combine old scores and new
        final_score = []
        for i, d in enumerate(batch_id):
            if d in self.historical_set:
                final_score.append([self.historical_set[d]["score"]])
            elif d in self.noisy_scores:
                final_score.append([self.noisy_scores[d]])
            else:
                final_score.append([score[new_ids.index(i)]])

        return torch.tensor(final_score, device=self.device)

    def est_score(self, deg_spec):
        """Returns score as estimated by discriminator

        Arguments
        ---------
        deg_spec : torch.Tensor
            The spectral features of the degraded utterance
        ref_spec : torch.Tensor
            The spectral features of the reference utterance
        """

        """
        combined_spec = torch.cat(
            [deg_spec.unsqueeze(1), ref_spec.unsqueeze(1)], 1
        )
        """
        return self.modules.discriminator(deg_spec.unsqueeze(1))

    def write_wavs(self, batch_id, wavs, score, lens):
        """Write wavs to files, for historical discriminator training

        Arguments
        ---------
        batch_id : list of str
            A list of the utterance ids for the batch
        wavs : torch.Tensor
            The wavs to write to files
        score : torch.Tensor
            The actual scores for the corresponding utterances
        lens : torch.Tensor
            The relative lengths of each utterance
        """
        lens = lens * wavs.shape[1]
        record = {}
        for i, (name, pred_wav, length) in enumerate(zip(batch_id, wavs, lens)):
            path = os.path.join(self.hparams.MetricGAN_folder, name + ".wav")
            data = torch.unsqueeze(pred_wav[: int(length)].cpu(), 0)
            torchaudio.save(path, data, self.hparams.Sample_rate)

            # Make record of path and score for historical training
            score = float(score[i][0])
            record[name] = {
                "enh_wav": path,
                "score": score,
            }

        # Update records for historical training
        self.historical_set.update(record)

        with open(self.hparams.historical_file, "wb") as fp:  # Pickling
            pickle.dump(self.historical_set, fp)

    def fit_batch(self, batch):
        "Compute gradients and update either D or G based on sub-stage."
        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
        loss_tracker = 0
        if self.sub_stage == SubStage.CURRENT:
            for mode in ["enh", "noisy"]:
                loss = self.compute_objectives(
                    predictions, batch, sb.Stage.TRAIN, f"D_{mode}"
                )
                self.d_optimizer.zero_grad()
                loss.backward()
                if self.check_gradients(loss):
                    self.d_optimizer.step()
                loss_tracker += loss.detach() / 3
        elif self.sub_stage == SubStage.HISTORICAL:
            loss = self.compute_objectives(
                predictions, batch, sb.Stage.TRAIN, "D_enh"
            )
            self.d_optimizer.zero_grad()
            loss.backward()
            if self.check_gradients(loss):
                self.d_optimizer.step()
            loss_tracker += loss.detach()
        elif self.sub_stage == SubStage.GENERATOR:
            for name, param in self.modules.generator.named_parameters():
                if "Learnable_sigmoid" in name:
                    param.data = torch.clamp(
                        param, max=3.5
                    )  # to prevent gradient goes to infinity

            loss = self.compute_objectives(
                predictions, batch, sb.Stage.TRAIN, "generator"
            )
            self.g_optimizer.zero_grad()
            loss.backward()
            if self.check_gradients(loss):
                self.g_optimizer.step()
            loss_tracker += loss.detach()

        return loss_tracker

    def on_stage_start(self, stage, epoch=None):
        """Gets called at the beginning of each epoch

        This method calls ``fit()`` again to train the discriminator
        before proceeding with generator training.
        """

        self.metrics = {"G": [], "D": []}

        if stage == sb.Stage.TRAIN:
            if self.hparams.target_metric == "srmr":
                self.target_metric = MetricStats(
                    metric=srmrpy_eval,
                    n_jobs=hparams["n_jobs"],
                    batch_eval=False,
                )
            elif self.hparams.target_metric == "dnsmos":
                self.target_metric = MetricStats(
                    metric=dnsmos_eval,
                    n_jobs=hparams["n_jobs"],
                    batch_eval=False,
                )
            else:
                raise NotImplementedError(
                    "Right now we only support 'srmr' and 'dnsmos'"
                )

            # Train discriminator before we start generator training
            if self.sub_stage == SubStage.GENERATOR:
                self.epoch = epoch
                self.train_discriminator()
                self.sub_stage = SubStage.GENERATOR
                print("Generator training by current data...")

        if stage != sb.Stage.TRAIN:
            self.pesq_metric = MetricStats(
                metric=pesq_eval, n_jobs=hparams["n_jobs"], batch_eval=False
            )
            self.stoi_metric = MetricStats(metric=stoi_loss)
            self.srmr_metric = MetricStats(
                metric=srmrpy_eval_valid,
                n_jobs=hparams["n_jobs"],
                batch_eval=False,
            )
            self.dnsmos_metric = MetricStats(
                metric=dnsmos_eval_valid,
                n_jobs=hparams["n_jobs"],
                batch_eval=False,
            )

    def train_discriminator(self):
        """A total of 3 data passes to update discriminator."""
        # First, iterate train subset w/ updates for enh, noisy
        print("Discriminator training by current data...")
        self.sub_stage = SubStage.CURRENT
        self.fit(
            range(1),
            self.train_set,
            train_loader_kwargs=self.hparams.dataloader_options,
        )

        # Next, iterate historical subset w/ updates for enh
        if self.historical_set:
            print("Discriminator training by historical data...")
            self.sub_stage = SubStage.HISTORICAL
            self.fit(
                range(1),
                self.historical_set,
                train_loader_kwargs=self.hparams.dataloader_options,
            )

        # Finally, iterate train set again. Should iterate same
        # samples as before, due to ReproducibleRandomSampler
        print("Discriminator training by current data again...")
        self.sub_stage = SubStage.CURRENT
        self.fit(
            range(1),
            self.train_set,
            train_loader_kwargs=self.hparams.dataloader_options,
        )

    def on_stage_end(self, stage, stage_loss, epoch=None):
        "Called at the end of each stage to summarize progress"
        if self.sub_stage != SubStage.GENERATOR:
            return

        if stage == sb.Stage.TRAIN:
            self.train_loss = stage_loss
            g_loss = torch.tensor(self.metrics["G"])  # batch_size
            d_loss = torch.tensor(self.metrics["D"])  # batch_size
            print("Avg G loss: %.3f" % torch.mean(g_loss))
            print("Avg D loss: %.3f" % torch.mean(d_loss))
        else:
            if self.hparams.calculate_dnsmos_on_validation_set:
                stats = {
                    "SI-SNR": -stage_loss,
                    "pesq": 5 * self.pesq_metric.summarize("average") - 0.5,
                    "stoi": -self.stoi_metric.summarize("average"),
                    "dnsmos": self.dnsmos_metric.summarize("average"),
                }
            else:
                stats = {
                    "SI-SNR": -stage_loss,
                    "pesq": 5 * self.pesq_metric.summarize("average") - 0.5,
                    "stoi": -self.stoi_metric.summarize("average"),
                }

        if stage == sb.Stage.VALID:
            old_lr, new_lr = self.hparams.lr_annealing(5.0 - stats["pesq"])
            sb.nnet.schedulers.update_learning_rate(self.g_optimizer, new_lr)

            if self.hparams.use_tensorboard:
                if (
                    self.hparams.calculate_dnsmos_on_validation_set
                ):  # Note: very time consuming........
                    valid_stats = {
                        "SI-SNR": -stage_loss,
                        "pesq": 5 * self.pesq_metric.summarize("average") - 0.5,
                        "stoi": -self.stoi_metric.summarize("average"),
                        "dnsmos": self.dnsmos_metric.summarize("average"),
                    }
                else:
                    valid_stats = {
                        "SI-SNR": -stage_loss,
                        "pesq": 5 * self.pesq_metric.summarize("average") - 0.5,
                        "stoi": -self.stoi_metric.summarize("average"),
                    }

                self.hparams.tensorboard_train_logger.log_stats(
                    {"lr": old_lr}, valid_stats
                )
            self.hparams.train_logger.log_stats(
                {"Epoch": epoch, "lr": old_lr},
                train_stats={"loss": self.train_loss},
                valid_stats=stats,
            )
            self.checkpointer.save_and_keep_only(meta=stats, max_keys=["pesq"])

        if stage == sb.Stage.TEST:
            self.hparams.train_logger.log_stats(
                {"Epoch loaded": self.hparams.epoch_counter.current},
                test_stats=stats,
            )

    def make_dataloader(
        self, dataset, stage, ckpt_prefix="dataloader-", **loader_kwargs
    ):
        "Override dataloader to insert custom sampler/dataset"
        if stage == sb.Stage.TRAIN:

            # Create a new dataset each time, this set grows
            if self.sub_stage == SubStage.HISTORICAL:
                dataset = sb.dataio.dataset.DynamicItemDataset(
                    data=dataset,
                    dynamic_items=[enh_pipeline],
                    output_keys=["id", "enh_sig", "score"],
                )
                samples = round(len(dataset) * self.hparams.history_portion)
            else:
                samples = self.hparams.number_of_samples

            # This sampler should give the same samples for D and G
            epoch = self.hparams.epoch_counter.current

            # Equal weights for all samples, we use "Weighted" so we can do
            # both "replacement=False" and a set number of samples, reproducibly
            weights = torch.ones(len(dataset))
            sampler = ReproducibleWeightedRandomSampler(
                weights, epoch=epoch, replacement=False, num_samples=samples
            )
            loader_kwargs["sampler"] = sampler

            if self.sub_stage == SubStage.GENERATOR:
                self.train_sampler = sampler

        # Make the dataloader as normal
        return super().make_dataloader(
            dataset, stage, ckpt_prefix, **loader_kwargs
        )

    def on_fit_start(self):
        "Override to prevent this from running for D training"
        if self.sub_stage == SubStage.GENERATOR:
            super().on_fit_start()

    def init_optimizers(self):
        "Initializes the generator and discriminator optimizers"
        self.g_optimizer = self.hparams.g_opt_class(
            self.modules.generator.parameters()
        )
        self.d_optimizer = self.hparams.d_opt_class(
            self.modules.discriminator.parameters()
        )

        if self.checkpointer is not None:
            self.checkpointer.add_recoverable("g_opt", self.g_optimizer)
            self.checkpointer.add_recoverable("d_opt", self.d_optimizer)
Exemplo n.º 11
0
class SEBrain(sb.Brain):
    def compute_forward(self, batch, stage):
        """Forward computations from the waveform batches to the enhanced output"""
        batch = batch.to(self.device)
        noisy_wavs, lens = batch.noisy_sig
        noisy_wavs = torch.unsqueeze(noisy_wavs, -1)
        predict_wavs = self.modules.model(noisy_wavs)[:, :, 0]

        return predict_wavs

    def compute_objectives(self, predict_wavs, batch, stage):
        """Computes the loss given the predicted and targeted outputs"""
        clean_wavs, lens = batch.clean_sig

        loss = self.hparams.compute_cost(predict_wavs, clean_wavs, lens)
        self.loss_metric.append(
            batch.id, predict_wavs, clean_wavs, lens, reduction="batch"
        )

        if stage != sb.Stage.TRAIN:

            # Evaluate speech quality/intelligibility
            self.stoi_metric.append(
                batch.id, predict_wavs, clean_wavs, lens, reduction="batch"
            )
            self.pesq_metric.append(
                batch.id, predict=predict_wavs, target=clean_wavs, lengths=lens
            )

            # Write wavs to file
            if stage == sb.Stage.TEST:
                lens = lens * clean_wavs.shape[1]
                for name, pred_wav, length in zip(batch.id, predict_wavs, lens):
                    name += ".wav"
                    enhance_path = os.path.join(
                        self.hparams.enhanced_folder, name
                    )
                    pred_wav = pred_wav / torch.max(torch.abs(pred_wav)) * 0.99
                    torchaudio.save(
                        enhance_path,
                        torch.unsqueeze(pred_wav[: int(length)].cpu(), 0),
                        16000,
                    )

        return loss

    def on_stage_start(self, stage, epoch=None):
        """Gets called at the beginning of each epoch"""
        self.loss_metric = MetricStats(metric=self.hparams.compute_cost)
        self.stoi_metric = MetricStats(metric=stoi_loss)

        # Define function taking (prediction, target) for parallel eval
        def pesq_eval(pred_wav, target_wav):
            """Computes the PESQ evaluation metric"""
            return pesq(
                fs=16000,
                ref=target_wav.numpy(),
                deg=pred_wav.numpy(),
                mode="wb",
            )

        if stage != sb.Stage.TRAIN:
            self.pesq_metric = MetricStats(
                metric=pesq_eval, n_jobs=1, batch_eval=False
            )

    def on_stage_end(self, stage, stage_loss, epoch=None):
        """Gets called at the end of an epoch."""
        if stage == sb.Stage.TRAIN:
            self.train_loss = stage_loss
            self.train_stats = {"loss": self.loss_metric.scores}
        else:
            stats = {
                "loss": stage_loss,
                "pesq": self.pesq_metric.summarize("average"),
                "stoi": -self.stoi_metric.summarize("average"),
            }

        if stage == sb.Stage.VALID:
            if self.hparams.use_tensorboard:
                valid_stats = {
                    "loss": self.loss_metric.scores,
                    "stoi": self.stoi_metric.scores,
                    "pesq": self.pesq_metric.scores,
                }
                self.hparams.tensorboard_train_logger.log_stats(
                    {"Epoch": epoch}, self.train_stats, valid_stats
                )
            self.hparams.train_logger.log_stats(
                {"Epoch": epoch},
                train_stats={"loss": self.train_loss},
                valid_stats=stats,
            )
            self.checkpointer.save_and_keep_only(meta=stats, max_keys=["pesq"])

        if stage == sb.Stage.TEST:
            self.hparams.train_logger.log_stats(
                {"Epoch loaded": self.hparams.epoch_counter.current},
                test_stats=stats,
            )
Exemplo n.º 12
0
class SEBrain(sb.Brain):
    def compute_forward(self, batch, stage):
        """Forward computations from the waveform batches to the enhanced output."""
        batch = batch.to(self.device)
        noisy_wavs, lens = batch.noisy_sig
        noisy_feats = self.compute_feats(noisy_wavs)

        # mask with "signal approximation (SA)"
        mask = self.modules.model(noisy_feats)
        mask = torch.squeeze(mask, 2)
        predict_spec = torch.mul(mask, noisy_feats)

        # Also return predicted wav
        predict_wav = self.hparams.resynth(
            torch.expm1(predict_spec), noisy_wavs
        )

        return predict_spec, predict_wav

    def compute_feats(self, wavs):
        """Feature computation pipeline"""
        feats = self.hparams.compute_STFT(wavs)
        feats = spectral_magnitude(feats, power=0.5)
        feats = torch.log1p(feats)
        return feats

    def compute_objectives(self, predictions, batch, stage):
        """Computes the loss given the predicted and targeted outputs"""
        predict_spec, predict_wav = predictions
        clean_wavs, lens = batch.clean_sig

        if getattr(self.hparams, "waveform_target", False):
            loss = self.hparams.compute_cost(predict_wav, clean_wavs, lens)
            self.loss_metric.append(
                batch.id, predict_wav, clean_wavs, lens, reduction="batch"
            )
        else:
            clean_spec = self.compute_feats(clean_wavs)
            loss = self.hparams.compute_cost(predict_spec, clean_spec, lens)
            self.loss_metric.append(
                batch.id, predict_spec, clean_spec, lens, reduction="batch"
            )

        if stage != sb.Stage.TRAIN:

            # Evaluate speech quality/intelligibility
            self.stoi_metric.append(
                batch.id, predict_wav, clean_wavs, lens, reduction="batch"
            )
            self.pesq_metric.append(
                batch.id, predict=predict_wav, target=clean_wavs, lengths=lens
            )

            # Write wavs to file
            if stage == sb.Stage.TEST:
                lens = lens * clean_wavs.shape[1]
                for name, pred_wav, length in zip(batch.id, predict_wav, lens):
                    name += ".wav"
                    enhance_path = os.path.join(
                        self.hparams.enhanced_folder, name
                    )
                    torchaudio.save(
                        enhance_path,
                        torch.unsqueeze(pred_wav[: int(length)].cpu(), 0),
                        16000,
                    )

        return loss

    def on_stage_start(self, stage, epoch=None):
        """Gets called at the beginning of each epoch"""
        self.loss_metric = MetricStats(metric=self.hparams.compute_cost)
        self.stoi_metric = MetricStats(metric=stoi_loss)

        # Define function taking (prediction, target) for parallel eval
        def pesq_eval(pred_wav, target_wav):
            """Computes the PESQ evaluation metric"""
            return pesq(
                fs=16000,
                ref=target_wav.numpy(),
                deg=pred_wav.numpy(),
                mode="wb",
            )

        if stage != sb.Stage.TRAIN:
            self.pesq_metric = MetricStats(metric=pesq_eval, n_jobs=4)

    def on_stage_end(self, stage, stage_loss, epoch=None):
        """Gets called at the end of an epoch."""
        if stage == sb.Stage.TRAIN:
            self.train_loss = stage_loss
            self.train_stats = {"loss": self.loss_metric.scores}
        else:
            stats = {
                "loss": stage_loss,
                "pesq": self.pesq_metric.summarize("average"),
                "stoi": -self.stoi_metric.summarize("average"),
            }

        if stage == sb.Stage.VALID:
            if self.hparams.use_tensorboard:
                valid_stats = {
                    "loss": self.loss_metric.scores,
                    "stoi": self.stoi_metric.scores,
                    "pesq": self.pesq_metric.scores,
                }
                self.hparams.tensorboard_train_logger.log_stats(
                    {"Epoch": epoch}, self.train_stats, valid_stats
                )
            self.hparams.train_logger.log_stats(
                {"Epoch": epoch},
                train_stats={"loss": self.train_loss},
                valid_stats=stats,
            )
            self.checkpointer.save_and_keep_only(meta=stats, max_keys=["pesq"])

        if stage == sb.Stage.TEST:
            self.hparams.train_logger.log_stats(
                {"Epoch loaded": self.hparams.epoch_counter.current},
                test_stats=stats,
            )
Exemplo n.º 13
0
class Separation(sb.Brain):
    def compute_forward(self, mix, targets, stage, noise=None):
        """Forward computations from the mixture to the separated signals."""

        # Unpack lists and put tensors in the right device
        mix, mix_lens = mix
        mix, mix_lens = mix.to(self.device), mix_lens.to(self.device)

        # Convert targets to tensor
        targets = torch.cat(
            [targets[i][0].unsqueeze(-1) for i in range(self.hparams.num_spks)],
            dim=-1,
        ).to(self.device)

        # Add speech distortions
        if stage == sb.Stage.TRAIN:
            with torch.no_grad():
                if self.hparams.use_speedperturb or self.hparams.use_rand_shift:
                    mix, targets = self.add_speed_perturb(targets, mix_lens)

                    if "whamr" in self.hparams.data_folder:
                        try:
                            targets_rev = [
                                self.hparams.reverb(targets[:, :, i], None)
                                for i in range(self.hparams.num_spks)
                            ]
                        except Exception:
                            print("reverb error, not adding reverb")
                            targets_rev = [
                                targets[:, :, i]
                                for i in range(self.hparams.num_spks)
                            ]

                        targets_rev = torch.stack(targets_rev, dim=-1)
                        mix = targets_rev.sum(-1)

                        # if we do not dereverberate, we set the targets to be reverberant
                        if not self.hparams.dereverberate:
                            targets = targets_rev
                    else:
                        mix = targets.sum(-1)

                    noise = noise.to(self.device)
                    len_noise = noise.shape[1]
                    len_mix = mix.shape[1]
                    min_len = min(len_noise, len_mix)

                    # add the noise
                    mix = mix[:, :min_len] + noise[:, :min_len]

                    # fix the length of targets also
                    targets = targets[:, :min_len, :]

                if self.hparams.use_wavedrop:
                    mix = self.hparams.wavedrop(mix, mix_lens)

                if self.hparams.limit_training_signal_len:
                    mix, targets = self.cut_signals(mix, targets)

        # torchaudio.save(
        #     'mix.wav', mix.data.cpu(), self.hparams.sample_rate
        # )

        # torchaudio.save(
        #     'targets.wav', targets.squeeze(-1).data.cpu(), self.hparams.sample_rate
        # )

        # Separation
        if self.use_freq_domain:
            mix_w = self.compute_feats(mix)
            est_mask = self.modules.masknet(mix_w)

            sep_h = mix_w * est_mask
            est_source = self.hparams.resynth(torch.expm1(sep_h), mix)
        else:
            mix_w = self.hparams.Encoder(mix)
            est_mask = self.modules.masknet(mix_w)

            mix_w = torch.stack([mix_w] * self.hparams.num_spks)
            sep_h = mix_w * est_mask
            est_source = torch.cat(
                [
                    self.hparams.Decoder(sep_h[i]).unsqueeze(-1)
                    for i in range(self.hparams.num_spks)
                ],
                dim=-1,
            )

        # T changed after conv1d in encoder, fix it here
        T_origin = mix.size(1)
        T_est = est_source.size(1)
        est_source = est_source.squeeze(-1)
        if T_origin > T_est:
            est_source = F.pad(est_source, (0, T_origin - T_est))
        else:
            est_source = est_source[:, :T_origin]

        return [est_source, sep_h], targets.squeeze(-1)

    def compute_feats(self, wavs):
        """Feature computation pipeline"""
        feats = self.hparams.Encoder(wavs)
        feats = spectral_magnitude(feats, power=0.5)
        feats = torch.log1p(feats)
        return feats

    def compute_objectives(self, predictions, targets):
        """Computes the si-snr loss"""
        predicted_wavs, predicted_specs = predictions

        if self.use_freq_domain:
            target_specs = self.compute_feats(targets)
            return self.hparams.loss(target_specs, predicted_specs)
        else:
            return self.hparams.loss(
                targets.unsqueeze(-1), predicted_wavs.unsqueeze(-1)
            )

    def fit_batch(self, batch):
        """Trains one batch"""
        # Unpacking batch list
        mixture = batch.mix_sig
        targets = [batch.s1_sig, batch.s2_sig]
        noise = batch.noise_sig[0]

        if self.auto_mix_prec:
            with autocast():
                predictions, targets = self.compute_forward(
                    mixture, targets, sb.Stage.TRAIN, noise
                )
                loss = self.compute_objectives(predictions, targets)

                # hard threshold the easy dataitems
                if self.hparams.threshold_byloss:
                    th = self.hparams.threshold
                    loss_to_keep = loss[loss > th]
                    if loss_to_keep.nelement() > 0:
                        loss = loss_to_keep.mean()
                else:
                    loss = loss.mean()

            if (
                loss < self.hparams.loss_upper_lim and loss.nelement() > 0
            ):  # the fix for computational problems
                self.scaler.scale(loss).backward()
                if self.hparams.clip_grad_norm >= 0:
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(
                        self.modules.parameters(), self.hparams.clip_grad_norm,
                    )
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                self.nonfinite_count += 1
                logger.info(
                    "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
                        self.nonfinite_count
                    )
                )
                loss.data = torch.tensor(0).to(self.device)
        else:
            predictions, targets = self.compute_forward(
                mixture, targets, sb.Stage.TRAIN, noise
            )
            loss = self.compute_objectives(predictions, targets)

            if self.hparams.threshold_byloss:
                th = self.hparams.threshold
                loss_to_keep = loss[loss > th]
                if loss_to_keep.nelement() > 0:
                    loss = loss_to_keep.mean()
            else:
                loss = loss.mean()

            if (
                loss < self.hparams.loss_upper_lim and loss.nelement() > 0
            ):  # the fix for computational problems
                loss.backward()
                if self.hparams.clip_grad_norm >= 0:
                    torch.nn.utils.clip_grad_norm_(
                        self.modules.parameters(), self.hparams.clip_grad_norm
                    )
                self.optimizer.step()
            else:
                self.nonfinite_count += 1
                logger.info(
                    "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
                        self.nonfinite_count
                    )
                )
                loss.data = torch.tensor(0).to(self.device)
        self.optimizer.zero_grad()

        return loss.detach().cpu()

    def evaluate_batch(self, batch, stage):
        """Computations needed for validation/test batches"""
        snt_id = batch.id
        mixture = batch.mix_sig
        targets = [batch.s1_sig, batch.s2_sig]

        with torch.no_grad():
            predictions, targets = self.compute_forward(mixture, targets, stage)
            loss = self.compute_objectives(predictions, targets)

        if stage != sb.Stage.TRAIN:
            self.pesq_metric.append(
                ids=batch.id, predict=predictions[0].cpu(), target=targets.cpu()
            )

        # Manage audio file saving
        if stage == sb.Stage.TEST and self.hparams.save_audio:
            if hasattr(self.hparams, "n_audio_to_save"):
                if self.hparams.n_audio_to_save > 0:
                    self.save_audio(snt_id[0], mixture, targets, predictions[0])
                    self.hparams.n_audio_to_save += -1
            else:
                self.save_audio(snt_id[0], mixture, targets, predictions[0])

        return loss.detach()

    def on_stage_start(self, stage, epoch=None):
        """Gets called at the beginning of each epoch"""
        if stage != sb.Stage.TRAIN:
            # Define function taking (prediction, target) for parallel eval
            def pesq_eval(pred_wav, target_wav):
                """Computes the PESQ evaluation metric"""
                psq_mode = "wb" if self.hparams.sample_rate == 16000 else "nb"
                try:
                    return pesq(
                        fs=self.hparams.sample_rate,
                        ref=target_wav.numpy(),
                        deg=pred_wav.numpy(),
                        mode=psq_mode,
                    )
                except Exception:
                    print("pesq encountered an error for this data item")
                    return 0

            self.pesq_metric = MetricStats(
                metric=pesq_eval, n_jobs=1, batch_eval=False
            )

    def on_stage_end(self, stage, stage_loss, epoch):
        """Gets called at the end of a epoch."""
        # Compute/store important stats
        stage_stats = {"loss": stage_loss}
        if stage == sb.Stage.TRAIN:
            self.train_stats = stage_stats
        else:
            stats = {
                "loss": stage_loss,
                "pesq": self.pesq_metric.summarize("average"),
            }

        # Perform end-of-iteration things, like annealing, logging, etc.
        if stage == sb.Stage.VALID:

            # Learning rate annealing
            if isinstance(
                self.hparams.lr_scheduler, schedulers.ReduceLROnPlateau
            ):
                current_lr, next_lr = self.hparams.lr_scheduler(
                    [self.optimizer], epoch, stage_loss
                )
                schedulers.update_learning_rate(self.optimizer, next_lr)
            else:
                # if we do not use the reducelronplateau, we do not change the lr
                current_lr = self.hparams.optimizer.optim.param_groups[0]["lr"]

            self.hparams.train_logger.log_stats(
                stats_meta={"epoch": epoch, "lr": current_lr},
                train_stats=self.train_stats,
                valid_stats=stats,
            )
            if (
                hasattr(self.hparams, "save_all_checkpoints")
                and self.hparams.save_all_checkpoints
            ):
                self.checkpointer.save_checkpoint(meta={"pesq": stats["pesq"]})
            else:
                self.checkpointer.save_and_keep_only(
                    meta={"pesq": stats["pesq"]}, max_keys=["pesq"],
                )
        elif stage == sb.Stage.TEST:
            self.hparams.train_logger.log_stats(
                stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
                test_stats=stats,
            )

    def add_speed_perturb(self, targets, targ_lens):
        """Adds speed perturbation and random_shift to the input signals"""

        min_len = -1
        recombine = False

        if self.hparams.use_speedperturb:
            # Performing speed change (independently on each source)
            new_targets = []
            recombine = True

            for i in range(targets.shape[-1]):
                new_target = self.hparams.speedperturb(
                    targets[:, :, i], targ_lens
                )
                new_targets.append(new_target)
                if i == 0:
                    min_len = new_target.shape[-1]
                else:
                    if new_target.shape[-1] < min_len:
                        min_len = new_target.shape[-1]

            if self.hparams.use_rand_shift:
                # Performing random_shift (independently on each source)
                recombine = True
                for i in range(targets.shape[-1]):
                    rand_shift = torch.randint(
                        self.hparams.min_shift, self.hparams.max_shift, (1,)
                    )
                    new_targets[i] = new_targets[i].to(self.device)
                    new_targets[i] = torch.roll(
                        new_targets[i], shifts=(rand_shift[0],), dims=1
                    )

            # Re-combination
            if recombine:
                if self.hparams.use_speedperturb:
                    targets = torch.zeros(
                        targets.shape[0],
                        min_len,
                        targets.shape[-1],
                        device=targets.device,
                        dtype=torch.float,
                    )
                for i, new_target in enumerate(new_targets):
                    targets[:, :, i] = new_targets[i][:, 0:min_len]

        mix = targets.sum(-1)
        return mix, targets

    def cut_signals(self, mixture, targets):
        """This function selects a random segment of a given length withing the mixture.
        The corresponding targets are selected accordingly"""
        randstart = torch.randint(
            0,
            1 + max(0, mixture.shape[1] - self.hparams.training_signal_len),
            (1,),
        ).item()
        targets = targets[
            :, randstart : randstart + self.hparams.training_signal_len, :
        ]
        mixture = mixture[
            :, randstart : randstart + self.hparams.training_signal_len
        ]
        return mixture, targets

    def reset_layer_recursively(self, layer):
        """Reinitializes the parameters of the neural networks"""
        if hasattr(layer, "reset_parameters"):
            layer.reset_parameters()
        for child_layer in layer.modules():
            if layer != child_layer:
                self.reset_layer_recursively(child_layer)

    def save_results(self, test_data):
        """This script computes the SDR and SI-SNR metrics and saves
        them into a csv file"""

        # This package is required for SDR computation
        from mir_eval.separation import bss_eval_sources

        # Create folders where to store audio
        save_file = os.path.join(self.hparams.output_folder, "test_results.csv")

        # Variable init
        all_sdrs = []
        all_sdrs_i = []
        all_sisnrs = []
        all_sisnrs_i = []
        all_pesqs = []
        csv_columns = ["snt_id", "sdr", "sdr_i", "si-snr", "si-snr_i", "pesq"]

        test_loader = sb.dataio.dataloader.make_dataloader(
            test_data, **self.hparams.dataloader_opts
        )

        with open(save_file, "w") as results_csv:
            writer = csv.DictWriter(results_csv, fieldnames=csv_columns)
            writer.writeheader()

            # Loop over all test sentence
            with tqdm(test_loader, dynamic_ncols=True) as t:
                for i, batch in enumerate(t):

                    # Apply Separation
                    mixture, mix_len = batch.mix_sig
                    snt_id = batch.id
                    targets = [batch.s1_sig, batch.s2_sig]
                    if self.hparams.num_spks == 3:
                        targets.append(batch.s3_sig)

                    with torch.no_grad():
                        predictions, targets = self.compute_forward(
                            batch.mix_sig, targets, sb.Stage.TEST
                        )

                    # Compute SI-SNR
                    sisnr = self.compute_objectives(predictions, targets)

                    # Compute SI-SNR improvement
                    mixture_signal = torch.stack(
                        [mixture] * self.hparams.num_spks, dim=-1
                    )
                    mixture_signal = mixture_signal.to(targets.device)
                    sisnr_baseline = self.compute_objectives(
                        [mixture_signal.squeeze(-1), None], targets
                    )
                    sisnr_i = sisnr - sisnr_baseline

                    # Compute SDR
                    sdr, _, _, _ = bss_eval_sources(
                        targets[0].t().cpu().numpy(),
                        predictions[0][0].t().detach().cpu().numpy(),
                    )

                    sdr_baseline, _, _, _ = bss_eval_sources(
                        targets[0].t().cpu().numpy(),
                        mixture_signal[0].t().detach().cpu().numpy(),
                    )

                    sdr_i = sdr.mean() - sdr_baseline.mean()

                    # Compute PESQ
                    psq_mode = (
                        "wb" if self.hparams.sample_rate == 16000 else "nb"
                    )
                    psq = pesq(
                        self.hparams.sample_rate,
                        targets.squeeze().cpu().numpy(),
                        predictions[0].squeeze().cpu().numpy(),
                        mode=psq_mode,
                    )

                    # Saving on a csv file
                    row = {
                        "snt_id": snt_id[0],
                        "sdr": sdr.mean(),
                        "sdr_i": sdr_i,
                        "si-snr": -sisnr.item(),
                        "si-snr_i": -sisnr_i.item(),
                        "pesq": psq,
                    }
                    writer.writerow(row)

                    # Metric Accumulation
                    all_sdrs.append(sdr.mean())
                    all_sdrs_i.append(sdr_i.mean())
                    all_sisnrs.append(-sisnr.item())
                    all_sisnrs_i.append(-sisnr_i.item())
                    all_pesqs.append(psq)

                row = {
                    "snt_id": "avg",
                    "sdr": np.array(all_sdrs).mean(),
                    "sdr_i": np.array(all_sdrs_i).mean(),
                    "si-snr": np.array(all_sisnrs).mean(),
                    "si-snr_i": np.array(all_sisnrs_i).mean(),
                    "pesq": np.array(all_pesqs).mean(),
                }
                writer.writerow(row)

        logger.info("Mean SISNR is {}".format(np.array(all_sisnrs).mean()))
        logger.info("Mean SISNRi is {}".format(np.array(all_sisnrs_i).mean()))
        logger.info("Mean SDR is {}".format(np.array(all_sdrs).mean()))
        logger.info("Mean SDRi is {}".format(np.array(all_sdrs_i).mean()))
        logger.info("Mean PESQ {}".format(np.array(all_pesqs).mean()))

    def save_audio(self, snt_id, mixture, targets, predictions):
        "saves the test audio (mixture, targets, and estimated sources) on disk"

        # Create outout folder
        save_path = os.path.join(self.hparams.save_folder, "audio_results")
        if not os.path.exists(save_path):
            os.mkdir(save_path)

        # Estimated source
        signal = predictions[0, :]
        signal = signal / signal.abs().max()
        save_file = os.path.join(
            save_path, "item{}_sourcehat.wav".format(snt_id)
        )
        torchaudio.save(
            save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate
        )

        # Original source
        signal = targets[0, :]
        signal = signal / signal.abs().max()
        save_file = os.path.join(save_path, "item{}_source.wav".format(snt_id))
        torchaudio.save(
            save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate
        )

        # Mixture
        signal = mixture[0][0, :]
        signal = signal / signal.abs().max()
        save_file = os.path.join(save_path, "item{}_mix.wav".format(snt_id))
        torchaudio.save(
            save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate
        )
Exemplo n.º 14
0
class SEBrain(sb.Brain):
    def compute_forward_g(self, noisy_wavs):
        """Forward computations of the generator. Input noisy signal,
        output clean signal"""
        noisy_wavs = noisy_wavs.to(self.device)
        predict_wavs = self.modules["model_g"](noisy_wavs)
        return predict_wavs

    def compute_forward_d(self, noisy_wavs, clean_wavs):
        """Forward computations from discriminator. Input denoised-noisy pair,
        output whether denoising was properly acheived"""
        noisy_wavs = noisy_wavs.to(self.device)
        clean_wavs = clean_wavs.to(self.device)

        inpt = torch.cat((noisy_wavs, clean_wavs), -1)
        out = self.modules["model_d"](inpt)
        return out

    def compute_objectives_d1(self, d_outs, batch):
        """Computes the loss of a discriminator given predicted and
        targeted outputs, with target being clean"""
        loss = self.hparams.compute_cost["d1"](d_outs)
        self.loss_metric_d1.append(batch.id, d_outs, reduction="batch")
        return loss

    def compute_objectives_d2(self, d_outs, batch):
        """Computes the loss of a discriminator given predicted and targeted outputs,
        with target being noisy"""
        loss = self.hparams.compute_cost["d2"](d_outs)
        self.loss_metric_d2.append(batch.id, d_outs, reduction="batch")
        return loss

    def compute_objectives_g3(
        self,
        d_outs,
        predict_wavs,
        clean_wavs,
        batch,
        stage,
        z_mean=None,
        z_logvar=None,
    ):
        """Computes the loss of the generator based on discriminator and generator losses"""
        clean_wavs_orig, lens = batch.clean_sig
        clean_wavs_orig = clean_wavs_orig.to(self.device)
        clean_wavs = clean_wavs.to(self.device)

        loss = self.hparams.compute_cost["g3"](
            d_outs,
            predict_wavs,
            clean_wavs,
            lens,
            l1LossCoeff=self.hparams.l1LossCoeff,
            klLossCoeff=self.hparams.klLossCoeff,
            z_mean=z_mean,
            z_logvar=z_logvar,
        )
        self.loss_metric_g3.append(
            batch.id,
            d_outs,
            predict_wavs,
            clean_wavs,
            lens,
            l1LossCoeff=self.hparams.l1LossCoeff,
            klLossCoeff=self.hparams.klLossCoeff,
            z_mean=z_mean,
            z_logvar=z_logvar,
            reduction="batch",
        )

        if stage != sb.Stage.TRAIN:

            # Evaluate speech quality/intelligibility
            predict_wavs = predict_wavs.reshape(self.batch_current, -1)
            clean_wavs = clean_wavs.reshape(self.batch_current, -1)

            predict_wavs = predict_wavs[:, 0:self.original_len]
            clean_wavs = clean_wavs[:, 0:self.original_len]

            self.stoi_metric.append(batch.id,
                                    predict_wavs,
                                    clean_wavs,
                                    lens,
                                    reduction="batch")
            self.pesq_metric.append(batch.id,
                                    predict=predict_wavs.cpu(),
                                    target=clean_wavs.cpu())

            # Write enhanced test wavs to file
            if stage == sb.Stage.TEST:
                lens = lens * clean_wavs.shape[1]
                for name, pred_wav, length in zip(batch.id, predict_wavs,
                                                  lens):
                    name += ".wav"
                    enhance_path = os.path.join(self.hparams.enhanced_folder,
                                                name)
                    print(enhance_path)

                    pred_wav = pred_wav / torch.max(torch.abs(pred_wav)) * 0.99
                    torchaudio.save(
                        enhance_path,
                        pred_wav[:int(length)].cpu().unsqueeze(0),
                        hparams["sample_rate"],
                    )
        return loss

    def fit_batch(self, batch):
        """Fit one batch, override to do multiple updates.

        The default implementation depends on a few methods being defined
        with a particular behavior:

        * ``compute_forward()``
        * ``compute_objectives()``

        Also depends on having optimizers passed at initialization.

        Arguments
        ---------
        batch : list of torch.Tensors
            Batch of data to use for training. Default implementation assumes
            this batch has two elements: inputs and targets.

        Returns
        -------
        detached loss
        """
        noisy_wavs, lens = batch.noisy_sig
        clean_wavs, lens = batch.clean_sig

        # split sentences in smaller chunks
        noisy_wavs = create_chunks(
            noisy_wavs,
            chunk_size=hparams["chunk_size"],
            chunk_stride=hparams["chunk_stride"],
        )
        clean_wavs = create_chunks(
            clean_wavs,
            chunk_size=hparams["chunk_size"],
            chunk_stride=hparams["chunk_stride"],
        )

        # first of three step training process detailed in SEGAN paper
        out_d1 = self.compute_forward_d(noisy_wavs, clean_wavs)
        loss_d1 = self.compute_objectives_d1(out_d1, batch)
        loss_d1.backward()
        if self.check_gradients(loss_d1):
            self.optimizer_d.step()
        self.optimizer_d.zero_grad()

        # second training step
        z_mean = None
        z_logvar = None
        if self.modules["model_g"].latent_vae:
            out_g2, z_mean, z_logvar = self.compute_forward_g(noisy_wavs)
        else:
            out_g2 = self.compute_forward_g(noisy_wavs)
        out_d2 = self.compute_forward_d(out_g2, clean_wavs)
        loss_d2 = self.compute_objectives_d2(out_d2, batch)
        loss_d2.backward(retain_graph=True)
        if self.check_gradients(loss_d2):
            self.optimizer_d.step()
        self.optimizer_d.zero_grad()

        # third (last) training step
        self.optimizer_g.zero_grad()
        out_d3 = self.compute_forward_d(out_g2, clean_wavs)
        loss_g3 = self.compute_objectives_g3(
            out_d3,
            out_g2,
            clean_wavs,
            batch,
            sb.Stage.TRAIN,
            z_mean=z_mean,
            z_logvar=z_logvar,
        )
        loss_g3.backward()
        if self.check_gradients(loss_g3):
            self.optimizer_g.step()
        self.optimizer_g.zero_grad()
        self.optimizer_d.zero_grad()

        loss_d1.detach().cpu()
        loss_d2.detach().cpu()
        loss_g3.detach().cpu()

        return loss_d1 + loss_d2 + loss_g3

    def evaluate_batch(self, batch, stage):
        """Evaluate one batch, override for different procedure than train.

        The default implementation depends on two methods being defined
        with a particular behavior:

        * ``compute_forward()``
        * ``compute_objectives()``

        Arguments
        ---------
        batch : list of torch.Tensors
            Batch of data to use for evaluation. Default implementation assumes
            this batch has two elements: inputs and targets.
        stage : Stage
            The stage of the experiment: Stage.VALID, Stage.TEST

        Returns
        -------
        detached loss
        """
        noisy_wavs, lens = batch.noisy_sig
        clean_wavs, lens = batch.clean_sig
        self.batch_current = clean_wavs.shape[0]
        self.original_len = clean_wavs.shape[1]

        # Add padding to make sure all the signal will be processed.
        padding_elements = torch.zeros(clean_wavs.shape[0],
                                       hparams["chunk_size"],
                                       device=clean_wavs.device)
        clean_wavs = torch.cat([clean_wavs, padding_elements], dim=1)
        noisy_wavs = torch.cat([noisy_wavs, padding_elements], dim=1)

        # Split sentences in smaller chunks
        noisy_wavs = create_chunks(
            noisy_wavs,
            chunk_size=hparams["chunk_size"],
            chunk_stride=hparams["chunk_size"],
        )
        clean_wavs = create_chunks(
            clean_wavs,
            chunk_size=hparams["chunk_size"],
            chunk_stride=hparams["chunk_size"],
        )

        # Perform speech enhancement with the current model
        out_d1 = self.compute_forward_d(noisy_wavs, clean_wavs)
        loss_d1 = self.compute_objectives_d1(out_d1, batch)

        z_mean = None
        z_logvar = None
        if self.modules["model_g"].latent_vae:
            out_g2, z_mean, z_logvar = self.compute_forward_g(noisy_wavs)
        else:
            out_g2 = self.compute_forward_g(noisy_wavs)
        out_d2 = self.compute_forward_d(out_g2, clean_wavs)
        loss_d2 = self.compute_objectives_d2(out_d2, batch)

        loss_g3 = self.compute_objectives_g3(
            out_d2,
            out_g2,
            clean_wavs,
            batch,
            stage=stage,
            z_mean=z_mean,
            z_logvar=z_logvar,
        )

        loss_d1.detach().cpu()
        loss_d2.detach().cpu()
        loss_g3.detach().cpu()

        return loss_d1 + loss_d2 + loss_g3

    def init_optimizers(self):
        """Called during ``on_fit_start()``, initialize optimizers
        after parameters are fully configured (e.g. DDP, jit).

        The default implementation of this method depends on an optimizer
        class being passed at initialization that takes only a list
        of parameters (e.g., a lambda or a partial function definition).
        This creates a single optimizer that optimizes all trainable params.

        Override this class if there are multiple optimizers.
        """
        if self.opt_class is not None:
            self.optimizer_d = self.opt_class(
                self.modules["model_d"].parameters())
            self.optimizer_g = self.opt_class(
                self.modules["model_g"].parameters())

            if self.checkpointer is not None:
                self.checkpointer.add_recoverable("optimizer_g",
                                                  self.optimizer_g)
                self.checkpointer.add_recoverable("optimizer_d",
                                                  self.optimizer_d)

    def on_stage_start(self, stage, epoch=None):
        """Gets called at the beginning of each epoch"""
        self.loss_metric_d1 = MetricStats(
            metric=self.hparams.compute_cost["d1"])
        self.loss_metric_d2 = MetricStats(
            metric=self.hparams.compute_cost["d2"])
        self.loss_metric_g3 = MetricStats(
            metric=self.hparams.compute_cost["g3"])
        self.stoi_metric = MetricStats(metric=stoi_loss)

        # Define function taking (prediction, target) for parallel eval
        def pesq_eval(pred_wav, target_wav):
            """Computes the PESQ evaluation metric"""
            return pesq(
                fs=hparams["sample_rate"],
                ref=target_wav.numpy().squeeze(),
                deg=pred_wav.numpy().squeeze(),
                mode="wb",
            )

        if stage != sb.Stage.TRAIN:
            self.pesq_metric = MetricStats(metric=pesq_eval,
                                           batch_eval=False,
                                           n_jobs=1)

    def on_stage_end(self, stage, stage_loss, epoch=None):
        """Gets called at the end of an epoch."""
        if stage == sb.Stage.TRAIN:
            self.train_loss = stage_loss
            self.train_stats = {  # "loss": self.loss_metric.scores,
                "loss_d1": self.loss_metric_d1.scores,
                "loss_d2": self.loss_metric_d2.scores,
                "loss_g3": self.loss_metric_g3.scores,
            }
        else:
            stats = {
                "loss": stage_loss,
                "pesq": self.pesq_metric.summarize("average"),
                "stoi": -self.stoi_metric.summarize("average"),
            }

        if stage == sb.Stage.VALID:
            if self.hparams.use_tensorboard:
                valid_stats = {
                    # "loss": self.loss_metric.scores,
                    "loss_d1": self.loss_metric_d1.scores,
                    "loss_d2": self.loss_metric_d2.scores,
                    "loss_g3": self.loss_metric_g3.scores,
                    "stoi": self.stoi_metric.scores,
                    "pesq": self.pesq_metric.scores,
                }
                self.hparams.tensorboard_train_logger.log_stats(
                    {"Epoch": epoch}, self.train_stats, valid_stats)
            self.hparams.train_logger.log_stats(
                {"Epoch": epoch},
                train_stats={"loss": self.train_loss},
                valid_stats=stats,
            )
            self.checkpointer.save_and_keep_only(meta=stats, max_keys=["pesq"])

        if stage == sb.Stage.TEST:
            self.hparams.train_logger.log_stats(
                {"Epoch loaded": self.hparams.epoch_counter.current},
                test_stats=stats,
            )