def make_model_and_optimizer(conf): """ Function to define the model and optimizer for a config dictionary. Args: conf: Dictionary containing the output of hierachical argparse. Returns: model, optimizer. The main goal of this function is to make reloading for resuming and evaluation very simple. """ # Define building blocks for local model stft, istft = make_enc_dec('stft', **conf['filterbank']) # Because we concatenate (re, im, mag) as input and compute a complex mask. if conf['main_args']['is_complex']: inp_size = int(stft.n_feats_out * 3 / 2) output_size = stft.n_feats_out else: inp_size = output_size = int(stft.n_feats_out / 2) # Add these fields to the mask model dict conf['masknet'].update(dict(input_size=inp_size, output_size=output_size)) masker = SimpleModel(**conf['masknet']) # Make the complete model model = Model(stft, masker, istft, is_complex=conf['main_args']['is_complex']) # Define optimizer of this model optimizer = make_optimizer(model.parameters(), **conf['optim']) return model, optimizer
def check_ola(): kernel_size = fb_config['kernel_size'] enc, dec = make_enc_dec('stft', window=None, **fb_config) inp = torch.ones(1, 1, 4096) tf_rep = dec(enc(inp))[:, :, kernel_size:-kernel_size] testing.assert_allclose(tf_rep, tf_rep.mean())
def test_dcunet(): n_fft = 1024 _, istft = make_enc_dec("stft", n_filters=n_fft, kernel_size=1024, stride=256, sample_rate=16000) input_samples = istft(torch.zeros((n_fft + 2, 17))).shape[0] _default_test_model(DCUNet("DCUNet-10"), input_samples=input_samples) _default_test_model(DCUNet("DCUNet-16"), input_samples=input_samples) _default_test_model(DCUNet("DCUNet-20"), input_samples=input_samples) _default_test_model(DCUNet("Large-DCUNet-20"), input_samples=input_samples) _default_test_model(DCUNet("DCUNet-10", n_src=2), input_samples=input_samples) # DCUMaskNet should fail with wrong freqency dimensions DCUNet("mini").masker(torch.zeros((1, 9, 17), dtype=torch.complex64)) with pytest.raises(TypeError): DCUNet("mini").masker(torch.zeros((1, 42, 17), dtype=torch.complex64)) # DCUMaskNet should fail with wrong time dimensions if fix_length_mode is not used DCUNet("mini", fix_length_mode="pad").masker( torch.zeros((1, 9, 17), dtype=torch.complex64)) DCUNet("mini", fix_length_mode="trim").masker( torch.zeros((1, 9, 17), dtype=torch.complex64)) with pytest.raises(TypeError): DCUNet("mini").masker(torch.zeros((1, 9, 16), dtype=torch.complex64))
def make_model_and_optimizer(conf): """ Function to define the model and optimizer for a config dictionary. Args: conf: Dictionary containing the output of hierachical argparse. Returns: model, optimizer. The main goal of this function is to make reloading for resuming and evaluation very simple. """ # Define building blocks for local model # The encoder and decoder can directly be made from the dictionary. encoder, decoder = fb.make_enc_dec(**conf['filterbank']) # The input post-processing changes the dimensions of input features to # the mask network. Different type of masks impose different output # dimensions to the mask network's output. We correct for these here. nn_in = int(encoder.n_feats_out * encoder.in_chan_mul) nn_out = int(encoder.n_feats_out * encoder.out_chan_mul) masker = TDConvNet(in_chan=nn_in, out_chan=nn_out, **conf['masknet']) # Another possibility is to correct for these effects inside of Model, # but then instantiation of masker should also be done inside. model = Model(encoder, masker, decoder) # The model is defined in Container, which is passed to DataParallel. # Define optimizer : can be instantiate from dictonary as well. optimizer = make_optimizer(model.parameters(), **conf['optim']) return model, optimizer
def test_stft_def(fb_config): """ Check consistency between two calls.""" fb = STFTFB(**fb_config) enc = Encoder(fb) dec = Decoder(fb) enc2, dec2 = make_enc_dec('stft', **fb_config) testing.assert_allclose(enc.filterbank.filters, enc2.filterbank.filters) testing.assert_allclose(dec.filterbank.filters, dec2.filterbank.filters)
def test_perfect_istft_default_parameters(fb_config): """ Unit test perfect reconstruction with default values. """ kernel_size = fb_config['kernel_size'] enc, dec = make_enc_dec('stft', **fb_config) inp_wav = torch.randn(2, 1, 32000) out_wav = dec(enc(inp_wav))[:, :, kernel_size: -kernel_size] inp_test = inp_wav[:, :, kernel_size: -kernel_size] testing.assert_allclose(inp_test, out_wav)
def test_istft(): """ Without dividing by the overlap-added window, the STFT iSTFT cannot pass the unit test. Uncomment the plot to see the perfect resynthesis.""" kernel_size = fb_config['kernel_size'] enc, dec = make_enc_dec('stft', **fb_config) inp_wav = torch.randn(2, 1, 32000) out_wav = dec(enc(inp_wav))[:, :, kernel_size:-kernel_size] inp_test = inp_wav[:, :, kernel_size:-kernel_size]
def main(conf): # Define data pipeline with datasets and loaders train_set = WhamDataset(conf['data']['train_dir'], conf['data']['task'], sample_rate=conf['data']['sample_rate'], nondefault_nsrc=conf['data']['nondefault_nsrc']) val_set = WhamDataset(conf['data']['valid_dir'], conf['data']['task'], sample_rate=conf['data']['sample_rate'], nondefault_nsrc=conf['data']['nondefault_nsrc']) train_loader = DataLoader(train_set, shuffle=True, batch_size=conf['data']['batch_size'], num_workers=conf['data']['num_workers']) val_loader = DataLoader(val_set, shuffle=True, batch_size=conf['data']['batch_size'], num_workers=conf['data']['num_workers']) loaders = {'train_loader': train_loader, 'val_loader': val_loader} # Define model # First define the encoder and the decoder. # This can be either done by passing a string and the config # dictionary (with number of filters, filter size and stride, see conf.yml) # to fb.make_enc_dec. enc, dec = fb.make_enc_dec('free', **conf['filterbank']) # Or done by instantiating the filterbanks and passing them to the # Encoder and Decoder classes, as follows : # enc = fb.Encoder(fb.FreeFB(**conf['filterbank'])) # dec = fb.Encoder(fb.FreeFB(**conf['filterbank'])) # Define the mask network with input and output dimensions dictated by # by the encoder (also passing a dictionary defined in conf.yml). masker = TDConvNet(in_chan=enc.filterbank.n_feats_out, out_chan=enc.filterbank.n_feats_out, n_src=train_set.n_src, **conf['masknet']) # Pass the encoder, masker and decoder to the container class which # handles the forward for such architectures model = nn.DataParallel(Container(enc, masker, dec)) if conf['main_args']['use_cuda']: model.cuda() # Define Loss function loss_class = PITLossContainer(pairwise_neg_sisdr, n_src=train_set.n_src) # Define optimizer optimizer = make_optimizer(model.parameters(), **conf['optim']) # Pass everything to the solver with a training dicitonary defined in # the conf.yml file. Finally, call .train() and that's it. solver = Solver(loaders, model, loss_class, optimizer, model_path=conf['main_args']['model_path'], **conf['training']) solver.train()
def test_dccrnet(): _, istft = make_enc_dec("stft", 512, 512) input_samples = istft(torch.zeros((514, 16))).shape[0] _default_test_model(DCCRNet("DCCRN-CL"), input_samples=input_samples) _default_test_model(DCCRNet("DCCRN-CL", n_src=2), input_samples=input_samples) # DCCRMaskNet should fail with wrong input dimensions DCCRNet("mini").masker(torch.zeros((1, 256, 3), dtype=torch.complex64)) with pytest.raises(TypeError): DCCRNet("mini").masker(torch.zeros((1, 42, 3), dtype=torch.complex64))
def test_serialization(): enc, dec = make_enc_dec('free', n_filters=512, kernel_size=16, stride=8) masker = TDConvNet(in_chan=512, n_src=2, out_chan=512) container = Container(enc, masker, dec) inp = torch.randn(2, 1, 16000) out = container(inp) # Serialize model_pack = container.serialize() # Load and forward new_model = Container(enc, masker, dec) new_model.load_model(model_pack) new_out = new_model(inp) # Check testing.assert_allclose(out, new_out)
def make_model_and_optimizer(conf): """Function to define the model and optimizer for a config dictionary. Args: conf: Dictionary containing the output of hierachical argparse. Returns: model, optimizer. The main goal of this function is to make reloading for resuming and evaluation very simple. """ enc, dec = fb.make_enc_dec("stft", **conf["filterbank"]) masker = Chimera(enc.n_feats_out // 2, **conf["masknet"]) model = Model(enc, masker, dec) optimizer = make_optimizer(model.parameters(), **conf["optim"]) return model, optimizer
def make_generator_and_optimizer(conf): """ Function to define the model and optimizer for a config dictionary. Args: conf: Dictionary containing the output of hierachical argparse. Returns: model, optimizer. The main goal of this function is to make reloading for resuming and evaluation very simple. """ encoder, decoder = make_enc_dec(**conf['filterbank']) model = Generator(encoder, decoder) # Define optimizer of this model optimizer = make_optimizer(model.parameters(), **conf['optim']) g_loss = GeneratorLoss(conf['g_loss']['s']) return model, optimizer, g_loss
def test_ola(kernel_size, stride): """ Unit-test the perfect OLA for boxcar weighted DFT filters.""" fb_config = { 'n_filters': 2 * kernel_size, 'kernel_size': kernel_size, 'stride': stride } # Make STFT filters with no analysis and synthesis windows. # kernel_size = fb_config['kernel_size'] enc, dec = make_enc_dec('stft', window=None, **fb_config) # Input a boxcar function inp = torch.ones(1, 1, 4096) # Analysis-synthesis. Cut leading and trailing frames. synth = dec(enc(inp))[:, :, kernel_size:-kernel_size] # Assert that an boxcar input returns a boxcar output. testing.assert_allclose(synth, inp[:, :, kernel_size:-kernel_size])
def test_dcunet(): _, istft = make_enc_dec("stft", 512, 512) input_samples = istft(torch.zeros((514, 17))).shape[0] _default_test_model(DCUNet("DCUNet-10"), input_samples=input_samples) _default_test_model(DCUNet("DCUNet-10", n_src=2), input_samples=input_samples) # DCUMaskNet should fail with wrong freqency dimensions DCUNet("mini").masker(torch.zeros((1, 9, 17), dtype=torch.complex64)) with pytest.raises(TypeError): DCUNet("mini").masker(torch.zeros((1, 42, 17), dtype=torch.complex64)) # DCUMaskNet should fail with wrong time dimensions if fix_length_mode is not used DCUNet("mini", fix_length_mode="pad").masker(torch.zeros((1, 9, 17), dtype=torch.complex64)) DCUNet("mini", fix_length_mode="trim").masker(torch.zeros((1, 9, 17), dtype=torch.complex64)) with pytest.raises(TypeError): DCUNet("mini").masker(torch.zeros((1, 9, 16), dtype=torch.complex64))
def make_model_and_optimizer(conf): """ Function to define the model and optimizer for a config dictionary. Args: conf: Dictionary containing the output of hierachical argparse. Returns: model, optimizer. The main goal of this function is to make reloading for resuming and evaluation very simple. """ # Define building blocks for local model enc, dec = fb.make_enc_dec('free', **conf['filterbank']) masker = DPRNN(**conf['masknet']) model = Model(enc, masker, dec) # Define optimizer of this model optimizer = make_optimizer(model.parameters(), **conf['optim']) return model, optimizer
def test_dccrnet(): n_fft = 512 _, istft = make_enc_dec("stft", n_filters=n_fft, kernel_size=400, stride=100, sample_rate=16000) input_samples = istft(torch.zeros((n_fft + 2, 16))).shape[0] _default_test_model(DCCRNet("DCCRN-CL"), input_samples=input_samples) _default_test_model(DCCRNet("DCCRN-CL", n_src=2), input_samples=input_samples) # DCCRMaskNet should fail with wrong input dimensions DCCRNet("mini").masker(torch.zeros((1, 256, 3), dtype=torch.complex64)) with pytest.raises(TypeError): DCCRNet("mini").masker(torch.zeros((1, 42, 3), dtype=torch.complex64))
def make_discriminator_and_optimizer(conf): """ Function to define the model and optimizer for a config dictionary. Args: conf: Dictionary containing the output of hierachical argparse. Returns: model, optimizer. The main goal of this function is to make reloading for resuming and evaluation very simple. """ # Define building blocks for local model encoder, decoder = make_enc_dec(**conf['filterbank']) model = Discriminator(encoder, decoder) # Define optimizer of this model optimizer = make_optimizer(model.parameters(), **conf['optim']) d_loss = DiscriminatorLoss(conf['metric_to_opt']['metric'], conf['data']['rate']) return model, optimizer, d_loss
def main(conf): # Define data pipeline train_set = WhamDataset(conf['data']['train_dir'], conf['data']['task'], sample_rate=conf['data']['sample_rate'], nondefault_nsrc=conf['data']['nondefault_nsrc']) val_set = WhamDataset(conf['data']['valid_dir'], conf['data']['task'], sample_rate=conf['data']['sample_rate'], nondefault_nsrc=conf['data']['nondefault_nsrc']) train_loader = DataLoader(train_set, shuffle=True, batch_size=conf['data']['batch_size'], num_workers=conf['data']['num_workers']) val_loader = DataLoader(val_set, shuffle=True, batch_size=conf['data']['batch_size'], num_workers=conf['data']['num_workers']) loaders = {'train_loader': train_loader, 'val_loader': val_loader} # Define model # The encoder and decoder can directly be made from the dictionary. encoder, decoder = filterbanks.make_enc_dec(**conf['filterbank']) # The input post-processing changes the dimensions of input features to # the mask network. Different type of masks impose different output # dimensions to the mask network's output. We correct for these here. nn_in = int(encoder.n_feats_out * encoder.in_chan_mul) nn_out = int(encoder.n_feats_out * encoder.out_chan_mul) masker = TDConvNet(in_chan=nn_in, out_chan=nn_out, n_src=train_set.n_src, **conf['masknet']) # The model is defined in Container, which is passed to DataParallel. model = nn.DataParallel(Container(encoder, masker, decoder)) if conf['main_args']['use_cuda']: model.cuda() # Define Loss function : Here we use time domain SI-SDR. loss_class = PITLossContainer(pairwise_neg_sisdr, n_src=train_set.n_src) # Define optimizer : can be instantiate from dictonary as well. optimizer = make_optimizer(model.parameters(), **conf['optim']) # Pass everything to the solver and train solver = Solver(loaders, model, loss_class, optimizer, model_path=conf['main_args']['model_path'], **conf['training']) # solver.train() solver.run_one_epoch(0, validation=True)
def __init__( self, fb_name="free", kernel_size=16, n_filters=32, stride=8, encoder_activation=None, **fb_kwargs, ): encoder, decoder = make_enc_dec(fb_name, kernel_size=kernel_size, n_filters=n_filters, stride=stride, **fb_kwargs) masker = torch.nn.Identity() super().__init__(encoder, masker, decoder, encoder_activation=encoder_activation)
def test_make_enc_dec(who): fb_config = {"n_filters": 500, "kernel_size": 16, "stride": 8} enc, dec = make_enc_dec("free", who_is_pinv=who, **fb_config) enc, dec = make_enc_dec(FreeFB, who_is_pinv=who, **fb_config) assert enc.filterbank == filterbanks.get(enc.filterbank)
def test_nomasker(): enc, dec = make_enc_dec('free', n_filters=512, kernel_size=16, stride=8) container = Container(enc, None, dec) inp = torch.randn(2, 1, 16000) out = container(inp) assert inp.shape == out.shape
def test_dcunet(): _, istft = make_enc_dec("stft", 512, 512) _default_test_model(DCUNet("DCUNet-10"), input_samples=istft(torch.zeros((514, 17))).shape[0])