def test_sisdr(n_src, function_triplet): # Unpack the triplet pairwise, nosrc, nonpit = function_triplet # Fake targets and estimates targets = torch.randn(2, n_src, 10000) est_targets = torch.randn(2, n_src, 10000) # Create the 3 PIT wrappers pw_wrapper = PITLossWrapper(pairwise, pit_from="pw_mtx") wo_src_wrapper = PITLossWrapper(nosrc, pit_from="pw_pt") w_src_wrapper = PITLossWrapper(nonpit, pit_from="perm_avg") # Circular tests on value assert_allclose(pw_wrapper(est_targets, targets), wo_src_wrapper(est_targets, targets)) assert_allclose(wo_src_wrapper(est_targets, targets), w_src_wrapper(est_targets, targets)) # Circular tests on returned estimates assert_allclose( pw_wrapper(est_targets, targets, return_est=True)[1], wo_src_wrapper(est_targets, targets, return_est=True)[1], ) assert_allclose( wo_src_wrapper(est_targets, targets, return_est=True)[1], w_src_wrapper(est_targets, targets, return_est=True)[1], )
def test_sisdr_and_mse(n_src, loss): # Unpack the triplet pairwise, singlesrc, multisrc, _ = loss # Fake targets and estimates targets = torch.randn(2, n_src, 10000) est_targets = torch.randn(2, n_src, 10000) # Create the 3 PIT wrappers pw_wrapper = PITLossWrapper(pairwise, pit_from="pw_mtx") wo_src_wrapper = PITLossWrapper(singlesrc, pit_from="pw_pt") w_src_wrapper = PITLossWrapper(multisrc, pit_from="perm_avg") # Circular tests on value assert_allclose(pw_wrapper(est_targets, targets), wo_src_wrapper(est_targets, targets)) assert_allclose(wo_src_wrapper(est_targets, targets), w_src_wrapper(est_targets, targets)) # Circular tests on returned estimates assert_allclose( pw_wrapper(est_targets, targets, return_est=True)[1], wo_src_wrapper(est_targets, targets, return_est=True)[1], ) assert_allclose( wo_src_wrapper(est_targets, targets, return_est=True)[1], w_src_wrapper(est_targets, targets, return_est=True)[1], )
def test_best_perm_match(n_src): pwl = torch.randn(2, n_src, n_src) min_loss, min_idx = PITLossWrapper.find_best_perm_factorial(pwl) min_loss_hun, min_idx_hun = PITLossWrapper.find_best_perm_hungarian(pwl) assert_allclose(min_loss, min_loss_hun) assert_allclose(min_idx, min_idx_hun)
def main(conf): perms = list(permutations(range(conf["train_conf"]["data"]["n_src"]))) model_path = os.path.join(conf["exp_dir"], conf["ckpt_path"]) if conf["ckpt_path"] == "best_model.pth": # serialized checkpoint model = getattr(asteroid, conf["model"]).from_pretrained(model_path) else: # non-serialized checkpoint, _ckpt_epoch_{i}.ckpt, keys would start with # "model.", which need to be removed model = getattr(asteroid, conf["model"])(**conf["train_conf"]["filterbank"], **conf["train_conf"]["masknet"]) all_states = torch.load(model_path, map_location="cpu") state_dict = {k.split('.', 1)[1]: all_states["state_dict"][k] for k in all_states["state_dict"]} model.load_state_dict(state_dict) # model.load_state_dict(all_states["state_dict"], strict=False) # Handle device placement if conf["use_gpu"]: model.cuda() model_device = next(model.parameters()).device test_set = make_test_dataset( corpus=conf["corpus"], test_dir=conf["test_dir"], task=conf["task"], sample_rate=conf["sample_rate"], n_src=conf["train_conf"]["data"]["n_src"], ) # Used to reorder sources only loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx") # all resulting files would be saved in eval_save_dir eval_save_dir = os.path.join(conf["exp_dir"], conf["out_dir"]) os.makedirs(eval_save_dir, exist_ok=True) series_list = [] torch.no_grad().__enter__() for idx in tqdm(range(len(test_set))): # Forward the network on the mixture. mix, sources = tensors_to_device(test_set[idx], device=model_device) est_sources = model(mix.unsqueeze(0)) # When inferencing separation for multi-task training, # exclude the last channel. Does not effect single-task training # models (from_scratch, pre+FT). est_sources = est_sources[:, :sources.shape[0]] _, best_perm_idx = loss_func.find_best_perm(pairwise_neg_sisdr(est_sources, sources[None]), conf["train_conf"]["data"]["n_src"]) utt_metrics = {} if hasattr(test_set, "mixture_path"): utt_metrics["mix_path"] = test_set.mixture_path utt_metrics["best_perm_idx"] = ' '.join([str(pidx) for pidx in perms[best_perm_idx[0]]]) series_list.append(pd.Series(utt_metrics)) # Save all metrics to the experiment folder. all_metrics_df = pd.DataFrame(series_list) all_metrics_df.to_csv(os.path.join(eval_save_dir, "best_perms.csv"))
def forward_wav(self, wav, slice_size=32000, *args, **kwargs): """Separation method for waveforms. Unfolds a full audio into slices, estimate Args: wav (torch.Tensor): waveform array/tensor. Shape: 1D, 2D or 3D tensor, time last. Return: output_cat (torch.Tensor): concatenated output tensor. [num_spks, T] """ assert not self.training, "forward_wav is only used for test mode" T = wav.size(-1) if wav.ndim == 1: wav = wav.reshape(1, wav.size(0)) assert wav.ndim == 2 # [1, T] slice_stride = slice_size // 2 # pad wav to integer multiple of slice_stride T_padded = max(int(np.ceil(T / slice_stride)), 2) * slice_stride wav = F.pad(wav, (0, T_padded - T)) slices = wav.unfold( dimension=-1, size=slice_size, step=slice_stride ) # [1, slice_nb, slice_size] slice_nb = slices.size(1) slices = slices.squeeze(0).unsqueeze(1) tf_rep = self.enc_activation(self.encoder(slices)) est_masks_list = self.masker(tf_rep) selector_input = est_masks_list[-1] # [slice_nb, bn_chan, chunk_size, n_chunks] selector_output = self.decoder_select.selector(selector_input).reshape( slice_nb, -1 ) # [slice_nb, num_decs] est_idx, _ = selector_output.argmax(-1).mode() est_spks = self.decoder_select.n_srcs[est_idx] output_wavs, _ = self.decoder_select( est_masks_list, tf_rep, ground_truth=[est_spks] * slice_nb ) # [slice_nb, 1, n_spks, slice_size] output_wavs = output_wavs.squeeze(1)[:, :est_spks, :] # TODO: overlap and add (with division) output_cat = output_wavs.new_zeros(est_spks, slice_nb * slice_size) output_cat[:, :slice_size] = output_wavs[0] start = slice_stride for i in range(1, slice_nb): end = start + slice_size overlap_prev = output_cat[:, start : start + slice_stride].unsqueeze(0) overlap_next = output_wavs[i : i + 1, :, :slice_stride] pw_losses = pairwise_neg_sisdr(overlap_next, overlap_prev) _, best_indices = PITLossWrapper.find_best_perm(pw_losses) reordered = PITLossWrapper.reorder_source(output_wavs[i : i + 1, :, :], best_indices) output_cat[:, start : start + slice_size] += reordered.squeeze(0) output_cat[:, start : start + slice_stride] /= 2 start += slice_stride return output_cat[:, :T]
def forward(self, est_src, logits, src): """Forward Args: est_src: $(num_stages, n_src, T) logits: $(num_stages, num_decoders) src: $(n_src, T) """ assert est_src.size()[1:] == src.size() num_stages, n_src, T = est_src.size() target_src = src.unsqueeze(0).repeat(num_stages, 1, 1) target_idx = self.n_src2idx[n_src] pw_losses = pairwise_neg_sisdr(est_src, target_src) sdr_loss, _ = PITLossWrapper.find_best_perm(pw_losses) pos_sdr = -sdr_loss[-1] cls_target = torch.LongTensor([target_idx] * num_stages).to( logits.device) cls_loss = self.cce(logits, cls_target) correctness = logits[-1].argmax().item() == target_idx coeffs = torch.Tensor([ (c_idx + 1) * (1 / num_stages) for c_idx in range(num_stages) ]).to(logits.device) assert coeffs.size() == sdr_loss.size() == cls_loss.size() # use sum of SDR for each channel, not mean loss = torch.sum(coeffs * (sdr_loss * n_src + cls_loss * self.lamb)) return loss, pos_sdr, correctness
def test_system(): discriminator = Discriminator() generator = nn.Sequential(nn.Linear(10, 10), nn.ReLU()) opt_d = optim.Adam(discriminator.parameters(), lr=1e-3) opt_g = optim.Adam(generator.parameters(), lr=1e-3) scheduler_d = ReduceLROnPlateau(optimizer=opt_d, factor=0.5, patience=5) scheduler_g = ReduceLROnPlateau(optimizer=opt_g, factor=0.5, patience=5) g_loss = GeneratorLoss() d_loss = DiscriminatorLoss() validation_loss = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx') dataset = DummyDataset() loader = data.DataLoader(dataset, batch_size=4, num_workers=4) gan = TrainGAN(discriminator=discriminator, generator=generator, opt_d=opt_d, opt_g=opt_g, discriminator_loss=d_loss, generator_loss=g_loss, validation_loss=validation_loss, train_loader=loader, val_loader=loader, scheduler_d=scheduler_d, scheduler_g=scheduler_g) trainer = Trainer(max_epochs=1, fast_dev_run=True) trainer.fit(gan)
def __init__(self, alpha=0.1): super().__init__() assert alpha >= 0, "Negative alpha values don't make sense." assert alpha <= 1, "Alpha values above 1 don't make sense." # PIT loss self.src_mse = PITLossWrapper(pairwise_mse, pit_from='pw_mtx') self.alpha = alpha
def main(conf): model = get_model(conf) test_set = WhamDataset(conf['test_dir'], conf['task'], sample_rate=conf['sample_rate'], nondefault_nsrc=conf['nondefault_nsrc'], segment=None) loss_func = PITLossWrapper(pairwise_neg_sisdr, mode='pairwise') model_device = next(model.parameters()).device for idx in range(len(test_set)): mix, sources, _ = tensors_to_device(test_set[idx], device=model_device) est_sources = model(mix) loss, reordered_sources = loss_func(sources, est_sources, return_est=True) mix_np = mix.data.numpy()[0] sources_np = sources.data.numpy()[0] est_sources_np = reordered_sources.data.numpy()[0] # Waiting for pb_bss support to compute subset of metrics. # We will probably want SI-SDR, + add option for mir_eval SDR, stoi, # pesq input_metrics = InputMetrics(observation=mix_np, speech_source=sources_np, enable_si_sdr=True, sample_rate=conf["sample_rate"]) output_metrics = OutputMetrics(speech_prediction=est_sources_np, speech_source=sources_np, enable_si_sdr=True, sample_rate=conf["sample_rate"])
def test_proximity_sinkhorn_hungarian(batch_size, n_src, beta, n_iter, function_triplet): time = 16000 noise_level = 0.1 pairwise, nosrc, nonpit = function_triplet # random data targets = torch.randn(batch_size, n_src, time) * 10 # ground truth noise = torch.randn(batch_size, n_src, time) * noise_level est_targets = (targets[:, torch.randperm(n_src), :] + noise ) # reorder channels, and add small noise # initialize wrappers loss_sinkhorn = SinkPITLossWrapper(pairwise, n_iter=n_iter) loss_hungarian = PITLossWrapper(pairwise, pit_from="pw_mtx") # compute loss by sinkhorn loss_sinkhorn.beta = beta mean_loss_sinkhorn = loss_sinkhorn(est_targets, targets, return_est=False) # compute loss by hungarian mean_loss_hungarian = loss_hungarian(est_targets, targets, return_est=False) # compare assert_allclose(mean_loss_sinkhorn, mean_loss_hungarian)
def main(conf): # from asteroid.data.toy_data import WavSet # train_set = WavSet(n_ex=1000, n_src=2, ex_len=32000) # val_set = WavSet(n_ex=1000, n_src=2, ex_len=32000) # Define data pipeline train_set = WhamDataset(conf['data']['train_dir'], conf['data']['task'], sample_rate=conf['data']['sample_rate'], nondefault_nsrc=conf['data']['nondefault_nsrc']) val_set = WhamDataset(conf['data']['valid_dir'], conf['data']['task'], sample_rate=conf['data']['sample_rate'], nondefault_nsrc=conf['data']['nondefault_nsrc']) train_loader = DataLoader(train_set, shuffle=True, batch_size=conf['training']['batch_size'], num_workers=conf['training']['num_workers']) val_loader = DataLoader(val_set, shuffle=False, batch_size=conf['training']['batch_size'], num_workers=conf['training']['num_workers']) conf['masknet'].update({'n_src': train_set.n_src}) # Define model and optimizer in a local function (defined in the recipe). # Two advantages to this : re-instantiating the model and optimizer # for retraining and evaluating is straight-forward. model, optimizer = make_model_and_optimizer(conf) # Just after instantiating, save the args. Easy loading in the future. exp_dir = conf['main_args']['exp_dir'] os.makedirs(exp_dir, exist_ok=True) conf_path = os.path.join(exp_dir, 'conf.yml') with open(conf_path, 'w') as outfile: yaml.safe_dump(conf, outfile) # Define Loss function. loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx') # loss_class = PITLossContainer(pairwise_neg_sisdr, n_src=train_set.n_src) # Checkpointing callback can monitor any quantity which is returned by # validation step, defaults to val_loss here (see System). checkpoint_dir = os.path.join(exp_dir, 'checkpoints/') checkpoint = ModelCheckpoint(checkpoint_dir, monitor='val_loss', mode='min', save_best_only=False) # New PL version will come the 7th of december / will have save_top_k system = System(model=model, loss_func=loss_func, optimizer=optimizer, train_loader=train_loader, val_loader=val_loader, config=conf) trainer = pl.Trainer(max_nb_epochs=conf['training']['epochs'], checkpoint_callback=checkpoint, default_save_path=exp_dir, gpus=conf['main_args']['gpus'], distributed_backend='dp') trainer.fit(system)
def test_negstoi_pit(n_src, sample_rate, use_vad, extended): ref, est = torch.randn(2, n_src, 16000), torch.randn(2, n_src, 16000) singlesrc_negstoi = SingleSrcNegSTOI(sample_rate=sample_rate, use_vad=use_vad, extended=extended) loss_func = PITLossWrapper(singlesrc_negstoi, pit_from='pw_pt') # Assert forward ok. loss_value = loss_func(est, ref)
def test_negstoi_pit(n_src, sample_rate, use_vad, extended): ref, est = torch.randn(2, n_src, 8000), torch.randn(2, n_src, 8000) singlesrc_negstoi = SingleSrcNegSTOI(sample_rate=sample_rate, use_vad=use_vad, extended=extended) loss_func = PITLossWrapper(singlesrc_negstoi, pit_from="pw_pt") # Assert forward ok. with warnings.catch_warnings(): warnings.simplefilter("ignore") loss_func(est, ref)
def __init__(self, num_srcs, n_fft, hop_length, win_length, window, center): self.num_srcs = num_srcs self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length if window == 'hann': self.window = torch.hann_window(win_length).cuda() self.center = center self.loss = PITLossWrapper(PairwiseNegSDR("sisdr"), pit_from="pw_mtx")
def train_model_part(conf, train_part='filterbank', pretrained_filterbank=None): train_loader, val_loader = get_data_loaders(conf, train_part=train_part) # Define model and optimizer in a local function (defined in the recipe). # Two advantages to this : re-instantiating the model and optimizer # for retraining and evaluating is straight-forward. model, optimizer = make_model_and_optimizer( conf, model_part=train_part, pretrained_filterbank=pretrained_filterbank ) # Define scheduler scheduler = None if conf[train_part + '_training'][train_part[0] + '_half_lr']: scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5, patience=5) # Just after instantiating, save the args. Easy loading in the future. exp_dir, checkpoint_dir = get_encoded_paths(conf, train_part) os.makedirs(exp_dir, exist_ok=True) conf_path = os.path.join(exp_dir, 'conf.yml') with open(conf_path, 'w') as outfile: yaml.safe_dump(conf, outfile) # Define Loss function. loss_func = PITLossWrapper(PairwiseNegSDR('sisdr', zero_mean=False), pit_from='pw_mtx') system = SystemTwoStep(model=model, loss_func=loss_func, optimizer=optimizer, train_loader=train_loader, val_loader=val_loader, scheduler=scheduler, config=conf, module=train_part) # Define callbacks checkpoint = ModelCheckpoint(checkpoint_dir, monitor='val_loss', mode='min', save_top_k=1, verbose=1) early_stopping = False if conf[train_part + '_training'][train_part[0] + '_early_stop']: early_stopping = EarlyStopping(monitor='val_loss', patience=10, verbose=1) # Don't ask GPU if they are not available. if not torch.cuda.is_available(): print('No available GPU were found, set gpus to None') conf['main_args']['gpus'] = None trainer = pl.Trainer( max_nb_epochs=conf[train_part + '_training'][train_part[0] + '_epochs'], checkpoint_callback=checkpoint, early_stop_callback=early_stopping, default_save_path=exp_dir, gpus=conf['main_args']['gpus'], distributed_backend='dp', train_percent_check=1.0, # Useful for fast experiment gradient_clip_val=5.) trainer.fit(system) with open(os.path.join(checkpoint_dir, "best_k_models.json"), "w") as file: json.dump(checkpoint.best_k_models, file, indent=0)
def test_sisdr(n_src): targets = torch.randn(2, n_src, 32000) est_targets = torch.randn(2, n_src, 32000) pw_wrapper = PITLossWrapper(pairwise_neg_sisdr, mode='pairwise') wo_src_wrapper = PITLossWrapper(nosrc_neg_sisdr, mode='wo_src') w_src_wrapper = PITLossWrapper(nonpit_neg_sisdr, mode='w_src') pw = pw_wrapper(targets, est_targets) wo_src = wo_src_wrapper(targets, est_targets) w_src = w_src_wrapper(targets, est_targets) assert_allclose(pw_wrapper(targets, est_targets), wo_src_wrapper(targets, est_targets)) assert_allclose(w_src_wrapper(targets, est_targets), wo_src_wrapper(targets, est_targets)) assert_allclose(pw_wrapper(targets, est_targets, return_est=True)[1], wo_src_wrapper(targets, est_targets, return_est=True)[1]) assert_allclose(w_src_wrapper(targets, est_targets, return_est=True)[1], wo_src_wrapper(targets, est_targets, return_est=True)[1])
def test_permreduce(): from functools import partial n_src = 3 sources = torch.randn(10, n_src, 8000) est_sources = torch.randn(10, n_src, 8000) wo_reduce = PITLossWrapper(pairwise_mse, pit_from='pw_mtx') w_mean_reduce = PITLossWrapper( pairwise_mse, pit_from='pw_mtx', # perm_reduce=partial(torch.mean, dim=-1)) perm_reduce=lambda x: torch.mean(x, dim=-1)) w_sum_reduce = PITLossWrapper(pairwise_mse, pit_from='pw_mtx', perm_reduce=partial(torch.sum, dim=-1)) wo = wo_reduce(est_sources, sources) w_mean = w_mean_reduce(est_sources, sources) w_sum = w_sum_reduce(est_sources, sources) assert_allclose(wo, w_mean) assert_allclose(wo, w_sum / n_src)
def _reorder_sources( current: torch.FloatTensor, previous: torch.FloatTensor, n_src: int, window_size: int, hop_size: int, ): """ Reorder sources in current chunk to maximize correlation with previous chunk. Used for Continuous Source Separation. Standard dsp correlation is used for reordering. Args: current (:class:`torch.Tensor`): current chunk, tensor of shape (batch, n_src, window_size) previous (:class:`torch.Tensor`): previous chunk, tensor of shape (batch, n_src, window_size) n_src (:class:`int`): number of sources. window_size (:class:`int`): window_size, equal to last dimension of both current and previous. hop_size (:class:`int`): hop_size between current and previous tensors. Returns: current: """ batch, frames = current.size() current = current.reshape(-1, n_src, frames) previous = previous.reshape(-1, n_src, frames) overlap_f = window_size - hop_size pw_losses = PITLossWrapper.get_pw_losses( lambda x, y: torch.sum((x.unsqueeze(1) * y.unsqueeze(2))), current[..., :overlap_f], previous[..., -overlap_f:], ) _, perms = PITLossWrapper.find_best_perm(pw_losses, n_src) current = PITLossWrapper.reorder_source(current, n_src, perms) return current.reshape(batch, frames)
def test_pmsqe_pit(n_src, sample_rate): # Define supported STFT if sample_rate == 16000: stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256)) else: stft = Encoder(STFTFB(kernel_size=256, n_filters=256, stride=128)) # Usage by itself ref, est = torch.randn(2, n_src, 16000), torch.randn(2, n_src, 16000) ref_spec = transforms.mag(stft(ref)) est_spec = transforms.mag(stft(est)) loss_func = PITLossWrapper(SingleSrcPMSQE(sample_rate=sample_rate), pit_from="pw_pt") # Assert forward ok. loss_func(est_spec, ref_spec)
def test_multi_scale_spectral_PIT(n_src): # Test in with reduced number of STFT scales. filt_list = [512, 256, 32] # Fake targets and estimates targets = torch.randn(2, n_src, 8000) est_targets = torch.randn(2, n_src, 8000) # Create PITLossWrapper in 'pw_pt' mode pt_loss = SingleSrcMultiScaleSpectral(windows_size=filt_list, n_filters=filt_list, hops_size=filt_list) loss_func = PITLossWrapper(pt_loss, pit_from='pw_pt') # Compute the loss loss = loss_func(targets, est_targets)
def test_permutation(perm): """ Construct fake target/estimates pair. Check the value and reordering.""" n_src = len(perm) perm_tensor = torch.Tensor(perm) source_base = torch.ones(1, n_src, 10) sources = torch.arange(n_src).unsqueeze(-1) * source_base est_sources = perm_tensor.unsqueeze(-1) * source_base loss_func = PITLossWrapper(pairwise_mse) loss_value, reordered = loss_func(est_sources, sources, return_est=True) assert loss_value.item() == 0 assert_allclose(sources, reordered)
def main(conf): set_trace() test_set = WSJ2mixDataset(conf['data']['tt_wav_len_list'], conf['data']['wav_base_path'] + '/tt', sample_rate=conf['data']['sample_rate']) test_loader = DataLoader(test_set, shuffle=True, batch_size=1, num_workers=conf['data']['num_workers'], drop_last=False) istft = fb.Decoder(fb.STFTFB(**conf['filterbank'])) exp_dir = conf['main_args']['exp_dir'] model_path = os.path.join(exp_dir, 'checkpoints/_ckpt_epoch_0.ckpt') model = load_best_model(conf, model_path) pit_loss = PITLossWrapper(pairwise_mse, mode='pairwise') system = DcSystem(model, None, None, None, config=conf) # Randomly choose the indexes of sentences to save. exp_dir = conf['main_args']['exp_dir'] exp_save_dir = os.path.join(exp_dir, 'examples/') n_save = conf['main_args']['n_save_ex'] if n_save == -1: n_save = len(test_set) save_idx = random.sample(range(len(test_set)), n_save) series_list = [] torch.no_grad().__enter__() for batch in test_loader: batch = [ele.type(torch.float32) for ele in batch] inputs, targets, masks = system.unpack_data(batch) est_targets = system(inputs) mix_stft = system.enc(inputs.unsqueeze(1)) min_loss, min_idx = pit_loss.best_perm_from_perm_avg_loss(\ pairwise_mse, est_targets[1], masks) for sidx in min_idx: src_stft = mix_stft * est_targets[1][sidx] src_sig = istft(src_stft)
def test_wrapper(batch_size, n_src, time): targets = torch.randn(batch_size, n_src, time) est_targets = torch.randn(batch_size, n_src, time) for bad_loss_func in [bad_loss_func_ndim0, bad_loss_func_ndim1]: loss = PITLossWrapper(bad_loss_func) with pytest.raises(AssertionError): loss(est_targets, targets) # wo_src loss function / With and without return estimates loss = PITLossWrapper(good_batch_loss_func, pit_from="pw_pt") loss(est_targets, targets) loss_value, reordered_est = loss(est_targets, targets, return_est=True) assert reordered_est.shape == est_targets.shape # pairwise loss function / With and without return estimates loss = PITLossWrapper(good_pairwise_loss_func, pit_from="pw_mtx") loss(est_targets, targets) loss_value, reordered_est = loss(est_targets, targets, return_est=True) assert reordered_est.shape == est_targets.shape # w_src loss function / With and without return estimates loss = PITLossWrapper(good_batch_loss_func, pit_from="perm_avg") loss(est_targets, targets) loss_value, reordered_est = loss(est_targets, targets, return_est=True) assert reordered_est.shape == est_targets.shape
def test_permreduce_args(): def reduce_func(perm_losses, class_weights=None): # perm_losses is (batch , n_perms, n_src) for now if class_weights is None: return torch.mean(perm_losses, dim=-1) if class_weights.ndim == 2: class_weights = class_weights.unsqueeze(1) return torch.mean(perm_losses * class_weights, -1) n_src = 3 sources = torch.randn(10, n_src, 8000) est_sources = torch.randn(10, n_src, 8000) loss_func = PITLossWrapper(pairwise_mse, pit_from="pw_mtx", perm_reduce=reduce_func) weights = torch.softmax(torch.randn(10, n_src), dim=-1) loss_func(est_sources, sources, reduce_kwargs={"class_weights": weights})
def test_wrapper(batch_size, n_src, time): targets = torch.randn(batch_size, n_src, time) est_targets = torch.randn(batch_size, n_src, time) for bad_loss_func in [bad_loss_func_ndim0, bad_loss_func_ndim1]: loss = PITLossWrapper(bad_loss_func) with pytest.raises(AssertionError): loss(targets, est_targets) # wo_src loss function / With and without return estimates loss = PITLossWrapper(good_batch_loss_func, mode='wo_src') loss_value_no_return = loss(targets, est_targets) loss_value, reordered_est = loss(targets, est_targets, return_est=True) assert reordered_est.shape == est_targets.shape # pairwise loss function / With and without return estimates loss = PITLossWrapper(good_pairwise_loss_func, mode='pairwise') loss_value_no_return = loss(targets, est_targets) loss_value, reordered_est = loss(targets, est_targets, return_est=True) assert reordered_est.shape == est_targets.shape # w_src loss function / With and without return estimates loss = PITLossWrapper(good_batch_loss_func, mode='w_src') loss_value_no_return = loss(targets, est_targets) loss_value, reordered_est = loss(targets, est_targets, return_est=True) assert reordered_est.shape == est_targets.shape
def test_sisdr(n_src, function_triplet): # Unpack the triplet pairwise, nosrc, nonpit = function_triplet # Fake targets and estimates targets = torch.randn(2, n_src, 32000) est_targets = torch.randn(2, n_src, 32000) # Create the 3 PIT wrappers pw_wrapper = PITLossWrapper(pairwise, mode='pairwise') wo_src_wrapper = PITLossWrapper(nosrc, mode='wo_src') w_src_wrapper = PITLossWrapper(nonpit, mode='w_src') # Circular tests on value assert_allclose(pw_wrapper(est_targets, targets), wo_src_wrapper(est_targets, targets)) assert_allclose(wo_src_wrapper(est_targets, targets), w_src_wrapper(est_targets, targets)) # Circular tests on returned estimates assert_allclose( pw_wrapper(est_targets, targets, return_est=True)[1], wo_src_wrapper(est_targets, targets, return_est=True)[1]) assert_allclose( wo_src_wrapper(est_targets, targets, return_est=True)[1], w_src_wrapper(est_targets, targets, return_est=True)[1])
def _reorder_sources( current: torch.FloatTensor, previous: torch.FloatTensor, n_src: int, window_size: int, hop_size: int, ): """ Reorder sources in current chunk to maximize correlation with previous chunk. Used for Continuous Source Separation. Standard dsp correlation is used for reordering. Args: current (:class:`torch.Tensor`): current chunk, tensor of shape (batch, n_src, window_size) previous (:class:`torch.Tensor`): previous chunk, tensor of shape (batch, n_src, window_size) n_src (:class:`int`): number of sources. window_size (:class:`int`): window_size, equal to last dimension of both current and previous. hop_size (:class:`int`): hop_size between current and previous tensors. Returns: current: """ batch, frames = current.size() current = current.reshape(-1, n_src, frames) previous = previous.reshape(-1, n_src, frames) overlap_f = window_size - hop_size def reorder_func(x, y): x = x[..., :overlap_f] y = y[..., -overlap_f:] # Mean normalization x = x - x.mean(-1, keepdim=True) y = y - y.mean(-1, keepdim=True) # Negative mean Correlation return -torch.sum(x.unsqueeze(1) * y.unsqueeze(2), dim=-1) # We maximize correlation-like between previous and current. pit = PITLossWrapper(reorder_func) current = pit(current, previous, return_est=True)[1] return current.reshape(batch, frames)
def main(conf): train_set = WhamDataset( conf["data"]["train_dir"], conf["data"]["task"], sample_rate=conf["data"]["sample_rate"], segment=conf["data"]["segment"], nondefault_nsrc=conf["data"]["nondefault_nsrc"], ) val_set = WhamDataset( conf["data"]["valid_dir"], conf["data"]["task"], sample_rate=conf["data"]["sample_rate"], nondefault_nsrc=conf["data"]["nondefault_nsrc"], ) train_loader = DataLoader( train_set, shuffle=True, batch_size=conf["training"]["batch_size"], num_workers=conf["training"]["num_workers"], drop_last=True, ) val_loader = DataLoader( val_set, shuffle=False, batch_size=conf["training"]["batch_size"], num_workers=conf["training"]["num_workers"], drop_last=True, ) # Update number of source values (It depends on the task) conf["masknet"].update({"n_src": train_set.n_src}) model = DPTNet(**conf["filterbank"], **conf["masknet"]) optimizer = make_optimizer(model.parameters(), **conf["optim"]) from asteroid.engine.schedulers import DPTNetScheduler schedulers = { "scheduler": DPTNetScheduler(optimizer, len(train_loader) // conf["training"]["batch_size"], 64), "interval": "step", } # Just after instantiating, save the args. Easy loading in the future. exp_dir = conf["main_args"]["exp_dir"] os.makedirs(exp_dir, exist_ok=True) conf_path = os.path.join(exp_dir, "conf.yml") with open(conf_path, "w") as outfile: yaml.safe_dump(conf, outfile) # Define Loss function. loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx") system = System( model=model, loss_func=loss_func, optimizer=optimizer, scheduler=schedulers, train_loader=train_loader, val_loader=val_loader, config=conf, ) # Define callbacks checkpoint_dir = os.path.join(exp_dir, "checkpoints/") checkpoint = ModelCheckpoint(checkpoint_dir, monitor="val_loss", mode="min", save_top_k=5, verbose=True) early_stopping = False if conf["training"]["early_stop"]: early_stopping = EarlyStopping(monitor="val_loss", patience=30, verbose=True) # Don't ask GPU if they are not available. gpus = -1 if torch.cuda.is_available() else None trainer = pl.Trainer( max_epochs=conf["training"]["epochs"], checkpoint_callback=checkpoint, early_stop_callback=early_stopping, default_root_dir=exp_dir, gpus=gpus, distributed_backend="ddp", gradient_clip_val=conf["training"]["gradient_clipping"], ) trainer.fit(system) best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()} with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f: json.dump(best_k, f, indent=0) state_dict = torch.load(checkpoint.best_model_path) system.load_state_dict(state_dict=state_dict["state_dict"]) system.cpu() to_save = system.model.serialize() to_save.update(train_set.get_infos()) torch.save(to_save, os.path.join(exp_dir, "best_model.pth"))
def main(conf): train_set = LibriMix(csv_dir=conf['data']['train_dir'], task=conf['data']['task'], sample_rate=conf['data']['sample_rate'], n_src=conf['data']['n_src'], segment=conf['data']['segment']) val_set = LibriMix(csv_dir=conf['data']['valid_dir'], task=conf['data']['task'], sample_rate=conf['data']['sample_rate'], n_src=conf['data']['n_src'], segment=conf['data']['segment']) train_loader = DataLoader(train_set, shuffle=True, batch_size=conf['training']['batch_size'], num_workers=conf['training']['num_workers'], drop_last=True) val_loader = DataLoader(val_set, shuffle=True, batch_size=conf['training']['batch_size'], num_workers=conf['training']['num_workers'], drop_last=True) conf['masknet'].update({'n_src': conf['data']['n_src']}) # Define model and optimizer in a local function (defined in the recipe). # Two advantages to this : re-instantiating the model and optimizer # for retraining and evaluating is straight-forward. model, optimizer = make_model_and_optimizer(conf) # Define scheduler scheduler = None if conf['training']['half_lr']: scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5, patience=5) # Just after instantiating, save the args. Easy loading in the future. exp_dir = conf['main_args']['exp_dir'] os.makedirs(exp_dir, exist_ok=True) conf_path = os.path.join(exp_dir, 'conf.yml') with open(conf_path, 'w') as outfile: yaml.safe_dump(conf, outfile) # Define Loss function. loss_func = PITLossWrapper(pairwise_neg_sisdr, mode='pairwise') system = System(model=model, loss_func=loss_func, optimizer=optimizer, train_loader=train_loader, val_loader=val_loader, scheduler=scheduler, config=conf) # Define callbacks checkpoint_dir = os.path.join(exp_dir, 'checkpoints/') checkpoint = ModelCheckpoint(checkpoint_dir, monitor='val_loss', mode='min', save_top_k=5, verbose=1) early_stopping = False if conf['training']['early_stop']: early_stopping = EarlyStopping(monitor='val_loss', patience=10, verbose=1) # Don't ask GPU if they are not available. if not torch.cuda.is_available(): print('No available GPU were found, set gpus to None') conf['main_args']['gpus'] = None trainer = pl.Trainer(max_epochs=conf['training']['epochs'], checkpoint_callback=checkpoint, early_stop_callback=early_stopping, default_save_path=exp_dir, gpus=conf['main_args']['gpus'], distributed_backend='dp', train_percent_check=1.0, # Useful for fast experiment gradient_clip_val=5.) trainer.fit(system) with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f: json.dump(checkpoint.best_k_models, f, indent=0)
def main(conf): train_set = WhamDataset(conf['data']['train_dir'], conf['data']['task'], sample_rate=conf['data']['sample_rate'], nondefault_nsrc=conf['data']['nondefault_nsrc']) val_set = WhamDataset(conf['data']['valid_dir'], conf['data']['task'], sample_rate=conf['data']['sample_rate'], nondefault_nsrc=conf['data']['nondefault_nsrc']) train_loader = DataLoader(train_set, shuffle=True, batch_size=conf['training']['batch_size'], num_workers=conf['training']['num_workers'], drop_last=True) val_loader = DataLoader(val_set, shuffle=False, batch_size=conf['training']['batch_size'], num_workers=conf['training']['num_workers'], drop_last=True) # Update number of source values (It depends on the task) conf['masknet'].update({'n_src': train_set.n_src}) # Define model and optimizer model = ConvTasNet(**conf['filterbank'], **conf['masknet']) optimizer = make_optimizer(model.parameters(), **conf['optim']) # Define scheduler scheduler = None if conf['training']['half_lr']: scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5, patience=5) # Just after instantiating, save the args. Easy loading in the future. exp_dir = conf['main_args']['exp_dir'] os.makedirs(exp_dir, exist_ok=True) conf_path = os.path.join(exp_dir, 'conf.yml') with open(conf_path, 'w') as outfile: yaml.safe_dump(conf, outfile) # Define Loss function. loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx') system = System(model=model, loss_func=loss_func, optimizer=optimizer, train_loader=train_loader, val_loader=val_loader, scheduler=scheduler, config=conf) # Define callbacks checkpoint_dir = os.path.join(exp_dir, 'checkpoints/') checkpoint = ModelCheckpoint(checkpoint_dir, monitor='val_loss', mode='min', save_top_k=5, verbose=1) early_stopping = False if conf['training']['early_stop']: early_stopping = EarlyStopping(monitor='val_loss', patience=10, verbose=1) # Don't ask GPU if they are not available. gpus = -1 if torch.cuda.is_available() else None trainer = pl.Trainer( max_epochs=conf['training']['epochs'], checkpoint_callback=checkpoint, early_stop_callback=early_stopping, default_save_path=exp_dir, gpus=gpus, distributed_backend='dp', train_percent_check=1.0, # Useful for fast experiment gradient_clip_val=5.) trainer.fit(system) best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()} with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f: json.dump(best_k, f, indent=0) # Save best model (next PL version will make this easier) best_path = [b for b, v in best_k.items() if v == min(best_k.values())][0] state_dict = torch.load(best_path) system.load_state_dict(state_dict=state_dict['state_dict']) system.cpu() to_save = system.model.serialize() to_save.update(train_set.get_infos()) torch.save(to_save, os.path.join(exp_dir, 'best_model.pth'))