def on_stage_start(self, stage, epoch=None): """Gets called at the beginning of each epoch This method calls ``fit()`` again to train the discriminator before proceeding with generator training. """ self.mse_metric = MetricStats(metric=self.hparams.compute_cost) self.metrics = {"G": [], "D": []} if stage == sb.Stage.TRAIN: if self.hparams.target_metric == "pesq": self.target_metric = MetricStats(metric=pesq_eval, n_jobs=40) elif self.hparams.target_metric == "stoi": self.target_metric = MetricStats(metric=stoi_loss) else: raise NotImplementedError( "Right now we only support 'pesq' and 'stoi'" ) # Train discriminator before we start generator training if self.sub_stage == SubStage.GENERATOR: self.epoch = epoch self.train_discriminator() self.sub_stage = SubStage.GENERATOR print("Generator training by current data...") if stage != sb.Stage.TRAIN: self.pesq_metric = MetricStats(metric=pesq_eval, n_jobs=30) self.stoi_metric = MetricStats(metric=stoi_loss)
def on_stage_start(self, stage, epoch=None): self.loss_metric = MetricStats(metric=self.hparams.compute_cost) self.stoi_metric = MetricStats(metric=stoi_loss) # Define function taking (prediction, target) for parallel eval def pesq_eval(pred_wav, target_wav): return pesq( fs=16000, ref=target_wav.cpu().numpy(), deg=pred_wav.cpu().numpy(), mode="wb", ) if stage != sb.Stage.TRAIN: self.pesq_metric = MetricStats(metric=pesq_eval, n_jobs=4)
def on_stage_start(self, stage, epoch=None): """Gets called at the beginning of each epoch This method calls ``fit()`` again to train the discriminator before proceeding with generator training. """ self.metrics = {"G": [], "D": []} if stage == sb.Stage.TRAIN: if self.hparams.target_metric == "srmr": self.target_metric = MetricStats( metric=srmrpy_eval, n_jobs=hparams["n_jobs"], batch_eval=False, ) elif self.hparams.target_metric == "dnsmos": self.target_metric = MetricStats( metric=dnsmos_eval, n_jobs=hparams["n_jobs"], batch_eval=False, ) else: raise NotImplementedError( "Right now we only support 'srmr' and 'dnsmos'") # Train discriminator before we start generator training if self.sub_stage == SubStage.GENERATOR: self.epoch = epoch self.train_discriminator() self.sub_stage = SubStage.GENERATOR print("Generator training by current data...") if stage != sb.Stage.TRAIN: self.pesq_metric = MetricStats(metric=pesq_eval, n_jobs=hparams["n_jobs"], batch_eval=False) self.stoi_metric = MetricStats(metric=stoi_loss) self.srmr_metric = MetricStats( metric=srmrpy_eval_valid, n_jobs=hparams["n_jobs"], batch_eval=False, ) self.dnsmos_metric = MetricStats( metric=dnsmos_eval_valid, n_jobs=hparams["n_jobs"], batch_eval=False, )
def on_stage_start(self, stage, epoch=None): """Gets called at the beginning of each epoch""" self.loss_metric = MetricStats(metric=self.hparams.compute_cost) self.stoi_metric = MetricStats(metric=stoi_loss) # Define function taking (prediction, target) for parallel eval def pesq_eval(pred_wav, target_wav): """Computes the PESQ evaluation metric""" return pesq( fs=16000, ref=target_wav.numpy(), deg=pred_wav.numpy(), mode="wb", ) if stage != sb.Stage.TRAIN: self.pesq_metric = MetricStats(metric=pesq_eval, n_jobs=4)
def test_metric_stats(): from speechbrain.utils.metric_stats import MetricStats from speechbrain.nnet.losses import l1_loss l1_stats = MetricStats(metric=l1_loss) l1_stats.append( ids=["utterance1", "utterance2"], predictions=torch.tensor([[0.1, 0.2], [0.1, 0.2]]), targets=torch.tensor([[0.1, 0.3], [0.2, 0.3]]), length=torch.ones(2), reduction="batch", ) summary = l1_stats.summarize() assert math.isclose(summary["average"], 0.075, rel_tol=1e-5) assert math.isclose(summary["min_score"], 0.05, rel_tol=1e-5) assert summary["min_id"] == "utterance1" assert math.isclose(summary["max_score"], 0.1, rel_tol=1e-5) assert summary["max_id"] == "utterance2"
def on_stage_start(self, stage, epoch=None): """Gets called at the beginning of each epoch""" if stage != sb.Stage.TRAIN: # Define function taking (prediction, target) for parallel eval def pesq_eval(pred_wav, target_wav): """Computes the PESQ evaluation metric""" psq_mode = "wb" if self.hparams.sample_rate == 16000 else "nb" try: return pesq( fs=self.hparams.sample_rate, ref=target_wav.numpy(), deg=pred_wav.numpy(), mode=psq_mode, ) except Exception: print("pesq encountered an error for this data item") return 0 self.pesq_metric = MetricStats( metric=pesq_eval, n_jobs=1, batch_eval=False )
def on_stage_start(self, stage, epoch=None): """Gets called at the beginning of each epoch""" self.loss_metric_d1 = MetricStats( metric=self.hparams.compute_cost["d1"]) self.loss_metric_d2 = MetricStats( metric=self.hparams.compute_cost["d2"]) self.loss_metric_g3 = MetricStats( metric=self.hparams.compute_cost["g3"]) self.stoi_metric = MetricStats(metric=stoi_loss) # Define function taking (prediction, target) for parallel eval def pesq_eval(pred_wav, target_wav): """Computes the PESQ evaluation metric""" return pesq( fs=hparams["sample_rate"], ref=target_wav.numpy().squeeze(), deg=pred_wav.numpy().squeeze(), mode="wb", ) if stage != sb.Stage.TRAIN: self.pesq_metric = MetricStats(metric=pesq_eval, batch_eval=False, n_jobs=1)
class MetricGanBrain(sb.Brain): def compute_feats(self, wavs): """Feature computation pipeline""" feats = self.hparams.compute_STFT(wavs) feats = spectral_magnitude(feats, power=0.5) feats = torch.log1p(feats) return feats def compute_forward(self, batch, stage): "Given an input batch computes the enhanced signal" batch = batch.to(self.device) if self.sub_stage == SubStage.HISTORICAL: predict_wav, lens = batch.enh_sig else: noisy_wav, lens = batch.noisy_sig noisy_spec = self.compute_feats(noisy_wav) # mask with "signal approximation (SA)" # print("check noisy spec",noisy_spec.size()) mask = self.modules.generator(noisy_spec, lengths=lens) mask = mask.clamp(min=self.hparams.min_mask).squeeze(2) predict_spec = torch.mul(mask, noisy_spec) # Also return predicted wav predict_wav = self.hparams.resynth( torch.expm1(predict_spec), noisy_wav ) return predict_wav def compute_objectives(self, predictions, batch, stage, optim_name=""): "Given the network predictions and targets compute the total loss" predict_wav = predictions predict_spec = self.compute_feats(predict_wav) clean_wav, lens = batch.clean_sig clean_spec = self.compute_feats(clean_wav) # print(clean_wav.size(), predict_wav.size()) # print(clean_spec.size(), predict_spec.size()) mse_cost = self.hparams.compute_cost(predict_spec, clean_spec, lens) # print("batch.id", batch.id) ids = self.compute_ids(batch.id, optim_name) # One is real, zero is fake if optim_name == "generator" or optim_name == "": if optim_name == "generator": target_score = torch.ones(self.batch_size, 1, device=self.device) else: target_score = torch.ones(1, 1, device=self.device) # if optim_name == "": # print("predict_wav", predict_wav.size(), clean_wav.size()) # print("predict_spec: ", predict_spec.size(), clean_spec.size()) est_score = self.est_score(predict_spec, clean_spec) self.mse_metric.append( ids, predict_spec, clean_spec, lens, reduction="batch" ) # D Learns to estimate the scores of clean speech elif optim_name == "D_clean": target_score = torch.ones(self.batch_size, 1, device=self.device) est_score = self.est_score(clean_spec, clean_spec) # D Learns to estimate the scores of enhanced speech elif optim_name == "D_enh" and self.sub_stage == SubStage.CURRENT: target_score = self.score(ids, predict_wav, clean_wav, lens) est_score = self.est_score(predict_spec, clean_spec) # Write enhanced wavs during discriminator training, because we # compute the actual score here and we can save it self.write_wavs(batch.id, ids, predict_wav, target_score, lens) # D Relearns to estimate the scores of previous epochs elif optim_name == "D_enh" and self.sub_stage == SubStage.HISTORICAL: target_score = batch.score.unsqueeze(1).float() est_score = self.est_score(predict_spec, clean_spec) # D Learns to estimate the scores of noisy speech elif optim_name == "D_noisy": noisy_wav, _ = batch.noisy_sig noisy_spec = self.compute_feats(noisy_wav) target_score = self.score(ids, noisy_wav, clean_wav, lens) est_score = self.est_score(noisy_spec, clean_spec) # Save scores of noisy wavs self.save_noisy_scores(ids, target_score) else: raise ValueError(f"{optim_name} is not a valid 'optim_name'") # Compute the cost adv_cost = self.hparams.compute_cost(est_score, target_score) if optim_name == "generator": adv_cost += self.hparams.mse_weight * mse_cost self.metrics["G"].append(adv_cost.detach()) else: self.metrics["D"].append(adv_cost.detach()) # On validation data compute scores if stage != sb.Stage.TRAIN: # Evaluate speech quality/intelligibility if self.hparams.mode == 'val': if self.hparams.target_metric == "stoi": self.stoi_metric.append( batch.id, predict_wav, clean_wav, lens, reduction="batch" ) elif self.hparams.target_metric == "pesq": self.pesq_metric.append( batch.id, predict=predict_wav, target=clean_wav, lengths=lens ) # Write wavs to file, for evaluation elif self.hparams.mode == 'test': self.stoi_metric.append( batch.id, predict_wav, clean_wav, lens, reduction="batch" ) self.pesq_metric.append( batch.id, predict=predict_wav, target=clean_wav, lengths=lens ) lens = lens * clean_wav.shape[1] for name, pred_wav, length in zip(batch.id, predict_wav, lens): name += ".wav" enhance_path = os.path.join(self.hparams.enhanced_folder, name) torchaudio.save( enhance_path, torch.unsqueeze(pred_wav[: int(length)].cpu(), 0), 16000, ) # we do not use mse_cost to update model return adv_cost def compute_ids(self, batch_id, optim_name): """Returns the list of ids, edited via optimizer name.""" if optim_name == "D_enh": return [f"{uid}@{self.epoch}" for uid in batch_id] return batch_id def save_noisy_scores(self, batch_id, scores): for i, score in zip(batch_id, scores): self.noisy_scores[i] = score def score(self, batch_id, deg_wav, ref_wav, lens): """Returns actual metric score, either pesq or stoi Arguments --------- batch_id : list of str A list of the utterance ids for the batch deg_wav : torch.Tensor The degraded waveform to score ref_wav : torch.Tensor The reference waveform to use for scoring length : torch.Tensor The relative lengths of the utterances """ new_ids = [ i for i, d in enumerate(batch_id) if d not in self.historical_set and d not in self.noisy_scores ] if len(new_ids) == 0: pass elif self.hparams.target_metric == "pesq": self.target_metric.append( ids=[batch_id[i] for i in new_ids], predict=deg_wav[new_ids].detach(), target=ref_wav[new_ids].detach(), lengths=lens[new_ids], ) score = torch.tensor( [[s] for s in self.target_metric.scores], device=self.device, ) elif self.hparams.target_metric == "stoi": self.target_metric.append( [batch_id[i] for i in new_ids], deg_wav[new_ids], ref_wav[new_ids], lens[new_ids], reduction="batch", ) score = torch.tensor( [[-s] for s in self.target_metric.scores], device=self.device, ) else: raise ValueError("Expected 'pesq' or 'stoi' for target_metric") # Clear metric scores to prepare for next batch self.target_metric.clear() # Combine old scores and new final_score = [] for i, d in enumerate(batch_id): if d in self.historical_set: final_score.append([self.historical_set[d]["score"]]) elif d in self.noisy_scores: final_score.append([self.noisy_scores[d]]) else: final_score.append([score[new_ids.index(i)]]) return torch.tensor(final_score, device=self.device) def est_score(self, deg_spec, ref_spec): """Returns score as estimated by discriminator Arguments --------- deg_spec : torch.Tensor The spectral features of the degraded utterance ref_spec : torch.Tensor The spectral features of the reference utterance """ combined_spec = torch.cat( [deg_spec.unsqueeze(1), ref_spec.unsqueeze(1)], 1 ) return self.modules.discriminator(combined_spec) def write_wavs(self, clean_id, batch_id, wavs, scores, lens): """Write wavs to files, for historical discriminator training Arguments --------- batch_id : list of str A list of the utterance ids for the batch wavs : torch.Tensor The wavs to write to files score : torch.Tensor The actual scores for the corresponding utterances lens : torch.Tensor The relative lengths of each utterance """ lens = lens * wavs.shape[1] record = {} for i, (cleanid, name, pred_wav, length) in enumerate( zip(clean_id, batch_id, wavs, lens) ): path = os.path.join(self.hparams.MetricGAN_folder, name + ".wav") data = torch.unsqueeze(pred_wav[: int(length)].cpu(), 0) torchaudio.save(path, data, self.hparams.Sample_rate) # Make record of path and score for historical training score = float(scores[i][0]) clean_path = cleanid.split('-', 1) clean_path = os.path.join( self.hparams.train_clean_folder, clean_path[0], clean_path[1] + ".pkl" ) record[name] = { "enh_wav": path, "score": score, "clean_wav": clean_path, } # Update records for historical training self.historical_set.update(record) def fit_batch(self, batch): "Compute gradients and update either D or G based on sub-stage." predictions = self.compute_forward(batch, sb.Stage.TRAIN) loss_tracker = 0 if self.sub_stage == SubStage.CURRENT: for mode in ["clean", "enh", "noisy"]: loss = self.compute_objectives( predictions, batch, sb.Stage.TRAIN, f"D_{mode}" ) self.d_optimizer.zero_grad() loss.backward() if self.check_gradients(loss): self.d_optimizer.step() loss_tracker += loss.detach() / 3 elif self.sub_stage == SubStage.HISTORICAL: loss = self.compute_objectives( predictions, batch, sb.Stage.TRAIN, "D_enh" ) self.d_optimizer.zero_grad() loss.backward() if self.check_gradients(loss): self.d_optimizer.step() loss_tracker += loss.detach() elif self.sub_stage == SubStage.GENERATOR: loss = self.compute_objectives( predictions, batch, sb.Stage.TRAIN, "generator" ) self.g_optimizer.zero_grad() loss.backward() if self.check_gradients(loss): self.g_optimizer.step() loss_tracker += loss.detach() return loss_tracker def on_stage_start(self, stage, epoch=None): """Gets called at the beginning of each epoch This method calls ``fit()`` again to train the discriminator before proceeding with generator training. """ self.mse_metric = MetricStats(metric=self.hparams.compute_cost) self.metrics = {"G": [], "D": []} if stage == sb.Stage.TRAIN: if self.hparams.target_metric == "pesq": self.target_metric = MetricStats(metric=pesq_eval, n_jobs=40) elif self.hparams.target_metric == "stoi": self.target_metric = MetricStats(metric=stoi_loss) else: raise NotImplementedError( "Right now we only support 'pesq' and 'stoi'" ) # Train discriminator before we start generator training if self.sub_stage == SubStage.GENERATOR: self.epoch = epoch self.train_discriminator() self.sub_stage = SubStage.GENERATOR print("Generator training by current data...") if stage != sb.Stage.TRAIN: self.pesq_metric = MetricStats(metric=pesq_eval, n_jobs=30) self.stoi_metric = MetricStats(metric=stoi_loss) def train_discriminator(self): """A total of 3 data passes to update discriminator.""" # First, iterate train subset w/ updates for clean, enh, noisy print("Discriminator training by current data...") self.sub_stage = SubStage.CURRENT self.fit( range(1), self.train_set, train_loader_kwargs=self.hparams.dataloader_options, ) # Next, iterate historical subset w/ updates for enh if self.historical_set: print("Discriminator training by historical data...") self.sub_stage = SubStage.HISTORICAL self.fit( range(1), self.historical_set, train_loader_kwargs=self.hparams.dataloader_options, ) # Finally, iterate train set again. Should iterate same # samples as before, due to ReproducibleRandomSampler print("Discriminator training by current data again...") self.sub_stage = SubStage.CURRENT self.fit( range(1), self.train_set, train_loader_kwargs=self.hparams.dataloader_options, ) def on_stage_end(self, stage, stage_loss, epoch=None): "Called at the end of each stage to summarize progress" # epoch is awared in each stage def ckpt_predicate(ckpt): return ckpt.meta['epoch'] == epoch def ckpt_predicate_lessthan(ckpt): return ckpt.meta['epoch'] < epoch if self.sub_stage != SubStage.GENERATOR: return if stage == sb.Stage.TRAIN: self.train_loss = stage_loss g_loss = torch.tensor(self.metrics["G"]) # batch_size d_loss = torch.tensor(self.metrics["D"]) # batch_size print("Avg G loss: %.3f" % torch.mean(g_loss)) print("Avg D loss: %.3f" % torch.mean(d_loss)) print("MSE distance: %.3f" % self.mse_metric.summarize("average")) # save the checkpoint every 10 epochs # use default timestamp or use epoch numbers as max key # run the test as stats = { "epoch": epoch, "MSE distance": stage_loss, # "target_metric": self.target_metric.summarize("average") } self.checkpointer.save_checkpoint(meta=stats) elif self.hparams.target_metric == "pesq": stats = { "epoch": epoch, "MSE distance": stage_loss, "pesq": 5 * self.pesq_metric.summarize("average") - 0.5, } elif self.hparams.target_metric == "stoi": stats = { "epoch": epoch, "MSE distance": stage_loss, "stoi": -self.stoi_metric.summarize("average"), } if stage == sb.Stage.VALID: if self.hparams.use_tensorboard: valid_stats = { "mse": stage_loss, "pesq": 5 * self.pesq_metric.summarize("average") - 0.5, "stoi": -self.stoi_metric.summarize("average"), } self.hparams.tensorboard_train_logger.log_stats(valid_stats) self.hparams.train_logger.log_stats( {"Epoch": epoch}, train_stats={"loss": self.train_loss}, valid_stats=stats, ) self.checkpointer.save_and_keep_only( meta=stats, max_keys=[self.hparams.target_metric] ) if stage == sb.Stage.TEST: if self.hparams.mode == 'val': ckpts = self.checkpointer.find_checkpoints(ckpt_predicate=ckpt_predicate) assert len(ckpts) == 1 # delete old ckpt from train self.checkpointer.delete_checkpoints(num_to_keep=0, ckpt_predicate=ckpt_predicate) if self.hparams.use_tensorboard: valid_stats = stats # will two tensorboard corrupt? self.hparams.tensorboard_train_logger.log_stats(valid_stats) self.hparams.train_logger.log_stats( {"Epoch": epoch}, valid_stats=stats, ) # TODO:save current checkpointer again self.checkpointer.save_and_keep_only( meta=stats, max_keys=[self.hparams.target_metric], ckpt_predicate=ckpt_predicate_lessthan ) # self.checkpointer.save_checkpoint(meta=stats) else: print("Epoch loaded", self.hparams.epoch_counter.current) print("stoi", -self.stoi_metric.summarize("average")) print("pesq", 5 * self.pesq_metric.summarize("average") - 0.5) test_stats = { "mse": stage_loss, "pesq": 5 * self.pesq_metric.summarize("average") - 0.5, "stoi": -self.stoi_metric.summarize("average"), } # self.hparams.train_logger.log_stats( # {"Epoch loaded": self.hparams.epoch_counter.current}, # test_stats=test_stats, # ) def make_dataloader( self, dataset, stage, ckpt_prefix="dataloader-", **loader_kwargs ): "Override dataloader to insert custom sampler/dataset" if stage == sb.Stage.TRAIN: # Create a new dataset each time, this set grows if self.sub_stage == SubStage.HISTORICAL: dataset = sb.dataio.dataset.DynamicItemDataset( data=dataset, dynamic_items=[enh_pipeline], output_keys=["id", "enh_sig", "clean_sig", "score"], ) samples = round(len(dataset) * self.hparams.history_portion) else: samples = self.hparams.number_of_samples # This sampler should give the same samples for D and G epoch = self.hparams.epoch_counter.current # Equal weights for all samples, we use "Weighted" so we can do # both "replacement=False" and a set number of samples, reproducibly weights = torch.ones(len(dataset)) sampler = ReproducibleWeightedRandomSampler( weights, epoch=epoch, replacement=False, num_samples=samples ) loader_kwargs["sampler"] = sampler if self.sub_stage == SubStage.GENERATOR: self.train_sampler = sampler # Make the dataloader as normal return super().make_dataloader( dataset, stage, ckpt_prefix, **loader_kwargs ) def on_fit_start(self): "Override to prevent this from running for D training" if self.sub_stage == SubStage.GENERATOR: super().on_fit_start() def init_optimizers(self): "Initializes the generator and discriminator optimizers" self.g_optimizer = self.hparams.g_opt_class( self.modules.generator.parameters() ) self.d_optimizer = self.hparams.d_opt_class( self.modules.discriminator.parameters() ) if self.checkpointer is not None: self.checkpointer.add_recoverable("g_opt", self.g_optimizer) self.checkpointer.add_recoverable("d_opt", self.d_optimizer)
class SEBrain(sb.core.Brain): def compute_forward(self, batch, stage): """Forward computations from the waveform batches to the enhanced output.""" batch = batch.to(self.device) noisy_wavs, lens = batch.noisy_sig feats = self.hparams.compute_STFT(noisy_wavs) feats = spectral_magnitude(feats, power=0.5) feats = torch.log1p(feats) predict_spec = self.hparams.model(feats) # Also return predicted wav if stage != sb.Stage.TRAIN: predict_wav = self.hparams.resynth(torch.expm1(predict_spec), noisy_wavs) else: predict_wav = None return predict_spec, predict_wav def compute_objectives(self, predictions, batch, stage): """Computes the loss given the predicted and targeted outputs""" predict_spec, predict_wav = predictions ids = batch.id clean_wav, lens = batch.clean_sig feats = self.hparams.compute_STFT(clean_wav) feats = spectral_magnitude(feats, power=0.5) feats = torch.log1p(feats) loss = self.hparams.compute_cost(predict_spec, feats, lens) self.loss_metric.append(ids, predict_spec, feats, lens, reduction="batch") if stage != sb.Stage.TRAIN: # Evaluate speech quality/intelligibility self.stoi_metric.append(ids, predict_wav, clean_wav, lens, reduction="batch") self.pesq_metric.append(batch.id, predict=predict_wav, target=clean_wav, lengths=lens) # Write wavs to file if stage == sb.Stage.TEST: lens = lens * clean_wav.shape[1] for name, wav, length in zip(ids, predict_wav, lens): enhance_path = os.path.join(self.hparams.enhanced_folder, name) if not enhance_path.endswith(".wav"): enhance_path = enhance_path + ".wav" torchaudio.save( enhance_path, torch.unsqueeze(wav[:int(length)].cpu(), 0), self.hparams.Sample_rate, ) return loss def on_stage_start(self, stage, epoch=None): """Gets called at the beginning of each epoch""" self.loss_metric = MetricStats(metric=self.hparams.compute_cost) self.stoi_metric = MetricStats(metric=stoi_loss) # Define function taking (prediction, target) for parallel eval def pesq_eval(pred_wav, target_wav): """Computes the PESQ evaluation metric""" return pesq( fs=16000, ref=target_wav.cpu().numpy(), deg=pred_wav.cpu().numpy(), mode="wb", ) if stage != sb.Stage.TRAIN: self.pesq_metric = MetricStats(metric=pesq_eval, n_jobs=4) def on_stage_end(self, stage, stage_loss, epoch=None): """Gets called at the end of an epoch.""" if stage == sb.Stage.TRAIN: self.train_loss = stage_loss self.train_stats = {"loss": self.loss_metric.scores} else: stats = { "loss": stage_loss, "pesq": self.pesq_metric.summarize("average"), "stoi": -self.stoi_metric.summarize("average"), } if stage == sb.Stage.VALID: old_lr, new_lr = self.hparams.lr_annealing(4.5 - stats["pesq"]) sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr) if self.hparams.use_tensorboard: valid_stats = { "loss": self.loss_metric.scores, "stoi": self.stoi_metric.scores, "pesq": self.pesq_metric.scores, } self.hparams.tensorboard_train_logger.log_stats( { "Epoch": epoch, "lr": old_lr }, self.train_stats, valid_stats, ) self.hparams.train_logger.log_stats( { "Epoch": epoch, "lr": old_lr }, train_stats={"loss": self.train_loss}, valid_stats=stats, ) self.checkpointer.save_and_keep_only(meta=stats, max_keys=["pesq"]) if stage == sb.Stage.TEST: self.hparams.train_logger.log_stats( {"Epoch loaded": self.hparams.epoch_counter.current}, test_stats=stats, )
class MetricGanBrain(sb.Brain): def load_history(self): if os.path.isfile(self.hparams.historical_file): with open(self.hparams.historical_file, "rb") as fp: # Unpickling self.historical_set = pickle.load(fp) def compute_feats(self, wavs): """Feature computation pipeline""" feats = self.hparams.compute_STFT(wavs) spec = spectral_magnitude(feats, power=0.5) return spec def compute_forward(self, batch, stage): "Given an input batch computes the enhanced signal" batch = batch.to(self.device) if self.sub_stage == SubStage.HISTORICAL: predict_wav, lens = batch.enh_sig return predict_wav else: noisy_wav, lens = batch.noisy_sig noisy_spec = self.compute_feats(noisy_wav) mask = self.modules.generator(noisy_spec, lengths=lens) mask = mask.clamp(min=self.hparams.min_mask).squeeze(2) predict_spec = torch.mul(mask, noisy_spec) # Also return predicted wav predict_wav = self.hparams.resynth(predict_spec, noisy_wav) return predict_wav, mask def compute_objectives(self, predictions, batch, stage, optim_name=""): "Given the network predictions and targets compute the total loss" if self.sub_stage == SubStage.HISTORICAL: predict_wav = predictions else: predict_wav, mask = predictions predict_spec = self.compute_feats(predict_wav) ids = self.compute_ids(batch.id, optim_name) if self.sub_stage != SubStage.HISTORICAL: noisy_wav, lens = batch.noisy_sig if optim_name == "generator": est_score = self.est_score(predict_spec) target_score = self.hparams.target_score * torch.ones( self.batch_size, 1, device=self.device ) noisy_wav, lens = batch.noisy_sig noisy_spec = self.compute_feats(noisy_wav) mse_cost = self.hparams.compute_cost(predict_spec, noisy_spec, lens) # D Learns to estimate the scores of enhanced speech elif optim_name == "D_enh" and self.sub_stage == SubStage.CURRENT: target_score = self.score( ids, predict_wav, predict_wav, lens ) # no clean_wav is needed est_score = self.est_score(predict_spec) # Write enhanced wavs during discriminator training, because we # compute the actual score here and we can save it self.write_wavs(ids, predict_wav, target_score, lens) # D Relearns to estimate the scores of previous epochs elif optim_name == "D_enh" and self.sub_stage == SubStage.HISTORICAL: target_score = batch.score.unsqueeze(1).float() est_score = self.est_score(predict_spec) # D Learns to estimate the scores of noisy speech elif optim_name == "D_noisy": noisy_spec = self.compute_feats(noisy_wav) target_score = self.score( ids, noisy_wav, noisy_wav, lens ) # no clean_wav is needed est_score = self.est_score(noisy_spec) # Save scores of noisy wavs self.save_noisy_scores(ids, target_score) if stage == sb.Stage.TRAIN: # Compute the cost cost = self.hparams.compute_cost(est_score, target_score) if optim_name == "generator": cost += self.hparams.mse_weight * mse_cost self.metrics["G"].append(cost.detach()) else: self.metrics["D"].append(cost.detach()) # Compute scores on validation data if stage != sb.Stage.TRAIN: clean_wav, lens = batch.clean_sig cost = self.hparams.compute_si_snr(predict_wav, clean_wav, lens) # Evaluate speech quality/intelligibility self.stoi_metric.append( batch.id, predict_wav, clean_wav, lens, reduction="batch" ) self.pesq_metric.append( batch.id, predict=predict_wav, target=clean_wav, lengths=lens ) if ( self.hparams.calculate_dnsmos_on_validation_set ): # Note: very time consuming........ self.dnsmos_metric.append( batch.id, predict=predict_wav, target=predict_wav, lengths=lens, # no clean_wav is needed ) # Write wavs to file, for evaluation lens = lens * clean_wav.shape[1] for name, pred_wav, length in zip(batch.id, predict_wav, lens): name += ".wav" enhance_path = os.path.join(self.hparams.enhanced_folder, name) torchaudio.save( enhance_path, torch.unsqueeze(pred_wav[: int(length)].cpu(), 0), 16000, ) return cost def compute_ids(self, batch_id, optim_name): """Returns the list of ids, edited via optimizer name.""" if optim_name == "D_enh": return [f"{uid}@{self.epoch}" for uid in batch_id] return batch_id def save_noisy_scores(self, batch_id, scores): for i, score in zip(batch_id, scores): self.noisy_scores[i] = score def score(self, batch_id, deg_wav, ref_wav, lens): """Returns actual metric score, either pesq or stoi Arguments --------- batch_id : list of str A list of the utterance ids for the batch deg_wav : torch.Tensor The degraded waveform to score ref_wav : torch.Tensor The reference waveform to use for scoring length : torch.Tensor The relative lengths of the utterances """ new_ids = [ i for i, d in enumerate(batch_id) if d not in self.historical_set and d not in self.noisy_scores ] if len(new_ids) == 0: pass elif self.hparams.target_metric == "srmr" or "dnsmos": self.target_metric.append( ids=[batch_id[i] for i in new_ids], predict=deg_wav[new_ids].detach(), target=ref_wav[ new_ids ].detach(), # target is not used in the function !!! lengths=lens[new_ids], ) score = torch.tensor( [[s] for s in self.target_metric.scores], device=self.device, ) else: raise ValueError("Expected 'srmr' or 'dnsmos' for target_metric") # Clear metric scores to prepare for next batch self.target_metric.clear() # Combine old scores and new final_score = [] for i, d in enumerate(batch_id): if d in self.historical_set: final_score.append([self.historical_set[d]["score"]]) elif d in self.noisy_scores: final_score.append([self.noisy_scores[d]]) else: final_score.append([score[new_ids.index(i)]]) return torch.tensor(final_score, device=self.device) def est_score(self, deg_spec): """Returns score as estimated by discriminator Arguments --------- deg_spec : torch.Tensor The spectral features of the degraded utterance ref_spec : torch.Tensor The spectral features of the reference utterance """ """ combined_spec = torch.cat( [deg_spec.unsqueeze(1), ref_spec.unsqueeze(1)], 1 ) """ return self.modules.discriminator(deg_spec.unsqueeze(1)) def write_wavs(self, batch_id, wavs, score, lens): """Write wavs to files, for historical discriminator training Arguments --------- batch_id : list of str A list of the utterance ids for the batch wavs : torch.Tensor The wavs to write to files score : torch.Tensor The actual scores for the corresponding utterances lens : torch.Tensor The relative lengths of each utterance """ lens = lens * wavs.shape[1] record = {} for i, (name, pred_wav, length) in enumerate(zip(batch_id, wavs, lens)): path = os.path.join(self.hparams.MetricGAN_folder, name + ".wav") data = torch.unsqueeze(pred_wav[: int(length)].cpu(), 0) torchaudio.save(path, data, self.hparams.Sample_rate) # Make record of path and score for historical training score = float(score[i][0]) record[name] = { "enh_wav": path, "score": score, } # Update records for historical training self.historical_set.update(record) with open(self.hparams.historical_file, "wb") as fp: # Pickling pickle.dump(self.historical_set, fp) def fit_batch(self, batch): "Compute gradients and update either D or G based on sub-stage." predictions = self.compute_forward(batch, sb.Stage.TRAIN) loss_tracker = 0 if self.sub_stage == SubStage.CURRENT: for mode in ["enh", "noisy"]: loss = self.compute_objectives( predictions, batch, sb.Stage.TRAIN, f"D_{mode}" ) self.d_optimizer.zero_grad() loss.backward() if self.check_gradients(loss): self.d_optimizer.step() loss_tracker += loss.detach() / 3 elif self.sub_stage == SubStage.HISTORICAL: loss = self.compute_objectives( predictions, batch, sb.Stage.TRAIN, "D_enh" ) self.d_optimizer.zero_grad() loss.backward() if self.check_gradients(loss): self.d_optimizer.step() loss_tracker += loss.detach() elif self.sub_stage == SubStage.GENERATOR: for name, param in self.modules.generator.named_parameters(): if "Learnable_sigmoid" in name: param.data = torch.clamp( param, max=3.5 ) # to prevent gradient goes to infinity loss = self.compute_objectives( predictions, batch, sb.Stage.TRAIN, "generator" ) self.g_optimizer.zero_grad() loss.backward() if self.check_gradients(loss): self.g_optimizer.step() loss_tracker += loss.detach() return loss_tracker def on_stage_start(self, stage, epoch=None): """Gets called at the beginning of each epoch This method calls ``fit()`` again to train the discriminator before proceeding with generator training. """ self.metrics = {"G": [], "D": []} if stage == sb.Stage.TRAIN: if self.hparams.target_metric == "srmr": self.target_metric = MetricStats( metric=srmrpy_eval, n_jobs=hparams["n_jobs"], batch_eval=False, ) elif self.hparams.target_metric == "dnsmos": self.target_metric = MetricStats( metric=dnsmos_eval, n_jobs=hparams["n_jobs"], batch_eval=False, ) else: raise NotImplementedError( "Right now we only support 'srmr' and 'dnsmos'" ) # Train discriminator before we start generator training if self.sub_stage == SubStage.GENERATOR: self.epoch = epoch self.train_discriminator() self.sub_stage = SubStage.GENERATOR print("Generator training by current data...") if stage != sb.Stage.TRAIN: self.pesq_metric = MetricStats( metric=pesq_eval, n_jobs=hparams["n_jobs"], batch_eval=False ) self.stoi_metric = MetricStats(metric=stoi_loss) self.srmr_metric = MetricStats( metric=srmrpy_eval_valid, n_jobs=hparams["n_jobs"], batch_eval=False, ) self.dnsmos_metric = MetricStats( metric=dnsmos_eval_valid, n_jobs=hparams["n_jobs"], batch_eval=False, ) def train_discriminator(self): """A total of 3 data passes to update discriminator.""" # First, iterate train subset w/ updates for enh, noisy print("Discriminator training by current data...") self.sub_stage = SubStage.CURRENT self.fit( range(1), self.train_set, train_loader_kwargs=self.hparams.dataloader_options, ) # Next, iterate historical subset w/ updates for enh if self.historical_set: print("Discriminator training by historical data...") self.sub_stage = SubStage.HISTORICAL self.fit( range(1), self.historical_set, train_loader_kwargs=self.hparams.dataloader_options, ) # Finally, iterate train set again. Should iterate same # samples as before, due to ReproducibleRandomSampler print("Discriminator training by current data again...") self.sub_stage = SubStage.CURRENT self.fit( range(1), self.train_set, train_loader_kwargs=self.hparams.dataloader_options, ) def on_stage_end(self, stage, stage_loss, epoch=None): "Called at the end of each stage to summarize progress" if self.sub_stage != SubStage.GENERATOR: return if stage == sb.Stage.TRAIN: self.train_loss = stage_loss g_loss = torch.tensor(self.metrics["G"]) # batch_size d_loss = torch.tensor(self.metrics["D"]) # batch_size print("Avg G loss: %.3f" % torch.mean(g_loss)) print("Avg D loss: %.3f" % torch.mean(d_loss)) else: if self.hparams.calculate_dnsmos_on_validation_set: stats = { "SI-SNR": -stage_loss, "pesq": 5 * self.pesq_metric.summarize("average") - 0.5, "stoi": -self.stoi_metric.summarize("average"), "dnsmos": self.dnsmos_metric.summarize("average"), } else: stats = { "SI-SNR": -stage_loss, "pesq": 5 * self.pesq_metric.summarize("average") - 0.5, "stoi": -self.stoi_metric.summarize("average"), } if stage == sb.Stage.VALID: old_lr, new_lr = self.hparams.lr_annealing(5.0 - stats["pesq"]) sb.nnet.schedulers.update_learning_rate(self.g_optimizer, new_lr) if self.hparams.use_tensorboard: if ( self.hparams.calculate_dnsmos_on_validation_set ): # Note: very time consuming........ valid_stats = { "SI-SNR": -stage_loss, "pesq": 5 * self.pesq_metric.summarize("average") - 0.5, "stoi": -self.stoi_metric.summarize("average"), "dnsmos": self.dnsmos_metric.summarize("average"), } else: valid_stats = { "SI-SNR": -stage_loss, "pesq": 5 * self.pesq_metric.summarize("average") - 0.5, "stoi": -self.stoi_metric.summarize("average"), } self.hparams.tensorboard_train_logger.log_stats( {"lr": old_lr}, valid_stats ) self.hparams.train_logger.log_stats( {"Epoch": epoch, "lr": old_lr}, train_stats={"loss": self.train_loss}, valid_stats=stats, ) self.checkpointer.save_and_keep_only(meta=stats, max_keys=["pesq"]) if stage == sb.Stage.TEST: self.hparams.train_logger.log_stats( {"Epoch loaded": self.hparams.epoch_counter.current}, test_stats=stats, ) def make_dataloader( self, dataset, stage, ckpt_prefix="dataloader-", **loader_kwargs ): "Override dataloader to insert custom sampler/dataset" if stage == sb.Stage.TRAIN: # Create a new dataset each time, this set grows if self.sub_stage == SubStage.HISTORICAL: dataset = sb.dataio.dataset.DynamicItemDataset( data=dataset, dynamic_items=[enh_pipeline], output_keys=["id", "enh_sig", "score"], ) samples = round(len(dataset) * self.hparams.history_portion) else: samples = self.hparams.number_of_samples # This sampler should give the same samples for D and G epoch = self.hparams.epoch_counter.current # Equal weights for all samples, we use "Weighted" so we can do # both "replacement=False" and a set number of samples, reproducibly weights = torch.ones(len(dataset)) sampler = ReproducibleWeightedRandomSampler( weights, epoch=epoch, replacement=False, num_samples=samples ) loader_kwargs["sampler"] = sampler if self.sub_stage == SubStage.GENERATOR: self.train_sampler = sampler # Make the dataloader as normal return super().make_dataloader( dataset, stage, ckpt_prefix, **loader_kwargs ) def on_fit_start(self): "Override to prevent this from running for D training" if self.sub_stage == SubStage.GENERATOR: super().on_fit_start() def init_optimizers(self): "Initializes the generator and discriminator optimizers" self.g_optimizer = self.hparams.g_opt_class( self.modules.generator.parameters() ) self.d_optimizer = self.hparams.d_opt_class( self.modules.discriminator.parameters() ) if self.checkpointer is not None: self.checkpointer.add_recoverable("g_opt", self.g_optimizer) self.checkpointer.add_recoverable("d_opt", self.d_optimizer)
class SEBrain(sb.Brain): def compute_forward(self, batch, stage): """Forward computations from the waveform batches to the enhanced output""" batch = batch.to(self.device) noisy_wavs, lens = batch.noisy_sig noisy_wavs = torch.unsqueeze(noisy_wavs, -1) predict_wavs = self.modules.model(noisy_wavs)[:, :, 0] return predict_wavs def compute_objectives(self, predict_wavs, batch, stage): """Computes the loss given the predicted and targeted outputs""" clean_wavs, lens = batch.clean_sig loss = self.hparams.compute_cost(predict_wavs, clean_wavs, lens) self.loss_metric.append( batch.id, predict_wavs, clean_wavs, lens, reduction="batch" ) if stage != sb.Stage.TRAIN: # Evaluate speech quality/intelligibility self.stoi_metric.append( batch.id, predict_wavs, clean_wavs, lens, reduction="batch" ) self.pesq_metric.append( batch.id, predict=predict_wavs, target=clean_wavs, lengths=lens ) # Write wavs to file if stage == sb.Stage.TEST: lens = lens * clean_wavs.shape[1] for name, pred_wav, length in zip(batch.id, predict_wavs, lens): name += ".wav" enhance_path = os.path.join( self.hparams.enhanced_folder, name ) pred_wav = pred_wav / torch.max(torch.abs(pred_wav)) * 0.99 torchaudio.save( enhance_path, torch.unsqueeze(pred_wav[: int(length)].cpu(), 0), 16000, ) return loss def on_stage_start(self, stage, epoch=None): """Gets called at the beginning of each epoch""" self.loss_metric = MetricStats(metric=self.hparams.compute_cost) self.stoi_metric = MetricStats(metric=stoi_loss) # Define function taking (prediction, target) for parallel eval def pesq_eval(pred_wav, target_wav): """Computes the PESQ evaluation metric""" return pesq( fs=16000, ref=target_wav.numpy(), deg=pred_wav.numpy(), mode="wb", ) if stage != sb.Stage.TRAIN: self.pesq_metric = MetricStats( metric=pesq_eval, n_jobs=1, batch_eval=False ) def on_stage_end(self, stage, stage_loss, epoch=None): """Gets called at the end of an epoch.""" if stage == sb.Stage.TRAIN: self.train_loss = stage_loss self.train_stats = {"loss": self.loss_metric.scores} else: stats = { "loss": stage_loss, "pesq": self.pesq_metric.summarize("average"), "stoi": -self.stoi_metric.summarize("average"), } if stage == sb.Stage.VALID: if self.hparams.use_tensorboard: valid_stats = { "loss": self.loss_metric.scores, "stoi": self.stoi_metric.scores, "pesq": self.pesq_metric.scores, } self.hparams.tensorboard_train_logger.log_stats( {"Epoch": epoch}, self.train_stats, valid_stats ) self.hparams.train_logger.log_stats( {"Epoch": epoch}, train_stats={"loss": self.train_loss}, valid_stats=stats, ) self.checkpointer.save_and_keep_only(meta=stats, max_keys=["pesq"]) if stage == sb.Stage.TEST: self.hparams.train_logger.log_stats( {"Epoch loaded": self.hparams.epoch_counter.current}, test_stats=stats, )
class SEBrain(sb.Brain): def compute_forward(self, batch, stage): """Forward computations from the waveform batches to the enhanced output.""" batch = batch.to(self.device) noisy_wavs, lens = batch.noisy_sig noisy_feats = self.compute_feats(noisy_wavs) # mask with "signal approximation (SA)" mask = self.modules.model(noisy_feats) mask = torch.squeeze(mask, 2) predict_spec = torch.mul(mask, noisy_feats) # Also return predicted wav predict_wav = self.hparams.resynth( torch.expm1(predict_spec), noisy_wavs ) return predict_spec, predict_wav def compute_feats(self, wavs): """Feature computation pipeline""" feats = self.hparams.compute_STFT(wavs) feats = spectral_magnitude(feats, power=0.5) feats = torch.log1p(feats) return feats def compute_objectives(self, predictions, batch, stage): """Computes the loss given the predicted and targeted outputs""" predict_spec, predict_wav = predictions clean_wavs, lens = batch.clean_sig if getattr(self.hparams, "waveform_target", False): loss = self.hparams.compute_cost(predict_wav, clean_wavs, lens) self.loss_metric.append( batch.id, predict_wav, clean_wavs, lens, reduction="batch" ) else: clean_spec = self.compute_feats(clean_wavs) loss = self.hparams.compute_cost(predict_spec, clean_spec, lens) self.loss_metric.append( batch.id, predict_spec, clean_spec, lens, reduction="batch" ) if stage != sb.Stage.TRAIN: # Evaluate speech quality/intelligibility self.stoi_metric.append( batch.id, predict_wav, clean_wavs, lens, reduction="batch" ) self.pesq_metric.append( batch.id, predict=predict_wav, target=clean_wavs, lengths=lens ) # Write wavs to file if stage == sb.Stage.TEST: lens = lens * clean_wavs.shape[1] for name, pred_wav, length in zip(batch.id, predict_wav, lens): name += ".wav" enhance_path = os.path.join( self.hparams.enhanced_folder, name ) torchaudio.save( enhance_path, torch.unsqueeze(pred_wav[: int(length)].cpu(), 0), 16000, ) return loss def on_stage_start(self, stage, epoch=None): """Gets called at the beginning of each epoch""" self.loss_metric = MetricStats(metric=self.hparams.compute_cost) self.stoi_metric = MetricStats(metric=stoi_loss) # Define function taking (prediction, target) for parallel eval def pesq_eval(pred_wav, target_wav): """Computes the PESQ evaluation metric""" return pesq( fs=16000, ref=target_wav.numpy(), deg=pred_wav.numpy(), mode="wb", ) if stage != sb.Stage.TRAIN: self.pesq_metric = MetricStats(metric=pesq_eval, n_jobs=4) def on_stage_end(self, stage, stage_loss, epoch=None): """Gets called at the end of an epoch.""" if stage == sb.Stage.TRAIN: self.train_loss = stage_loss self.train_stats = {"loss": self.loss_metric.scores} else: stats = { "loss": stage_loss, "pesq": self.pesq_metric.summarize("average"), "stoi": -self.stoi_metric.summarize("average"), } if stage == sb.Stage.VALID: if self.hparams.use_tensorboard: valid_stats = { "loss": self.loss_metric.scores, "stoi": self.stoi_metric.scores, "pesq": self.pesq_metric.scores, } self.hparams.tensorboard_train_logger.log_stats( {"Epoch": epoch}, self.train_stats, valid_stats ) self.hparams.train_logger.log_stats( {"Epoch": epoch}, train_stats={"loss": self.train_loss}, valid_stats=stats, ) self.checkpointer.save_and_keep_only(meta=stats, max_keys=["pesq"]) if stage == sb.Stage.TEST: self.hparams.train_logger.log_stats( {"Epoch loaded": self.hparams.epoch_counter.current}, test_stats=stats, )
class Separation(sb.Brain): def compute_forward(self, mix, targets, stage, noise=None): """Forward computations from the mixture to the separated signals.""" # Unpack lists and put tensors in the right device mix, mix_lens = mix mix, mix_lens = mix.to(self.device), mix_lens.to(self.device) # Convert targets to tensor targets = torch.cat( [targets[i][0].unsqueeze(-1) for i in range(self.hparams.num_spks)], dim=-1, ).to(self.device) # Add speech distortions if stage == sb.Stage.TRAIN: with torch.no_grad(): if self.hparams.use_speedperturb or self.hparams.use_rand_shift: mix, targets = self.add_speed_perturb(targets, mix_lens) if "whamr" in self.hparams.data_folder: try: targets_rev = [ self.hparams.reverb(targets[:, :, i], None) for i in range(self.hparams.num_spks) ] except Exception: print("reverb error, not adding reverb") targets_rev = [ targets[:, :, i] for i in range(self.hparams.num_spks) ] targets_rev = torch.stack(targets_rev, dim=-1) mix = targets_rev.sum(-1) # if we do not dereverberate, we set the targets to be reverberant if not self.hparams.dereverberate: targets = targets_rev else: mix = targets.sum(-1) noise = noise.to(self.device) len_noise = noise.shape[1] len_mix = mix.shape[1] min_len = min(len_noise, len_mix) # add the noise mix = mix[:, :min_len] + noise[:, :min_len] # fix the length of targets also targets = targets[:, :min_len, :] if self.hparams.use_wavedrop: mix = self.hparams.wavedrop(mix, mix_lens) if self.hparams.limit_training_signal_len: mix, targets = self.cut_signals(mix, targets) # torchaudio.save( # 'mix.wav', mix.data.cpu(), self.hparams.sample_rate # ) # torchaudio.save( # 'targets.wav', targets.squeeze(-1).data.cpu(), self.hparams.sample_rate # ) # Separation if self.use_freq_domain: mix_w = self.compute_feats(mix) est_mask = self.modules.masknet(mix_w) sep_h = mix_w * est_mask est_source = self.hparams.resynth(torch.expm1(sep_h), mix) else: mix_w = self.hparams.Encoder(mix) est_mask = self.modules.masknet(mix_w) mix_w = torch.stack([mix_w] * self.hparams.num_spks) sep_h = mix_w * est_mask est_source = torch.cat( [ self.hparams.Decoder(sep_h[i]).unsqueeze(-1) for i in range(self.hparams.num_spks) ], dim=-1, ) # T changed after conv1d in encoder, fix it here T_origin = mix.size(1) T_est = est_source.size(1) est_source = est_source.squeeze(-1) if T_origin > T_est: est_source = F.pad(est_source, (0, T_origin - T_est)) else: est_source = est_source[:, :T_origin] return [est_source, sep_h], targets.squeeze(-1) def compute_feats(self, wavs): """Feature computation pipeline""" feats = self.hparams.Encoder(wavs) feats = spectral_magnitude(feats, power=0.5) feats = torch.log1p(feats) return feats def compute_objectives(self, predictions, targets): """Computes the si-snr loss""" predicted_wavs, predicted_specs = predictions if self.use_freq_domain: target_specs = self.compute_feats(targets) return self.hparams.loss(target_specs, predicted_specs) else: return self.hparams.loss( targets.unsqueeze(-1), predicted_wavs.unsqueeze(-1) ) def fit_batch(self, batch): """Trains one batch""" # Unpacking batch list mixture = batch.mix_sig targets = [batch.s1_sig, batch.s2_sig] noise = batch.noise_sig[0] if self.auto_mix_prec: with autocast(): predictions, targets = self.compute_forward( mixture, targets, sb.Stage.TRAIN, noise ) loss = self.compute_objectives(predictions, targets) # hard threshold the easy dataitems if self.hparams.threshold_byloss: th = self.hparams.threshold loss_to_keep = loss[loss > th] if loss_to_keep.nelement() > 0: loss = loss_to_keep.mean() else: loss = loss.mean() if ( loss < self.hparams.loss_upper_lim and loss.nelement() > 0 ): # the fix for computational problems self.scaler.scale(loss).backward() if self.hparams.clip_grad_norm >= 0: self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_( self.modules.parameters(), self.hparams.clip_grad_norm, ) self.scaler.step(self.optimizer) self.scaler.update() else: self.nonfinite_count += 1 logger.info( "infinite loss or empty loss! it happened {} times so far - skipping this batch".format( self.nonfinite_count ) ) loss.data = torch.tensor(0).to(self.device) else: predictions, targets = self.compute_forward( mixture, targets, sb.Stage.TRAIN, noise ) loss = self.compute_objectives(predictions, targets) if self.hparams.threshold_byloss: th = self.hparams.threshold loss_to_keep = loss[loss > th] if loss_to_keep.nelement() > 0: loss = loss_to_keep.mean() else: loss = loss.mean() if ( loss < self.hparams.loss_upper_lim and loss.nelement() > 0 ): # the fix for computational problems loss.backward() if self.hparams.clip_grad_norm >= 0: torch.nn.utils.clip_grad_norm_( self.modules.parameters(), self.hparams.clip_grad_norm ) self.optimizer.step() else: self.nonfinite_count += 1 logger.info( "infinite loss or empty loss! it happened {} times so far - skipping this batch".format( self.nonfinite_count ) ) loss.data = torch.tensor(0).to(self.device) self.optimizer.zero_grad() return loss.detach().cpu() def evaluate_batch(self, batch, stage): """Computations needed for validation/test batches""" snt_id = batch.id mixture = batch.mix_sig targets = [batch.s1_sig, batch.s2_sig] with torch.no_grad(): predictions, targets = self.compute_forward(mixture, targets, stage) loss = self.compute_objectives(predictions, targets) if stage != sb.Stage.TRAIN: self.pesq_metric.append( ids=batch.id, predict=predictions[0].cpu(), target=targets.cpu() ) # Manage audio file saving if stage == sb.Stage.TEST and self.hparams.save_audio: if hasattr(self.hparams, "n_audio_to_save"): if self.hparams.n_audio_to_save > 0: self.save_audio(snt_id[0], mixture, targets, predictions[0]) self.hparams.n_audio_to_save += -1 else: self.save_audio(snt_id[0], mixture, targets, predictions[0]) return loss.detach() def on_stage_start(self, stage, epoch=None): """Gets called at the beginning of each epoch""" if stage != sb.Stage.TRAIN: # Define function taking (prediction, target) for parallel eval def pesq_eval(pred_wav, target_wav): """Computes the PESQ evaluation metric""" psq_mode = "wb" if self.hparams.sample_rate == 16000 else "nb" try: return pesq( fs=self.hparams.sample_rate, ref=target_wav.numpy(), deg=pred_wav.numpy(), mode=psq_mode, ) except Exception: print("pesq encountered an error for this data item") return 0 self.pesq_metric = MetricStats( metric=pesq_eval, n_jobs=1, batch_eval=False ) def on_stage_end(self, stage, stage_loss, epoch): """Gets called at the end of a epoch.""" # Compute/store important stats stage_stats = {"loss": stage_loss} if stage == sb.Stage.TRAIN: self.train_stats = stage_stats else: stats = { "loss": stage_loss, "pesq": self.pesq_metric.summarize("average"), } # Perform end-of-iteration things, like annealing, logging, etc. if stage == sb.Stage.VALID: # Learning rate annealing if isinstance( self.hparams.lr_scheduler, schedulers.ReduceLROnPlateau ): current_lr, next_lr = self.hparams.lr_scheduler( [self.optimizer], epoch, stage_loss ) schedulers.update_learning_rate(self.optimizer, next_lr) else: # if we do not use the reducelronplateau, we do not change the lr current_lr = self.hparams.optimizer.optim.param_groups[0]["lr"] self.hparams.train_logger.log_stats( stats_meta={"epoch": epoch, "lr": current_lr}, train_stats=self.train_stats, valid_stats=stats, ) if ( hasattr(self.hparams, "save_all_checkpoints") and self.hparams.save_all_checkpoints ): self.checkpointer.save_checkpoint(meta={"pesq": stats["pesq"]}) else: self.checkpointer.save_and_keep_only( meta={"pesq": stats["pesq"]}, max_keys=["pesq"], ) elif stage == sb.Stage.TEST: self.hparams.train_logger.log_stats( stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, test_stats=stats, ) def add_speed_perturb(self, targets, targ_lens): """Adds speed perturbation and random_shift to the input signals""" min_len = -1 recombine = False if self.hparams.use_speedperturb: # Performing speed change (independently on each source) new_targets = [] recombine = True for i in range(targets.shape[-1]): new_target = self.hparams.speedperturb( targets[:, :, i], targ_lens ) new_targets.append(new_target) if i == 0: min_len = new_target.shape[-1] else: if new_target.shape[-1] < min_len: min_len = new_target.shape[-1] if self.hparams.use_rand_shift: # Performing random_shift (independently on each source) recombine = True for i in range(targets.shape[-1]): rand_shift = torch.randint( self.hparams.min_shift, self.hparams.max_shift, (1,) ) new_targets[i] = new_targets[i].to(self.device) new_targets[i] = torch.roll( new_targets[i], shifts=(rand_shift[0],), dims=1 ) # Re-combination if recombine: if self.hparams.use_speedperturb: targets = torch.zeros( targets.shape[0], min_len, targets.shape[-1], device=targets.device, dtype=torch.float, ) for i, new_target in enumerate(new_targets): targets[:, :, i] = new_targets[i][:, 0:min_len] mix = targets.sum(-1) return mix, targets def cut_signals(self, mixture, targets): """This function selects a random segment of a given length withing the mixture. The corresponding targets are selected accordingly""" randstart = torch.randint( 0, 1 + max(0, mixture.shape[1] - self.hparams.training_signal_len), (1,), ).item() targets = targets[ :, randstart : randstart + self.hparams.training_signal_len, : ] mixture = mixture[ :, randstart : randstart + self.hparams.training_signal_len ] return mixture, targets def reset_layer_recursively(self, layer): """Reinitializes the parameters of the neural networks""" if hasattr(layer, "reset_parameters"): layer.reset_parameters() for child_layer in layer.modules(): if layer != child_layer: self.reset_layer_recursively(child_layer) def save_results(self, test_data): """This script computes the SDR and SI-SNR metrics and saves them into a csv file""" # This package is required for SDR computation from mir_eval.separation import bss_eval_sources # Create folders where to store audio save_file = os.path.join(self.hparams.output_folder, "test_results.csv") # Variable init all_sdrs = [] all_sdrs_i = [] all_sisnrs = [] all_sisnrs_i = [] all_pesqs = [] csv_columns = ["snt_id", "sdr", "sdr_i", "si-snr", "si-snr_i", "pesq"] test_loader = sb.dataio.dataloader.make_dataloader( test_data, **self.hparams.dataloader_opts ) with open(save_file, "w") as results_csv: writer = csv.DictWriter(results_csv, fieldnames=csv_columns) writer.writeheader() # Loop over all test sentence with tqdm(test_loader, dynamic_ncols=True) as t: for i, batch in enumerate(t): # Apply Separation mixture, mix_len = batch.mix_sig snt_id = batch.id targets = [batch.s1_sig, batch.s2_sig] if self.hparams.num_spks == 3: targets.append(batch.s3_sig) with torch.no_grad(): predictions, targets = self.compute_forward( batch.mix_sig, targets, sb.Stage.TEST ) # Compute SI-SNR sisnr = self.compute_objectives(predictions, targets) # Compute SI-SNR improvement mixture_signal = torch.stack( [mixture] * self.hparams.num_spks, dim=-1 ) mixture_signal = mixture_signal.to(targets.device) sisnr_baseline = self.compute_objectives( [mixture_signal.squeeze(-1), None], targets ) sisnr_i = sisnr - sisnr_baseline # Compute SDR sdr, _, _, _ = bss_eval_sources( targets[0].t().cpu().numpy(), predictions[0][0].t().detach().cpu().numpy(), ) sdr_baseline, _, _, _ = bss_eval_sources( targets[0].t().cpu().numpy(), mixture_signal[0].t().detach().cpu().numpy(), ) sdr_i = sdr.mean() - sdr_baseline.mean() # Compute PESQ psq_mode = ( "wb" if self.hparams.sample_rate == 16000 else "nb" ) psq = pesq( self.hparams.sample_rate, targets.squeeze().cpu().numpy(), predictions[0].squeeze().cpu().numpy(), mode=psq_mode, ) # Saving on a csv file row = { "snt_id": snt_id[0], "sdr": sdr.mean(), "sdr_i": sdr_i, "si-snr": -sisnr.item(), "si-snr_i": -sisnr_i.item(), "pesq": psq, } writer.writerow(row) # Metric Accumulation all_sdrs.append(sdr.mean()) all_sdrs_i.append(sdr_i.mean()) all_sisnrs.append(-sisnr.item()) all_sisnrs_i.append(-sisnr_i.item()) all_pesqs.append(psq) row = { "snt_id": "avg", "sdr": np.array(all_sdrs).mean(), "sdr_i": np.array(all_sdrs_i).mean(), "si-snr": np.array(all_sisnrs).mean(), "si-snr_i": np.array(all_sisnrs_i).mean(), "pesq": np.array(all_pesqs).mean(), } writer.writerow(row) logger.info("Mean SISNR is {}".format(np.array(all_sisnrs).mean())) logger.info("Mean SISNRi is {}".format(np.array(all_sisnrs_i).mean())) logger.info("Mean SDR is {}".format(np.array(all_sdrs).mean())) logger.info("Mean SDRi is {}".format(np.array(all_sdrs_i).mean())) logger.info("Mean PESQ {}".format(np.array(all_pesqs).mean())) def save_audio(self, snt_id, mixture, targets, predictions): "saves the test audio (mixture, targets, and estimated sources) on disk" # Create outout folder save_path = os.path.join(self.hparams.save_folder, "audio_results") if not os.path.exists(save_path): os.mkdir(save_path) # Estimated source signal = predictions[0, :] signal = signal / signal.abs().max() save_file = os.path.join( save_path, "item{}_sourcehat.wav".format(snt_id) ) torchaudio.save( save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate ) # Original source signal = targets[0, :] signal = signal / signal.abs().max() save_file = os.path.join(save_path, "item{}_source.wav".format(snt_id)) torchaudio.save( save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate ) # Mixture signal = mixture[0][0, :] signal = signal / signal.abs().max() save_file = os.path.join(save_path, "item{}_mix.wav".format(snt_id)) torchaudio.save( save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate )
class SEBrain(sb.Brain): def compute_forward_g(self, noisy_wavs): """Forward computations of the generator. Input noisy signal, output clean signal""" noisy_wavs = noisy_wavs.to(self.device) predict_wavs = self.modules["model_g"](noisy_wavs) return predict_wavs def compute_forward_d(self, noisy_wavs, clean_wavs): """Forward computations from discriminator. Input denoised-noisy pair, output whether denoising was properly acheived""" noisy_wavs = noisy_wavs.to(self.device) clean_wavs = clean_wavs.to(self.device) inpt = torch.cat((noisy_wavs, clean_wavs), -1) out = self.modules["model_d"](inpt) return out def compute_objectives_d1(self, d_outs, batch): """Computes the loss of a discriminator given predicted and targeted outputs, with target being clean""" loss = self.hparams.compute_cost["d1"](d_outs) self.loss_metric_d1.append(batch.id, d_outs, reduction="batch") return loss def compute_objectives_d2(self, d_outs, batch): """Computes the loss of a discriminator given predicted and targeted outputs, with target being noisy""" loss = self.hparams.compute_cost["d2"](d_outs) self.loss_metric_d2.append(batch.id, d_outs, reduction="batch") return loss def compute_objectives_g3( self, d_outs, predict_wavs, clean_wavs, batch, stage, z_mean=None, z_logvar=None, ): """Computes the loss of the generator based on discriminator and generator losses""" clean_wavs_orig, lens = batch.clean_sig clean_wavs_orig = clean_wavs_orig.to(self.device) clean_wavs = clean_wavs.to(self.device) loss = self.hparams.compute_cost["g3"]( d_outs, predict_wavs, clean_wavs, lens, l1LossCoeff=self.hparams.l1LossCoeff, klLossCoeff=self.hparams.klLossCoeff, z_mean=z_mean, z_logvar=z_logvar, ) self.loss_metric_g3.append( batch.id, d_outs, predict_wavs, clean_wavs, lens, l1LossCoeff=self.hparams.l1LossCoeff, klLossCoeff=self.hparams.klLossCoeff, z_mean=z_mean, z_logvar=z_logvar, reduction="batch", ) if stage != sb.Stage.TRAIN: # Evaluate speech quality/intelligibility predict_wavs = predict_wavs.reshape(self.batch_current, -1) clean_wavs = clean_wavs.reshape(self.batch_current, -1) predict_wavs = predict_wavs[:, 0:self.original_len] clean_wavs = clean_wavs[:, 0:self.original_len] self.stoi_metric.append(batch.id, predict_wavs, clean_wavs, lens, reduction="batch") self.pesq_metric.append(batch.id, predict=predict_wavs.cpu(), target=clean_wavs.cpu()) # Write enhanced test wavs to file if stage == sb.Stage.TEST: lens = lens * clean_wavs.shape[1] for name, pred_wav, length in zip(batch.id, predict_wavs, lens): name += ".wav" enhance_path = os.path.join(self.hparams.enhanced_folder, name) print(enhance_path) pred_wav = pred_wav / torch.max(torch.abs(pred_wav)) * 0.99 torchaudio.save( enhance_path, pred_wav[:int(length)].cpu().unsqueeze(0), hparams["sample_rate"], ) return loss def fit_batch(self, batch): """Fit one batch, override to do multiple updates. The default implementation depends on a few methods being defined with a particular behavior: * ``compute_forward()`` * ``compute_objectives()`` Also depends on having optimizers passed at initialization. Arguments --------- batch : list of torch.Tensors Batch of data to use for training. Default implementation assumes this batch has two elements: inputs and targets. Returns ------- detached loss """ noisy_wavs, lens = batch.noisy_sig clean_wavs, lens = batch.clean_sig # split sentences in smaller chunks noisy_wavs = create_chunks( noisy_wavs, chunk_size=hparams["chunk_size"], chunk_stride=hparams["chunk_stride"], ) clean_wavs = create_chunks( clean_wavs, chunk_size=hparams["chunk_size"], chunk_stride=hparams["chunk_stride"], ) # first of three step training process detailed in SEGAN paper out_d1 = self.compute_forward_d(noisy_wavs, clean_wavs) loss_d1 = self.compute_objectives_d1(out_d1, batch) loss_d1.backward() if self.check_gradients(loss_d1): self.optimizer_d.step() self.optimizer_d.zero_grad() # second training step z_mean = None z_logvar = None if self.modules["model_g"].latent_vae: out_g2, z_mean, z_logvar = self.compute_forward_g(noisy_wavs) else: out_g2 = self.compute_forward_g(noisy_wavs) out_d2 = self.compute_forward_d(out_g2, clean_wavs) loss_d2 = self.compute_objectives_d2(out_d2, batch) loss_d2.backward(retain_graph=True) if self.check_gradients(loss_d2): self.optimizer_d.step() self.optimizer_d.zero_grad() # third (last) training step self.optimizer_g.zero_grad() out_d3 = self.compute_forward_d(out_g2, clean_wavs) loss_g3 = self.compute_objectives_g3( out_d3, out_g2, clean_wavs, batch, sb.Stage.TRAIN, z_mean=z_mean, z_logvar=z_logvar, ) loss_g3.backward() if self.check_gradients(loss_g3): self.optimizer_g.step() self.optimizer_g.zero_grad() self.optimizer_d.zero_grad() loss_d1.detach().cpu() loss_d2.detach().cpu() loss_g3.detach().cpu() return loss_d1 + loss_d2 + loss_g3 def evaluate_batch(self, batch, stage): """Evaluate one batch, override for different procedure than train. The default implementation depends on two methods being defined with a particular behavior: * ``compute_forward()`` * ``compute_objectives()`` Arguments --------- batch : list of torch.Tensors Batch of data to use for evaluation. Default implementation assumes this batch has two elements: inputs and targets. stage : Stage The stage of the experiment: Stage.VALID, Stage.TEST Returns ------- detached loss """ noisy_wavs, lens = batch.noisy_sig clean_wavs, lens = batch.clean_sig self.batch_current = clean_wavs.shape[0] self.original_len = clean_wavs.shape[1] # Add padding to make sure all the signal will be processed. padding_elements = torch.zeros(clean_wavs.shape[0], hparams["chunk_size"], device=clean_wavs.device) clean_wavs = torch.cat([clean_wavs, padding_elements], dim=1) noisy_wavs = torch.cat([noisy_wavs, padding_elements], dim=1) # Split sentences in smaller chunks noisy_wavs = create_chunks( noisy_wavs, chunk_size=hparams["chunk_size"], chunk_stride=hparams["chunk_size"], ) clean_wavs = create_chunks( clean_wavs, chunk_size=hparams["chunk_size"], chunk_stride=hparams["chunk_size"], ) # Perform speech enhancement with the current model out_d1 = self.compute_forward_d(noisy_wavs, clean_wavs) loss_d1 = self.compute_objectives_d1(out_d1, batch) z_mean = None z_logvar = None if self.modules["model_g"].latent_vae: out_g2, z_mean, z_logvar = self.compute_forward_g(noisy_wavs) else: out_g2 = self.compute_forward_g(noisy_wavs) out_d2 = self.compute_forward_d(out_g2, clean_wavs) loss_d2 = self.compute_objectives_d2(out_d2, batch) loss_g3 = self.compute_objectives_g3( out_d2, out_g2, clean_wavs, batch, stage=stage, z_mean=z_mean, z_logvar=z_logvar, ) loss_d1.detach().cpu() loss_d2.detach().cpu() loss_g3.detach().cpu() return loss_d1 + loss_d2 + loss_g3 def init_optimizers(self): """Called during ``on_fit_start()``, initialize optimizers after parameters are fully configured (e.g. DDP, jit). The default implementation of this method depends on an optimizer class being passed at initialization that takes only a list of parameters (e.g., a lambda or a partial function definition). This creates a single optimizer that optimizes all trainable params. Override this class if there are multiple optimizers. """ if self.opt_class is not None: self.optimizer_d = self.opt_class( self.modules["model_d"].parameters()) self.optimizer_g = self.opt_class( self.modules["model_g"].parameters()) if self.checkpointer is not None: self.checkpointer.add_recoverable("optimizer_g", self.optimizer_g) self.checkpointer.add_recoverable("optimizer_d", self.optimizer_d) def on_stage_start(self, stage, epoch=None): """Gets called at the beginning of each epoch""" self.loss_metric_d1 = MetricStats( metric=self.hparams.compute_cost["d1"]) self.loss_metric_d2 = MetricStats( metric=self.hparams.compute_cost["d2"]) self.loss_metric_g3 = MetricStats( metric=self.hparams.compute_cost["g3"]) self.stoi_metric = MetricStats(metric=stoi_loss) # Define function taking (prediction, target) for parallel eval def pesq_eval(pred_wav, target_wav): """Computes the PESQ evaluation metric""" return pesq( fs=hparams["sample_rate"], ref=target_wav.numpy().squeeze(), deg=pred_wav.numpy().squeeze(), mode="wb", ) if stage != sb.Stage.TRAIN: self.pesq_metric = MetricStats(metric=pesq_eval, batch_eval=False, n_jobs=1) def on_stage_end(self, stage, stage_loss, epoch=None): """Gets called at the end of an epoch.""" if stage == sb.Stage.TRAIN: self.train_loss = stage_loss self.train_stats = { # "loss": self.loss_metric.scores, "loss_d1": self.loss_metric_d1.scores, "loss_d2": self.loss_metric_d2.scores, "loss_g3": self.loss_metric_g3.scores, } else: stats = { "loss": stage_loss, "pesq": self.pesq_metric.summarize("average"), "stoi": -self.stoi_metric.summarize("average"), } if stage == sb.Stage.VALID: if self.hparams.use_tensorboard: valid_stats = { # "loss": self.loss_metric.scores, "loss_d1": self.loss_metric_d1.scores, "loss_d2": self.loss_metric_d2.scores, "loss_g3": self.loss_metric_g3.scores, "stoi": self.stoi_metric.scores, "pesq": self.pesq_metric.scores, } self.hparams.tensorboard_train_logger.log_stats( {"Epoch": epoch}, self.train_stats, valid_stats) self.hparams.train_logger.log_stats( {"Epoch": epoch}, train_stats={"loss": self.train_loss}, valid_stats=stats, ) self.checkpointer.save_and_keep_only(meta=stats, max_keys=["pesq"]) if stage == sb.Stage.TEST: self.hparams.train_logger.log_stats( {"Epoch loaded": self.hparams.epoch_counter.current}, test_stats=stats, )