示例#1
0
    def on_stage_start(self, stage, epoch=None):
        """Gets called at the beginning of each epoch

        This method calls ``fit()`` again to train the discriminator
        before proceeding with generator training.
        """

        self.mse_metric = MetricStats(metric=self.hparams.compute_cost)
        self.metrics = {"G": [], "D": []}

        if stage == sb.Stage.TRAIN:
            if self.hparams.target_metric == "pesq":
                self.target_metric = MetricStats(metric=pesq_eval, n_jobs=40)
            elif self.hparams.target_metric == "stoi":
                self.target_metric = MetricStats(metric=stoi_loss)
            else:
                raise NotImplementedError(
                    "Right now we only support 'pesq' and 'stoi'"
                )

            # Train discriminator before we start generator training
            if self.sub_stage == SubStage.GENERATOR:
                self.epoch = epoch
                self.train_discriminator()
                self.sub_stage = SubStage.GENERATOR
                print("Generator training by current data...")

        if stage != sb.Stage.TRAIN:
            self.pesq_metric = MetricStats(metric=pesq_eval, n_jobs=30)
            self.stoi_metric = MetricStats(metric=stoi_loss)
示例#2
0
    def on_stage_start(self, stage, epoch=None):
        self.loss_metric = MetricStats(metric=self.hparams.compute_cost)
        self.stoi_metric = MetricStats(metric=stoi_loss)

        # Define function taking (prediction, target) for parallel eval
        def pesq_eval(pred_wav, target_wav):
            return pesq(
                fs=16000,
                ref=target_wav.cpu().numpy(),
                deg=pred_wav.cpu().numpy(),
                mode="wb",
            )

        if stage != sb.Stage.TRAIN:
            self.pesq_metric = MetricStats(metric=pesq_eval, n_jobs=4)
示例#3
0
    def on_stage_start(self, stage, epoch=None):
        """Gets called at the beginning of each epoch"""
        self.loss_metric = MetricStats(metric=self.hparams.compute_cost)
        self.stoi_metric = MetricStats(metric=stoi_loss)

        # Define function taking (prediction, target) for parallel eval
        def pesq_eval(pred_wav, target_wav):
            """Computes the PESQ evaluation metric"""
            return pesq(
                fs=16000,
                ref=target_wav.numpy(),
                deg=pred_wav.numpy(),
                mode="wb",
            )

        if stage != sb.Stage.TRAIN:
            self.pesq_metric = MetricStats(metric=pesq_eval, n_jobs=4)
示例#4
0
    def on_stage_start(self, stage, epoch=None):
        """Gets called at the beginning of each epoch"""
        self.loss_metric_d1 = MetricStats(
            metric=self.hparams.compute_cost["d1"])
        self.loss_metric_d2 = MetricStats(
            metric=self.hparams.compute_cost["d2"])
        self.loss_metric_g3 = MetricStats(
            metric=self.hparams.compute_cost["g3"])
        self.stoi_metric = MetricStats(metric=stoi_loss)

        # Define function taking (prediction, target) for parallel eval
        def pesq_eval(pred_wav, target_wav):
            """Computes the PESQ evaluation metric"""
            return pesq(
                fs=hparams["sample_rate"],
                ref=target_wav.numpy().squeeze(),
                deg=pred_wav.numpy().squeeze(),
                mode="wb",
            )

        if stage != sb.Stage.TRAIN:
            self.pesq_metric = MetricStats(metric=pesq_eval,
                                           batch_eval=False,
                                           n_jobs=1)
示例#5
0
def test_metric_stats():
    from speechbrain.utils.metric_stats import MetricStats
    from speechbrain.nnet.losses import l1_loss

    l1_stats = MetricStats(metric=l1_loss)
    l1_stats.append(
        ids=["utterance1", "utterance2"],
        predictions=torch.tensor([[0.1, 0.2], [0.1, 0.2]]),
        targets=torch.tensor([[0.1, 0.3], [0.2, 0.3]]),
        length=torch.ones(2),
        reduction="batch",
    )
    summary = l1_stats.summarize()
    assert math.isclose(summary["average"], 0.075, rel_tol=1e-5)
    assert math.isclose(summary["min_score"], 0.05, rel_tol=1e-5)
    assert summary["min_id"] == "utterance1"
    assert math.isclose(summary["max_score"], 0.1, rel_tol=1e-5)
    assert summary["max_id"] == "utterance2"
示例#6
0
    def on_stage_start(self, stage, epoch=None):
        """Gets called at the beginning of each epoch

        This method calls ``fit()`` again to train the discriminator
        before proceeding with generator training.
        """

        self.metrics = {"G": [], "D": []}

        if stage == sb.Stage.TRAIN:
            if self.hparams.target_metric == "srmr":
                self.target_metric = MetricStats(
                    metric=srmrpy_eval,
                    n_jobs=hparams["n_jobs"],
                    batch_eval=False,
                )
            elif self.hparams.target_metric == "dnsmos":
                self.target_metric = MetricStats(
                    metric=dnsmos_eval,
                    n_jobs=hparams["n_jobs"],
                    batch_eval=False,
                )
            else:
                raise NotImplementedError(
                    "Right now we only support 'srmr' and 'dnsmos'"
                )

            # Train discriminator before we start generator training
            if self.sub_stage == SubStage.GENERATOR:
                self.epoch = epoch
                self.train_discriminator()
                self.sub_stage = SubStage.GENERATOR
                print("Generator training by current data...")

        if stage != sb.Stage.TRAIN:
            self.pesq_metric = MetricStats(
                metric=pesq_eval, n_jobs=hparams["n_jobs"], batch_eval=False
            )
            self.stoi_metric = MetricStats(metric=stoi_loss)
            self.srmr_metric = MetricStats(
                metric=srmrpy_eval_valid,
                n_jobs=hparams["n_jobs"],
                batch_eval=False,
            )
            self.dnsmos_metric = MetricStats(
                metric=dnsmos_eval_valid,
                n_jobs=hparams["n_jobs"],
                batch_eval=False,
            )
示例#7
0
    def on_stage_start(self, stage, epoch=None):
        """Gets called at the beginning of each epoch"""
        if stage != sb.Stage.TRAIN:
            # Define function taking (prediction, target) for parallel eval
            def pesq_eval(pred_wav, target_wav):
                """Computes the PESQ evaluation metric"""
                psq_mode = "wb" if self.hparams.sample_rate == 16000 else "nb"
                try:
                    return pesq(
                        fs=self.hparams.sample_rate,
                        ref=target_wav.numpy(),
                        deg=pred_wav.numpy(),
                        mode=psq_mode,
                    )
                except Exception:
                    print("pesq encountered an error for this data item")
                    return 0

            self.pesq_metric = MetricStats(
                metric=pesq_eval, n_jobs=1, batch_eval=False
            )