class Model(object): def __init__(self, opt): super(Model, self).__init__() # Generator self.gen = Generator(opt).cuda(opt.gpu_id) self.gen_params = self.gen.parameters() num_params = 0 for p in self.gen.parameters(): num_params += p.numel() print(self.gen) print(num_params) # Discriminator self.dis = Discriminator(opt).cuda(opt.gpu_id) self.dis_params = self.dis.parameters() num_params = 0 for p in self.dis.parameters(): num_params += p.numel() print(self.dis) print(num_params) # Regressor if opt.mse_weight: self.reg = torch.load('data/utils/classifier.pth').cuda( opt.gpu_id).eval() else: self.reg = None # Losses self.criterion_gan = GANLoss(opt, self.dis) self.criterion_mse = lambda x, y: l1_loss(x, y) * opt.mse_weight self.loss_mse = Variable(torch.zeros(1).cuda()) self.loss_adv = Variable(torch.zeros(1).cuda()) self.loss = Variable(torch.zeros(1).cuda()) self.path = opt.experiments_dir + opt.experiment_name + '/checkpoints/' self.gpu_id = opt.gpu_id self.noise_channels = opt.in_channels - len(opt.input_idx.split(',')) def forward(self, inputs): input, input_orig, target = inputs self.input = Variable(input.cuda(self.gpu_id)) self.input_orig = Variable(input_orig.cuda(self.gpu_id)) self.target = Variable(target.cuda(self.gpu_id)) noise = Variable( torch.randn(self.input.size(0), self.noise_channels).cuda(self.gpu_id)) self.fake = self.gen(torch.cat([self.input, noise], 1)) def backward_G(self): # Regressor loss if self.reg is not None: fake_input = self.reg(self.fake) self.loss_mse = self.criterion_mse(fake_input, self.input_orig) # GAN loss loss_adv, _ = self.criterion_gan(self.fake) loss_G = self.loss_mse + loss_adv loss_G.backward() def backward_D(self): loss_adv, self.loss_adv = self.criterion_gan(self.target, self.fake) loss_D = loss_adv loss_D.backward() def train(self): self.gen.train() self.dis.train() def eval(self): self.gen.eval() self.dis.eval() def save_checkpoint(self, epoch): torch.save( { 'epoch': epoch, 'gen_state_dict': self.gen.state_dict(), 'dis_state_dict': self.dis.state_dict() }, self.path + '%d.pkl' % epoch) def load_checkpoint(self, path, pretrained=True): weights = torch.load(path) self.gen.load_state_dict(weights['gen_state_dict']) self.dis.load_state_dict(weights['dis_state_dict'])
class GanTrainer(Trainer): def __init__(self, train_loader, test_loader, valid_loader, general_args, trainer_args): super(GanTrainer, self).__init__(train_loader, test_loader, valid_loader, general_args) # Paths self.loadpath = trainer_args.loadpath self.savepath = trainer_args.savepath # Load the auto-encoder self.use_autoencoder = False if trainer_args.autoencoder_path and os.path.exists( trainer_args.autoencoder_path): self.use_autoencoder = True self.autoencoder = AutoEncoder(general_args=general_args).to( self.device) self.load_pretrained_autoencoder(trainer_args.autoencoder_path) self.autoencoder.eval() # Load the generator self.generator = Generator(general_args=general_args).to(self.device) if trainer_args.generator_path and os.path.exists( trainer_args.generator_path): self.load_pretrained_generator(trainer_args.generator_path) self.discriminator = Discriminator(general_args=general_args).to( self.device) # Optimizers and schedulers self.generator_optimizer = torch.optim.Adam( params=self.generator.parameters(), lr=trainer_args.generator_lr) self.discriminator_optimizer = torch.optim.Adam( params=self.discriminator.parameters(), lr=trainer_args.discriminator_lr) self.generator_scheduler = lr_scheduler.StepLR( optimizer=self.generator_optimizer, step_size=trainer_args.generator_scheduler_step, gamma=trainer_args.generator_scheduler_gamma) self.discriminator_scheduler = lr_scheduler.StepLR( optimizer=self.discriminator_optimizer, step_size=trainer_args.discriminator_scheduler_step, gamma=trainer_args.discriminator_scheduler_gamma) # Load saved states if os.path.exists(self.loadpath): self.load() # Loss function and stored losses self.adversarial_criterion = nn.BCEWithLogitsLoss() self.generator_time_criterion = nn.MSELoss() self.generator_frequency_criterion = nn.MSELoss() self.generator_autoencoder_criterion = nn.MSELoss() # Define labels self.real_label = 1 self.generated_label = 0 # Loss scaling factors self.lambda_adv = trainer_args.lambda_adversarial self.lambda_freq = trainer_args.lambda_freq self.lambda_autoencoder = trainer_args.lambda_autoencoder # Spectrogram converter self.spectrogram = Spectrogram(normalized=True).to(self.device) # Boolean indicating if the model needs to be saved self.need_saving = True # Boolean if the generator receives the feedback from the discriminator self.use_adversarial = trainer_args.use_adversarial def load_pretrained_generator(self, generator_path): """ Loads a pre-trained generator. Can be used to stabilize the training. :param generator_path: location of the pre-trained generator (string). :return: None """ checkpoint = torch.load(generator_path, map_location=self.device) self.generator.load_state_dict(checkpoint['generator_state_dict']) def load_pretrained_autoencoder(self, autoencoder_path): """ Loads a pre-trained auto-encoder. Can be used to infer :param autoencoder_path: location of the pre-trained auto-encoder (string). :return: None """ checkpoint = torch.load(autoencoder_path, map_location=self.device) self.autoencoder.load_state_dict(checkpoint['autoencoder_state_dict']) def train(self, epochs): """ Trains the GAN for a given number of pseudo-epochs. :param epochs: Number of time to iterate over a part of the dataset (int). :return: None """ for epoch in range(epochs): for i in range(self.train_batches_per_epoch): self.generator.train() self.discriminator.train() # Transfer to GPU local_batch = next(self.train_loader_iter) input_batch, target_batch = local_batch[0].to( self.device), local_batch[1].to(self.device) batch_size = input_batch.shape[0] ############################ # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) ########################### # Train the discriminator with real data self.discriminator_optimizer.zero_grad() label = torch.full((batch_size, ), self.real_label, device=self.device) output = self.discriminator(target_batch) # Compute and store the discriminator loss on real data loss_discriminator_real = self.adversarial_criterion( output, torch.unsqueeze(label, dim=1)) self.train_losses['discriminator_adversarial']['real'].append( loss_discriminator_real.item()) loss_discriminator_real.backward() # Train the discriminator with fake data generated_batch = self.generator(input_batch) label.fill_(self.generated_label) output = self.discriminator(generated_batch.detach()) # Compute and store the discriminator loss on fake data loss_discriminator_generated = self.adversarial_criterion( output, torch.unsqueeze(label, dim=1)) self.train_losses['discriminator_adversarial']['fake'].append( loss_discriminator_generated.item()) loss_discriminator_generated.backward() # Update the discriminator weights self.discriminator_optimizer.step() ############################ # Update G network: maximize log(D(G(z))) ########################### self.generator_optimizer.zero_grad() # Get the spectrogram specgram_target_batch = self.spectrogram(target_batch) specgram_fake_batch = self.spectrogram(generated_batch) # Fake labels are real for the generator cost label.fill_(self.real_label) output = self.discriminator(generated_batch) # Compute the generator loss on fake data # Get the adversarial loss loss_generator_adversarial = torch.zeros(size=[1], device=self.device) if self.use_adversarial: loss_generator_adversarial = self.adversarial_criterion( output, torch.unsqueeze(label, dim=1)) self.train_losses['generator_adversarial'].append( loss_generator_adversarial.item()) # Get the L2 loss in time domain loss_generator_time = self.generator_time_criterion( generated_batch, target_batch) self.train_losses['time_l2'].append(loss_generator_time.item()) # Get the L2 loss in frequency domain loss_generator_frequency = self.generator_frequency_criterion( specgram_fake_batch, specgram_target_batch) self.train_losses['freq_l2'].append( loss_generator_frequency.item()) # Get the L2 loss in embedding space loss_generator_autoencoder = torch.zeros(size=[1], device=self.device, requires_grad=True) if self.use_autoencoder: # Get the embeddings _, embedding_target_batch = self.autoencoder(target_batch) _, embedding_generated_batch = self.autoencoder( generated_batch) loss_generator_autoencoder = self.generator_autoencoder_criterion( embedding_generated_batch, embedding_target_batch) self.train_losses['autoencoder_l2'].append( loss_generator_autoencoder.item()) # Combine the different losses loss_generator = self.lambda_adv * loss_generator_adversarial + loss_generator_time + \ self.lambda_freq * loss_generator_frequency + \ self.lambda_autoencoder * loss_generator_autoencoder # Back-propagate and update the generator weights loss_generator.backward() self.generator_optimizer.step() # Print message if not (i % 10): message = 'Batch {}: \n' \ '\t Generator: \n' \ '\t\t Time: {} \n' \ '\t\t Frequency: {} \n' \ '\t\t Autoencoder {} \n' \ '\t\t Adversarial: {} \n' \ '\t Discriminator: \n' \ '\t\t Real {} \n' \ '\t\t Fake {} \n'.format(i, loss_generator_time.item(), loss_generator_frequency.item(), loss_generator_autoencoder.item(), loss_generator_adversarial.item(), loss_discriminator_real.item(), loss_discriminator_generated.item()) print(message) # Evaluate the model with torch.no_grad(): self.eval() # Save the trainer state self.save() # if self.need_saving: # self.save() # Increment epoch counter self.epoch += 1 self.generator_scheduler.step() self.discriminator_scheduler.step() def eval(self): self.generator.eval() self.discriminator.eval() batch_losses = {'time_l2': [], 'freq_l2': []} for i in range(self.valid_batches_per_epoch): # Transfer to GPU local_batch = next(self.valid_loader_iter) input_batch, target_batch = local_batch[0].to( self.device), local_batch[1].to(self.device) generated_batch = self.generator(input_batch) # Get the spectrogram specgram_target_batch = self.spectrogram(target_batch) specgram_generated_batch = self.spectrogram(generated_batch) loss_generator_time = self.generator_time_criterion( generated_batch, target_batch) batch_losses['time_l2'].append(loss_generator_time.item()) loss_generator_frequency = self.generator_frequency_criterion( specgram_generated_batch, specgram_target_batch) batch_losses['freq_l2'].append(loss_generator_frequency.item()) # Store the validation losses self.valid_losses['time_l2'].append(np.mean(batch_losses['time_l2'])) self.valid_losses['freq_l2'].append(np.mean(batch_losses['freq_l2'])) # Display validation losses message = 'Epoch {}: \n' \ '\t Time: {} \n' \ '\t Frequency: {} \n'.format(self.epoch, np.mean(np.mean(batch_losses['time_l2'])), np.mean(np.mean(batch_losses['freq_l2']))) print(message) # Check if the loss is decreasing self.check_improvement() def save(self): """ Saves the model(s), optimizer(s), scheduler(s) and losses :return: None """ torch.save( { 'epoch': self.epoch, 'generator_state_dict': self.generator.state_dict(), 'discriminator_state_dict': self.discriminator.state_dict(), 'generator_optimizer_state_dict': self.generator_optimizer.state_dict(), 'discriminator_optimizer_state_dict': self.discriminator_optimizer.state_dict(), 'generator_scheduler_state_dict': self.generator_scheduler.state_dict(), 'discriminator_scheduler_state_dict': self.discriminator_scheduler.state_dict(), 'train_losses': self.train_losses, 'test_losses': self.test_losses, 'valid_losses': self.valid_losses }, self.savepath) def load(self): """ Loads the model(s), optimizer(s), scheduler(s) and losses :return: None """ checkpoint = torch.load(self.loadpath, map_location=self.device) self.epoch = checkpoint['epoch'] self.generator.load_state_dict(checkpoint['generator_state_dict']) self.discriminator.load_state_dict( checkpoint['discriminator_state_dict']) self.generator_optimizer.load_state_dict( checkpoint['generator_optimizer_state_dict']) self.discriminator_optimizer.load_state_dict( checkpoint['discriminator_optimizer_state_dict']) self.generator_scheduler.load_state_dict( checkpoint['generator_scheduler_state_dict']) self.discriminator_scheduler.load_state_dict( checkpoint['discriminator_scheduler_state_dict']) self.train_losses = checkpoint['train_losses'] self.test_losses = checkpoint['test_losses'] self.valid_losses = checkpoint['valid_losses'] def evaluate_metrics(self, n_batches): """ Evaluates the quality of the reconstruction with the SNR and LSD metrics on a specified number of batches :param: n_batches: number of batches to process :return: mean and std for each metric """ with torch.no_grad(): snrs = [] lsds = [] generator = self.generator.eval() for k in range(n_batches): # Transfer to GPU local_batch = next(self.test_loader_iter) # Transfer to GPU input_batch, target_batch = local_batch[0].to( self.device), local_batch[1].to(self.device) # Generates a batch generated_batch = generator(input_batch) # Get the metrics snrs.append( snr(x=generated_batch.squeeze(), x_ref=target_batch.squeeze())) lsds.append( lsd(x=generated_batch.squeeze(), x_ref=target_batch.squeeze())) snrs = torch.cat(snrs).cpu().numpy() lsds = torch.cat(lsds).cpu().numpy() # Some signals corresponding to silence will be all zeroes and cause troubles due to the logarithm snrs[np.isinf(snrs)] = np.nan lsds[np.isinf(lsds)] = np.nan return np.nanmean(snrs), np.nanstd(snrs), np.nanmean(lsds), np.nanstd( lsds)
class WGanTrainer(Trainer): def __init__(self, train_loader, test_loader, valid_loader, general_args, trainer_args): super(WGanTrainer, self).__init__(train_loader, test_loader, valid_loader, general_args) # Paths self.loadpath = trainer_args.loadpath self.savepath = trainer_args.savepath # Load the generator self.generator = Generator(general_args=general_args).to(self.device) if trainer_args.generator_path and os.path.exists( trainer_args.generator_path): self.load_pretrained_generator(trainer_args.generator_path) self.discriminator = Discriminator(general_args=general_args).to( self.device) # Optimizers and schedulers self.generator_optimizer = torch.optim.Adam( params=self.generator.parameters(), lr=trainer_args.generator_lr) self.discriminator_optimizer = torch.optim.Adam( params=self.discriminator.parameters(), lr=trainer_args.discriminator_lr) self.generator_scheduler = lr_scheduler.StepLR( optimizer=self.generator_optimizer, step_size=trainer_args.generator_scheduler_step, gamma=trainer_args.generator_scheduler_gamma) self.discriminator_scheduler = lr_scheduler.StepLR( optimizer=self.discriminator_optimizer, step_size=trainer_args.discriminator_scheduler_step, gamma=trainer_args.discriminator_scheduler_gamma) # Load saved states if os.path.exists(self.loadpath): self.load() # Loss function and stored losses self.generator_time_criterion = nn.MSELoss() # Loss scaling factors self.lambda_adv = trainer_args.lambda_adversarial self.lambda_time = trainer_args.lambda_time # Boolean indicating if the model needs to be saved self.need_saving = True # Overrides losses from parent class self.train_losses = { 'generator': { 'time_l2': [], 'adversarial': [] }, 'discriminator': { 'penalty': [], 'adversarial': [] } } self.test_losses = { 'generator': { 'time_l2': [], 'adversarial': [] }, 'discriminator': { 'penalty': [], 'adversarial': [] } } self.valid_losses = { 'generator': { 'time_l2': [], 'adversarial': [] }, 'discriminator': { 'penalty': [], 'adversarial': [] } } # Select either wgan or wgan-gp method self.use_penalty = trainer_args.use_penalty self.gamma = trainer_args.gamma_wgan_gp self.clipping_limit = trainer_args.clipping_limit self.n_critic = trainer_args.n_critic self.coupling_epoch = trainer_args.coupling_epoch def load_pretrained_generator(self, generator_path): """ Loads a pre-trained generator. Can be used to stabilize the training. :param generator_path: location of the pre-trained generator (string). :return: None """ checkpoint = torch.load(generator_path, map_location=self.device) self.generator.load_state_dict(checkpoint['generator_state_dict']) def compute_gradient_penalty(self, input_batch, generated_batch): """ Compute the gradient penalty as described in the original paper (https://papers.nips.cc/paper/7159-improved-training-of-wasserstein-gans.pdf). :param input_batch: batch of input data (torch tensor). :param generated_batch: batch of generated data (torch tensor). :return: penalty as a scalar (torch tensor). """ batch_size = input_batch.size(0) epsilon = torch.rand(batch_size, 1, 1) epsilon = epsilon.expand_as(input_batch).to(self.device) # Interpolate interpolation = epsilon * input_batch.data + ( 1 - epsilon) * generated_batch.data interpolation = interpolation.requires_grad_(True).to(self.device) # Computes the discriminator's prediction for the interpolated input interpolation_logits = self.discriminator(interpolation) # Computes a vector of outputs to make it works with 2 output classes if needed grad_outputs = torch.ones_like(interpolation_logits).to( self.device).requires_grad_(True) # Get the gradients and retain the graph so that the penalty can be back-propagated gradients = autograd.grad(outputs=interpolation_logits, inputs=interpolation, grad_outputs=grad_outputs, create_graph=True, retain_graph=True, only_inputs=True)[0] gradients = gradients.view(batch_size, -1) # Computes the norm of the gradients gradients_norm = torch.sqrt(torch.sum(gradients**2, dim=1)) return ((gradients_norm - 1)**2).mean() def train_discriminator_step(self, input_batch, target_batch): """ Trains the discriminator for a single step based on the wasserstein gan-gp framework. :param input_batch: batch of input data (torch tensor). :param target_batch: batch of target data (torch tensor). :return: a batch of generated data (torch tensor). """ # Activate gradient tracking for the discriminator self.change_discriminator_grad_requirement(requires_grad=True) # Set the discriminator's gradients to zero self.discriminator_optimizer.zero_grad() # Generate a batch and compute the penalty generated_batch = self.generator(input_batch) # Compute the loss loss_d = self.discriminator(generated_batch.detach()).mean( ) - self.discriminator(target_batch).mean() self.train_losses['discriminator']['adversarial'].append(loss_d.item()) if self.use_penalty: penalty = self.compute_gradient_penalty(input_batch, generated_batch.detach()) self.train_losses['discriminator']['penalty'].append( penalty.item()) loss_d = loss_d + self.gamma * penalty # Update the discriminator's weights loss_d.backward() self.discriminator_optimizer.step() # Apply the weight constraint if needed if not self.use_penalty: for p in self.discriminator.parameters(): p.data.clamp_(min=-self.clipping_limit, max=self.clipping_limit) # Return the generated batch to avoid redundant computation return generated_batch def train_generator_step(self, target_batch, generated_batch): """ Trains the generator for a single step based on the wasserstein gan-gp framework. :param target_batch: batch of target data (torch tensor). :param generated_batch: batch of generated data (torch tensor). :return: None """ # Deactivate gradient tracking for the discriminator self.change_discriminator_grad_requirement(requires_grad=False) # Set generator's gradients to zero self.generator_optimizer.zero_grad() # Get the generator losses loss_g_adversarial = -self.discriminator(generated_batch).mean() loss_g_time = self.generator_time_criterion(generated_batch, target_batch) # Combine the different losses loss_g = loss_g_time if self.epoch >= self.coupling_epoch: loss_g = loss_g + self.lambda_adv * loss_g_adversarial # Back-propagate and update the generator weights loss_g.backward() self.generator_optimizer.step() # Store the losses self.train_losses['generator']['time_l2'].append(loss_g_time.item()) self.train_losses['generator']['adversarial'].append( loss_g_adversarial.item()) def change_discriminator_grad_requirement(self, requires_grad): """ Changes the requires_grad flag of discriminator's parameters. This action is not absolutely needed as the discriminator's optimizer is not called after the generators update, but it reduces the computational cost. :param requires_grad: flag indicating if the discriminator's parameter require gradient tracking (boolean). :return: None """ for p in self.discriminator.parameters(): p.requires_grad_(requires_grad) def train(self, epochs): """ Trains the WGAN-GP for a given number of pseudo-epochs. :param epochs: Number of time to iterate over a part of the dataset (int). :return: None """ self.generator.train() self.discriminator.train() for epoch in range(epochs): for i in range(self.train_batches_per_epoch): # Transfer to GPU local_batch = next(self.train_loader_iter) input_batch, target_batch = local_batch[0].to( self.device), local_batch[1].to(self.device) # Train the discriminator generated_batch = self.train_discriminator_step( input_batch, target_batch) # Train the generator every n_critic if not (i % self.n_critic): self.train_generator_step(target_batch, generated_batch) # Print message if not (i % 10): message = 'Batch {}: \n' \ '\t Generator: \n' \ '\t\t Time: {} \n' \ '\t\t Adversarial: {} \n' \ '\t Discriminator: \n' \ '\t\t Penalty: {}\n' \ '\t\t Adversarial: {} \n'.format(i, self.train_losses['generator']['time_l2'][-1], self.train_losses['generator']['adversarial'][-1], self.train_losses['discriminator']['penalty'][-1], self.train_losses['discriminator']['adversarial'][-1]) print(message) # Evaluate the model with torch.no_grad(): self.eval() # Save the trainer state self.save() # Increment epoch counter self.epoch += 1 self.generator_scheduler.step() self.discriminator_scheduler.step() def eval(self): # Set the models in evaluation mode self.generator.eval() self.discriminator.eval() batch_losses = {'time_l2': []} for i in range(self.valid_batches_per_epoch): # Transfer to GPU local_batch = next(self.valid_loader_iter) input_batch, target_batch = local_batch[0].to( self.device), local_batch[1].to(self.device) generated_batch = self.generator(input_batch) loss_g_time = self.generator_time_criterion( generated_batch, target_batch) batch_losses['time_l2'].append(loss_g_time.item()) # Store the validation losses self.valid_losses['generator']['time_l2'].append( np.mean(batch_losses['time_l2'])) # Display validation losses message = 'Epoch {}: \n' \ '\t Time: {} \n'.format(self.epoch, np.mean(np.mean(batch_losses['time_l2']))) print(message) # Set the models in train mode self.generator.train() self.discriminator.eval() def save(self): """ Saves the model(s), optimizer(s), scheduler(s) and losses :return: None """ savepath = self.savepath.split('.')[0] + '_' + str( self.epoch // 5) + '.' + self.savepath.split('.')[1] torch.save( { 'epoch': self.epoch, 'generator_state_dict': self.generator.state_dict(), 'discriminator_state_dict': self.discriminator.state_dict(), 'generator_optimizer_state_dict': self.generator_optimizer.state_dict(), 'discriminator_optimizer_state_dict': self.discriminator_optimizer.state_dict(), 'generator_scheduler_state_dict': self.generator_scheduler.state_dict(), 'discriminator_scheduler_state_dict': self.discriminator_scheduler.state_dict(), 'train_losses': self.train_losses, 'test_losses': self.test_losses, 'valid_losses': self.valid_losses }, savepath) def load(self): """ Loads the model(s), optimizer(s), scheduler(s) and losses :return: None """ checkpoint = torch.load(self.loadpath, map_location=self.device) self.epoch = checkpoint['epoch'] self.generator.load_state_dict(checkpoint['generator_state_dict']) # self.discriminator.load_state_dict(checkpoint['discriminator_state_dict']) self.generator_optimizer.load_state_dict( checkpoint['generator_optimizer_state_dict']) # self.discriminator_optimizer.load_state_dict(checkpoint['discriminator_optimizer_state_dict']) self.generator_scheduler.load_state_dict( checkpoint['generator_scheduler_state_dict']) self.discriminator_scheduler.load_state_dict( checkpoint['discriminator_scheduler_state_dict']) self.train_losses = checkpoint['train_losses'] self.test_losses = checkpoint['test_losses'] self.valid_losses = checkpoint['valid_losses'] def evaluate_metrics(self, n_batches): """ Evaluates the quality of the reconstruction with the SNR and LSD metrics on a specified number of batches :param: n_batches: number of batches to process :return: mean and std for each metric """ with torch.no_grad(): snrs = [] lsds = [] generator = self.generator.eval() for k in range(n_batches): # Transfer to GPU local_batch = next(self.test_loader_iter) # Transfer to GPU input_batch, target_batch = local_batch[0].to( self.device), local_batch[1].to(self.device) # Generates a batch generated_batch = generator(input_batch) # Get the metrics snrs.append( snr(x=generated_batch.squeeze(), x_ref=target_batch.squeeze())) lsds.append( lsd(x=generated_batch.squeeze(), x_ref=target_batch.squeeze())) snrs = torch.cat(snrs).cpu().numpy() lsds = torch.cat(lsds).cpu().numpy() # Some signals corresponding to silence will be all zeroes and cause troubles due to the logarithm snrs[np.isinf(snrs)] = np.nan lsds[np.isinf(lsds)] = np.nan return np.nanmean(snrs), np.nanstd(snrs), np.nanmean(lsds), np.nanstd( lsds)