Exemple #1
0
def make_model_and_optimizer(conf):
    """ Function to define the model and optimizer for a config dictionary.
    Args:
        conf: Dictionary containing the output of hierachical argparse.
    Returns:
        model, optimizer.
    The main goal of this function is to make reloading for resuming
    and evaluation very simple.
    """
    enc = fb.Encoder(fb.STFTFB(**conf['filterbank']))
    masker = ChimeraPP(int(enc.filterbank.n_feats_out/2), 2,
                       embedding_dim=20, n_layers=2, hidden_size=600, \
                       dropout=0, bidirectional=True)
    model = Model(enc, masker)
    optimizer = make_optimizer(model.parameters(), **conf['optim'])
    return model, optimizer
Exemple #2
0
 def __init__(self,
              model,
              optimizer,
              loss_func,
              train_loader,
              val_loader=None,
              scheduler=None,
              config=None):
     super().__init__(model,
                      optimizer,
                      loss_func,
                      train_loader,
                      val_loader=val_loader,
                      scheduler=scheduler,
                      config=config)
     self.enc = fb.Encoder(fb.STFTFB(**config['filterbank']))
Exemple #3
0
def main(conf):
    set_trace()
    test_set = WSJ2mixDataset(conf['data']['tt_wav_len_list'],
                              conf['data']['wav_base_path'] + '/tt',
                              sample_rate=conf['data']['sample_rate'])
    test_loader = DataLoader(test_set,
                             shuffle=True,
                             batch_size=1,
                             num_workers=conf['data']['num_workers'],
                             drop_last=False)
    istft = fb.Decoder(fb.STFTFB(**conf['filterbank']))
    exp_dir = conf['main_args']['exp_dir']
    model_path = os.path.join(exp_dir, 'checkpoints/_ckpt_epoch_0.ckpt')
    model = load_best_model(conf, model_path)
    pit_loss = PITLossWrapper(pairwise_mse, mode='pairwise')

    system = DcSystem(model, None, None, None, config=conf)

    # Randomly choose the indexes of sentences to save.
    exp_dir = conf['main_args']['exp_dir']
    exp_save_dir = os.path.join(exp_dir, 'examples/')
    n_save = conf['main_args']['n_save_ex']
    if n_save == -1:
        n_save = len(test_set)
    save_idx = random.sample(range(len(test_set)), n_save)
    series_list = []
    torch.no_grad().__enter__()

    for batch in test_loader:
        batch = [ele.type(torch.float32) for ele in batch]
        inputs, targets, masks = system.unpack_data(batch)
        est_targets = system(inputs)
        mix_stft = system.enc(inputs.unsqueeze(1))
        min_loss, min_idx = pit_loss.best_perm_from_perm_avg_loss(\
                pairwise_mse, est_targets[1], masks)
        for sidx in min_idx:
            src_stft = mix_stft * est_targets[1][sidx]
            src_sig = istft(src_stft)
Exemple #4
0
import numpy as np
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from torch import optim

from asteroid.filterbanks.transforms import take_mag
import asteroid.filterbanks as fb
from asteroid.data.wsj0_mix import WSJ2mixDataset, BucketingSampler, \
        collate_fn
from asteroid.masknn.blocks import SingleRNN
from asteroid.losses import PITLossWrapper, pairwise_mse
from asteroid.losses import deep_clustering_loss

EPS = torch.finfo(torch.float32).eps
enc = fb.Encoder(fb.STFTFB(256, 256, stride=64))
enc = enc.cuda()

parser = argparse.ArgumentParser()
parser.add_argument('--gpus', type=str, help='list of GPUs', default='-1')
parser.add_argument('--exp_dir',
                    default='exp/tmp',
                    help='Full path to save best validation model')

pit_loss = PITLossWrapper(pairwise_mse, mode='pairwise')


class Model(nn.Module):
    def __init__(self):
        #def __init__(self, in_chan, n_src, rnn_type = 'lstm',
        #        embedding_dim=20, n_layers=2, hidden_size=600,