示例#1
0
def make_model_and_optimizer(conf):
    """ Function to define the model and optimizer for a config dictionary.
    Args:
        conf: Dictionary containing the output of hierachical argparse.
    Returns:
        model, optimizer.
    The main goal of this function is to make reloading for resuming
    and evaluation very simple.
    """
    # Define building blocks for local model
    stft, istft = make_enc_dec('stft', **conf['filterbank'])
    # Because we concatenate (re, im, mag) as input and compute a complex mask.
    if conf['main_args']['is_complex']:
        inp_size = int(stft.n_feats_out * 3 / 2)
        output_size = stft.n_feats_out
    else:
        inp_size = output_size = int(stft.n_feats_out / 2)
    # Add these fields to the mask model dict
    conf['masknet'].update(dict(input_size=inp_size, output_size=output_size))
    masker = SimpleModel(**conf['masknet'])
    # Make the complete model
    model = Model(stft,
                  masker,
                  istft,
                  is_complex=conf['main_args']['is_complex'])
    # Define optimizer of this model
    optimizer = make_optimizer(model.parameters(), **conf['optim'])
    return model, optimizer
示例#2
0
def check_ola():
    kernel_size = fb_config['kernel_size']
    enc, dec = make_enc_dec('stft', window=None, **fb_config)

    inp = torch.ones(1, 1, 4096)
    tf_rep = dec(enc(inp))[:, :, kernel_size:-kernel_size]
    testing.assert_allclose(tf_rep, tf_rep.mean())
示例#3
0
def test_dcunet():
    n_fft = 1024
    _, istft = make_enc_dec("stft",
                            n_filters=n_fft,
                            kernel_size=1024,
                            stride=256,
                            sample_rate=16000)
    input_samples = istft(torch.zeros((n_fft + 2, 17))).shape[0]
    _default_test_model(DCUNet("DCUNet-10"), input_samples=input_samples)
    _default_test_model(DCUNet("DCUNet-16"), input_samples=input_samples)
    _default_test_model(DCUNet("DCUNet-20"), input_samples=input_samples)
    _default_test_model(DCUNet("Large-DCUNet-20"), input_samples=input_samples)
    _default_test_model(DCUNet("DCUNet-10", n_src=2),
                        input_samples=input_samples)

    # DCUMaskNet should fail with wrong freqency dimensions
    DCUNet("mini").masker(torch.zeros((1, 9, 17), dtype=torch.complex64))
    with pytest.raises(TypeError):
        DCUNet("mini").masker(torch.zeros((1, 42, 17), dtype=torch.complex64))

    # DCUMaskNet should fail with wrong time dimensions if fix_length_mode is not used
    DCUNet("mini", fix_length_mode="pad").masker(
        torch.zeros((1, 9, 17), dtype=torch.complex64))
    DCUNet("mini", fix_length_mode="trim").masker(
        torch.zeros((1, 9, 17), dtype=torch.complex64))
    with pytest.raises(TypeError):
        DCUNet("mini").masker(torch.zeros((1, 9, 16), dtype=torch.complex64))
示例#4
0
def make_model_and_optimizer(conf):
    """ Function to define the model and optimizer for a config dictionary.
    Args:
        conf: Dictionary containing the output of hierachical argparse.
    Returns:
        model, optimizer.
    The main goal of this function is to make reloading for resuming
    and evaluation very simple.
    """
    # Define building blocks for local model
    # The encoder and decoder can directly be made from the dictionary.
    encoder, decoder = fb.make_enc_dec(**conf['filterbank'])

    # The input post-processing changes the dimensions of input features to
    # the mask network. Different type of masks impose different output
    # dimensions to the mask network's output. We correct for these here.
    nn_in = int(encoder.n_feats_out * encoder.in_chan_mul)
    nn_out = int(encoder.n_feats_out * encoder.out_chan_mul)
    masker = TDConvNet(in_chan=nn_in, out_chan=nn_out, **conf['masknet'])
    # Another possibility is to correct for these effects inside of Model,
    # but then instantiation of masker should also be done inside.
    model = Model(encoder, masker, decoder)

    # The model is defined in Container, which is passed to DataParallel.

    # Define optimizer : can be instantiate from dictonary as well.
    optimizer = make_optimizer(model.parameters(), **conf['optim'])
    return model, optimizer
示例#5
0
def test_stft_def(fb_config):
    """ Check consistency between two calls."""
    fb = STFTFB(**fb_config)
    enc = Encoder(fb)
    dec = Decoder(fb)
    enc2, dec2 = make_enc_dec('stft', **fb_config)
    testing.assert_allclose(enc.filterbank.filters, enc2.filterbank.filters)
    testing.assert_allclose(dec.filterbank.filters, dec2.filterbank.filters)
示例#6
0
def test_perfect_istft_default_parameters(fb_config):
    """ Unit test perfect reconstruction with default values. """
    kernel_size = fb_config['kernel_size']
    enc, dec = make_enc_dec('stft', **fb_config)
    inp_wav = torch.randn(2, 1, 32000)
    out_wav = dec(enc(inp_wav))[:, :, kernel_size: -kernel_size]
    inp_test = inp_wav[:, :, kernel_size: -kernel_size]
    testing.assert_allclose(inp_test, out_wav)
示例#7
0
def test_istft():
    """ Without dividing by the overlap-added window, the STFT iSTFT cannot
    pass the unit test. Uncomment the plot to see the perfect resynthesis."""
    kernel_size = fb_config['kernel_size']
    enc, dec = make_enc_dec('stft', **fb_config)
    inp_wav = torch.randn(2, 1, 32000)
    out_wav = dec(enc(inp_wav))[:, :, kernel_size:-kernel_size]
    inp_test = inp_wav[:, :, kernel_size:-kernel_size]
示例#8
0
def main(conf):
    # Define data pipeline with datasets and loaders
    train_set = WhamDataset(conf['data']['train_dir'],
                            conf['data']['task'],
                            sample_rate=conf['data']['sample_rate'],
                            nondefault_nsrc=conf['data']['nondefault_nsrc'])
    val_set = WhamDataset(conf['data']['valid_dir'],
                          conf['data']['task'],
                          sample_rate=conf['data']['sample_rate'],
                          nondefault_nsrc=conf['data']['nondefault_nsrc'])

    train_loader = DataLoader(train_set,
                              shuffle=True,
                              batch_size=conf['data']['batch_size'],
                              num_workers=conf['data']['num_workers'])
    val_loader = DataLoader(val_set,
                            shuffle=True,
                            batch_size=conf['data']['batch_size'],
                            num_workers=conf['data']['num_workers'])
    loaders = {'train_loader': train_loader, 'val_loader': val_loader}

    # Define model

    # First define the encoder and the decoder.
    # This can be either done by passing a string and the config
    # dictionary (with number of filters, filter size and stride, see conf.yml)
    # to fb.make_enc_dec.
    enc, dec = fb.make_enc_dec('free', **conf['filterbank'])
    # Or done by instantiating the filterbanks and passing them to the
    # Encoder and Decoder classes, as follows :
    # enc = fb.Encoder(fb.FreeFB(**conf['filterbank']))
    # dec = fb.Encoder(fb.FreeFB(**conf['filterbank']))

    # Define the mask network with input and output dimensions dictated by
    # by the encoder (also passing a dictionary defined in conf.yml).
    masker = TDConvNet(in_chan=enc.filterbank.n_feats_out,
                       out_chan=enc.filterbank.n_feats_out,
                       n_src=train_set.n_src,
                       **conf['masknet'])
    # Pass the encoder, masker and decoder to the container class which
    # handles the forward for such architectures
    model = nn.DataParallel(Container(enc, masker, dec))
    if conf['main_args']['use_cuda']:
        model.cuda()
    # Define Loss function
    loss_class = PITLossContainer(pairwise_neg_sisdr, n_src=train_set.n_src)
    # Define optimizer
    optimizer = make_optimizer(model.parameters(), **conf['optim'])

    # Pass everything to the solver with a training dicitonary defined in
    # the conf.yml file. Finally, call .train() and that's it.
    solver = Solver(loaders,
                    model,
                    loss_class,
                    optimizer,
                    model_path=conf['main_args']['model_path'],
                    **conf['training'])
    solver.train()
示例#9
0
def test_dccrnet():
    _, istft = make_enc_dec("stft", 512, 512)
    input_samples = istft(torch.zeros((514, 16))).shape[0]
    _default_test_model(DCCRNet("DCCRN-CL"), input_samples=input_samples)
    _default_test_model(DCCRNet("DCCRN-CL", n_src=2), input_samples=input_samples)

    # DCCRMaskNet should fail with wrong input dimensions
    DCCRNet("mini").masker(torch.zeros((1, 256, 3), dtype=torch.complex64))
    with pytest.raises(TypeError):
        DCCRNet("mini").masker(torch.zeros((1, 42, 3), dtype=torch.complex64))
示例#10
0
def test_serialization():
    enc, dec = make_enc_dec('free', n_filters=512, kernel_size=16, stride=8)
    masker = TDConvNet(in_chan=512, n_src=2, out_chan=512)
    container = Container(enc, masker, dec)
    inp = torch.randn(2, 1, 16000)
    out = container(inp)
    # Serialize
    model_pack = container.serialize()
    # Load and forward
    new_model = Container(enc, masker, dec)
    new_model.load_model(model_pack)
    new_out = new_model(inp)
    # Check
    testing.assert_allclose(out, new_out)
示例#11
0
def make_model_and_optimizer(conf):
    """Function to define the model and optimizer for a config dictionary.
    Args:
        conf: Dictionary containing the output of hierachical argparse.
    Returns:
        model, optimizer.
    The main goal of this function is to make reloading for resuming
    and evaluation very simple.
    """
    enc, dec = fb.make_enc_dec("stft", **conf["filterbank"])
    masker = Chimera(enc.n_feats_out // 2, **conf["masknet"])
    model = Model(enc, masker, dec)
    optimizer = make_optimizer(model.parameters(), **conf["optim"])
    return model, optimizer
示例#12
0
def make_generator_and_optimizer(conf):
    """ Function to define the model and optimizer for a config dictionary.
    Args:
        conf: Dictionary containing the output of hierachical argparse.
    Returns:
        model, optimizer.
    The main goal of this function is to make reloading for resuming
    and evaluation very simple.
    """
    encoder, decoder = make_enc_dec(**conf['filterbank'])
    model = Generator(encoder, decoder)
    # Define optimizer of this model
    optimizer = make_optimizer(model.parameters(), **conf['optim'])
    g_loss = GeneratorLoss(conf['g_loss']['s'])
    return model, optimizer, g_loss
示例#13
0
def test_ola(kernel_size, stride):
    """ Unit-test the perfect OLA for boxcar weighted DFT filters."""
    fb_config = {
        'n_filters': 2 * kernel_size,
        'kernel_size': kernel_size,
        'stride': stride
    }
    # Make STFT filters with no analysis and synthesis windows.
    # kernel_size = fb_config['kernel_size']
    enc, dec = make_enc_dec('stft', window=None, **fb_config)
    # Input a boxcar function
    inp = torch.ones(1, 1, 4096)
    # Analysis-synthesis. Cut leading and trailing frames.
    synth = dec(enc(inp))[:, :, kernel_size:-kernel_size]
    # Assert that an boxcar input returns a boxcar output.
    testing.assert_allclose(synth, inp[:, :, kernel_size:-kernel_size])
示例#14
0
def test_dcunet():
    _, istft = make_enc_dec("stft", 512, 512)
    input_samples = istft(torch.zeros((514, 17))).shape[0]
    _default_test_model(DCUNet("DCUNet-10"), input_samples=input_samples)
    _default_test_model(DCUNet("DCUNet-10", n_src=2), input_samples=input_samples)

    # DCUMaskNet should fail with wrong freqency dimensions
    DCUNet("mini").masker(torch.zeros((1, 9, 17), dtype=torch.complex64))
    with pytest.raises(TypeError):
        DCUNet("mini").masker(torch.zeros((1, 42, 17), dtype=torch.complex64))

    # DCUMaskNet should fail with wrong time dimensions if fix_length_mode is not used
    DCUNet("mini", fix_length_mode="pad").masker(torch.zeros((1, 9, 17), dtype=torch.complex64))
    DCUNet("mini", fix_length_mode="trim").masker(torch.zeros((1, 9, 17), dtype=torch.complex64))
    with pytest.raises(TypeError):
        DCUNet("mini").masker(torch.zeros((1, 9, 16), dtype=torch.complex64))
示例#15
0
文件: model.py 项目: zwb0626/asteroid
def make_model_and_optimizer(conf):
    """ Function to define the model and optimizer for a config dictionary.
    Args:
        conf: Dictionary containing the output of hierachical argparse.
    Returns:
        model, optimizer.
    The main goal of this function is to make reloading for resuming
    and evaluation very simple.
    """
    # Define building blocks for local model
    enc, dec = fb.make_enc_dec('free', **conf['filterbank'])
    masker = DPRNN(**conf['masknet'])
    model = Model(enc, masker, dec)
    # Define optimizer of this model
    optimizer = make_optimizer(model.parameters(), **conf['optim'])
    return model, optimizer
示例#16
0
def test_dccrnet():
    n_fft = 512
    _, istft = make_enc_dec("stft",
                            n_filters=n_fft,
                            kernel_size=400,
                            stride=100,
                            sample_rate=16000)
    input_samples = istft(torch.zeros((n_fft + 2, 16))).shape[0]
    _default_test_model(DCCRNet("DCCRN-CL"), input_samples=input_samples)
    _default_test_model(DCCRNet("DCCRN-CL", n_src=2),
                        input_samples=input_samples)

    # DCCRMaskNet should fail with wrong input dimensions
    DCCRNet("mini").masker(torch.zeros((1, 256, 3), dtype=torch.complex64))
    with pytest.raises(TypeError):
        DCCRNet("mini").masker(torch.zeros((1, 42, 3), dtype=torch.complex64))
示例#17
0
def make_discriminator_and_optimizer(conf):
    """ Function to define the model and optimizer for a config dictionary.
    Args:
        conf: Dictionary containing the output of hierachical argparse.

    Returns:
        model, optimizer.
    The main goal of this function is to make reloading for resuming
    and evaluation very simple.
    """
    # Define building blocks for local model
    encoder, decoder = make_enc_dec(**conf['filterbank'])
    model = Discriminator(encoder, decoder)
    # Define optimizer of this model
    optimizer = make_optimizer(model.parameters(), **conf['optim'])
    d_loss = DiscriminatorLoss(conf['metric_to_opt']['metric'],
                               conf['data']['rate'])
    return model, optimizer, d_loss
示例#18
0
def main(conf):
    # Define data pipeline
    train_set = WhamDataset(conf['data']['train_dir'], conf['data']['task'],
                            sample_rate=conf['data']['sample_rate'],
                            nondefault_nsrc=conf['data']['nondefault_nsrc'])
    val_set = WhamDataset(conf['data']['valid_dir'], conf['data']['task'],
                          sample_rate=conf['data']['sample_rate'],
                          nondefault_nsrc=conf['data']['nondefault_nsrc'])

    train_loader = DataLoader(train_set, shuffle=True,
                              batch_size=conf['data']['batch_size'],
                              num_workers=conf['data']['num_workers'])
    val_loader = DataLoader(val_set, shuffle=True,
                            batch_size=conf['data']['batch_size'],
                            num_workers=conf['data']['num_workers'])
    loaders = {'train_loader': train_loader, 'val_loader': val_loader}

    # Define model
    # The encoder and decoder can directly be made from the dictionary.
    encoder, decoder = filterbanks.make_enc_dec(**conf['filterbank'])

    # The input post-processing changes the dimensions of input features to
    # the mask network. Different type of masks impose different output
    # dimensions to the mask network's output. We correct for these here.
    nn_in = int(encoder.n_feats_out * encoder.in_chan_mul)
    nn_out = int(encoder.n_feats_out * encoder.out_chan_mul)
    masker = TDConvNet(in_chan=nn_in, out_chan=nn_out,
                       n_src=train_set.n_src, **conf['masknet'])
    # The model is defined in Container, which is passed to DataParallel.
    model = nn.DataParallel(Container(encoder, masker, decoder))
    if conf['main_args']['use_cuda']:
        model.cuda()

    # Define Loss function : Here we use time domain SI-SDR.
    loss_class = PITLossContainer(pairwise_neg_sisdr, n_src=train_set.n_src)
    # Define optimizer : can be instantiate from dictonary as well.
    optimizer = make_optimizer(model.parameters(), **conf['optim'])

    # Pass everything to the solver and train
    solver = Solver(loaders, model, loss_class, optimizer,
                    model_path=conf['main_args']['model_path'],
                    **conf['training'])
    # solver.train()
    solver.run_one_epoch(0, validation=True)
 def __init__(
     self,
     fb_name="free",
     kernel_size=16,
     n_filters=32,
     stride=8,
     encoder_activation=None,
     **fb_kwargs,
 ):
     encoder, decoder = make_enc_dec(fb_name,
                                     kernel_size=kernel_size,
                                     n_filters=n_filters,
                                     stride=stride,
                                     **fb_kwargs)
     masker = torch.nn.Identity()
     super().__init__(encoder,
                      masker,
                      decoder,
                      encoder_activation=encoder_activation)
示例#20
0
def test_make_enc_dec(who):
    fb_config = {"n_filters": 500, "kernel_size": 16, "stride": 8}
    enc, dec = make_enc_dec("free", who_is_pinv=who, **fb_config)
    enc, dec = make_enc_dec(FreeFB, who_is_pinv=who, **fb_config)
    assert enc.filterbank == filterbanks.get(enc.filterbank)
示例#21
0
def test_nomasker():
    enc, dec = make_enc_dec('free', n_filters=512, kernel_size=16, stride=8)
    container = Container(enc, None, dec)
    inp = torch.randn(2, 1, 16000)
    out = container(inp)
    assert inp.shape == out.shape
示例#22
0
def test_dcunet():
    _, istft = make_enc_dec("stft", 512, 512)
    _default_test_model(DCUNet("DCUNet-10"), input_samples=istft(torch.zeros((514, 17))).shape[0])