def process_through_model(path_to_input_wav: str, path_to_output: str, model: nn.Module, args: argparse.Namespace) -> AudioData: # 1. create AudioData object from input full_input_audiodata: AudioData = AudioData(wav_filepath=path_to_input_wav) if full_input_audiodata.sampling_freq != CD_QUALITY_SAMPLING_FREQ and full_input_audiodata.sampling_freq != TARGET_SAMPLING_FREQ: raise ValueError( "The input .wav file should have a " "sampling frequency of either 44.1 kHz or 44.1/3 kHz, but it has {} kHz instead" .format(full_input_audiodata.sampling_freq / 1000)) if full_input_audiodata.sampling_freq == CD_QUALITY_SAMPLING_FREQ: full_input_audiodata = downsample(full_input_audiodata, DEFAULT_DOWNSAMPLE_FACTOR) # 2. split input data into non-overlapping segments. This will truncate a small chunk (<4.5s) of the input audio. # Then, convert the segments to the frequency domain. input_segments = list() num_samples_per_segment = DEFAULT_NUM_SAMPLES_PER_SEGMENT for i in range( 0, len(full_input_audiodata.time_amplitudes) // num_samples_per_segment): begin_idx = i * num_samples_per_segment end_idx = begin_idx + num_samples_per_segment next_segment: AudioData = AudioData(manual_init=( full_input_audiodata.sampling_freq, full_input_audiodata.time_amplitudes[begin_idx:end_idx])) next_segment_freq: StftData = StftData(args=DEFAULT_STFT_ARGS, audiodata=next_segment) input_segments.append(next_segment_freq) # 3. Run the model on the input segments num_freqs = input_segments[0].sample_freqs.shape[0] num_windows = input_segments[0].segment_times.shape[0] input_freq_domain = np.zeros( (len(input_segments), 1, num_freqs, num_windows), dtype=np.complex64) for i in range(len(input_segments)): input_freq_domain[i, 0, :, :] = input_segments[i].data output_segments = apply_model(input_freq_domain, model, input_segments[0], args) # 4. Concatenate the output segments into an output AudioData object output_time_amplitudes = np.zeros( (num_samples_per_segment * len(output_segments)), ) for i in range(0, len(output_segments)): begin_idx = i * num_samples_per_segment end_idx = begin_idx + num_samples_per_segment output_time_amplitudes[begin_idx:end_idx] = output_segments[ i].time_amplitudes output_audio: AudioData = AudioData(manual_init=(TARGET_SAMPLING_FREQ, output_time_amplitudes)) # 5. convert output to wav file and save it to wav filepath audiodata_to_wav(output_audio, path_to_output) return output_audio
def invert_batch_like(batch: np.ndarray, container: StftData) -> np.ndarray: ''' Takes in batch of shape (m, nsrc, 1, H, W), returns ndarray of shape (m * nsrc, time_samples, 1), containing the inverse STFT of the H,W images in batch. ''' # Assume batch is shape (m, nsrc, 1, H, W) m, nsrc, _, _, _= batch.shape batch_small = np.squeeze(batch, axis=2) output = [] for i in range(m): for j in range(nsrc): container.data = batch[i, j, :, :] audio = container.invert() output.append(audio.time_amplitudes) output = np.array(output) time_samples = output.shape[2] output = np.reshape(output, (m * nsrc, time_samples, 1)) # output = np.expand_dims(np.concatenate(output, axis=0),3) # (m, nsrc, time_samples, 1) return output
def apply_model(input: np.ndarray, model: nn.Module, example_stft: StftData, args: argparse.Namespace) -> List[AudioData]: input_tensor = torch.tensor(input, dtype=torch.complex64) input_mags = input_tensor.abs() predictions, _ = model(input_mags) if args.nonboolean_mask: predicted_mask = torch.clamp(predictions / input_mags, 0, 1) else: predicted_mask = torch.ones_like(predictions) * (torch.clamp( predictions / input_mags, 0, 1) > args.alpha) predicted_mask = predicted_mask.detach().numpy() for i, mask in enumerate(predicted_mask[:, 0, :, :]): blurred_mask = cv2.GaussianBlur(mask, ksize=(5, 5), sigmaX=0, sigmaY=0.5) predicted_mask[i, 0] = blurred_mask output_freqs = input_tensor.numpy() * (predicted_mask) # for i in range(predicted_mask.shape[0]): # example_stft.data = input_tensor[i,0].numpy() # example_stft.save_spectrogram(show=True) # example_stft.data = predicted_mask[i,0] # example_stft.save_spectrogram(show=True) # example_stft.data = output_freqs[i,0] # example_stft.save_spectrogram(show=True) # Convert the outputs back to the time domain output_time_data = list() for i in range(output_freqs.shape[0]): freq_data = StftData( args=example_stft.args, manual_init=(output_freqs[i, 0, :, :], example_stft.sample_freqs, example_stft.segment_times, example_stft.fs)) output_time_data.append(freq_data.invert()) return output_time_data
def create_data_for_model(target_dir: str, source1_dir: str, source2_dir: str, source1_name: str = "source1", source2_name: str = "source2", num_examples_to_create: int = 15000): if num_examples_to_create > 50000: raise ValueError("Let's not brick the VM. If you really want to create this many examples, edit the code to allow it.") # desired frequency outputs are 512x128. args: StftArgs = StftArgs(nperseg=1022, noverlap=511) source1_snippets = load_snippets(source1_dir) source2_snippets = load_snippets(source2_dir) if num_examples_to_create > len(source1_snippets) * len(source2_snippets): print("There aren't enough unique combinations to create {} unique examples! Falling back to only creating {} examples instead." .format(num_examples_to_create, len(source1_snippets) * len(source2_snippets))) num_examples_to_create = len(source1_snippets) * len(source2_snippets) random.shuffle(source1_snippets) random.shuffle(source2_snippets) pairgen: PairGenerator = PairGenerator(source1_snippets, source2_snippets) example_num = 0 while example_num < num_examples_to_create: if example_num%100 == 0: print("Processing example #{}".format(example_num)) audio1, audio2 = pairgen.get_pair() superimposed, audio1, audio2 = AudioDataUtils.superimpose(audio1, audio2) freqdata1 = trim_stft_data(StftData(args=args, audiodata=audio1)) freqdata2 = trim_stft_data(StftData(args=args, audiodata=audio2)) freqdata_super = trim_stft_data(StftData(args=args, audiodata=superimposed)) file_prefix = target_dir + "/" freqdata1.save(file_prefix + source1_name + "_{}.pkl".format(example_num)) freqdata2.save(file_prefix + source2_name + "_{}.pkl".format(example_num)) freqdata_super.save(file_prefix + "combined_{}.pkl".format(example_num)) example_num += 1
def test_model( dev_dl: data.DataLoader, model: nn.Module, args: argparse.Namespace, stft_container: StftData, ) -> nn.Module: device = model_utils.get_device() loss_fn = model_utils.l1_norm_loss print('\nRunning test metrics...') # Validation portion # Forward inference on model predicted_masks = [] print(' Running forward inference...') with tqdm(total=args.batch_size * len(dev_dl)) as progress_bar: for i, (x_batch, _, _) in enumerate(dev_dl): x_batch = x_batch.abs().to(device) # Forward pass on model # y_pred = model(torch.clamp_min(torch.log(x_batch), 0)) y_pred_b, y_pred_t = model(x_batch) if args.nonboolean_mask: y_biden_mask = torch.clamp(y_pred_b.detach() / x_batch, 0, 1) y_trump_mask = torch.clamp(y_pred_t.detach() / x_batch, 0, 1) else: y_biden_mask = torch.ones_like(y_pred_b) * (torch.clamp( y_pred_b / x_batch, 0, 1) > args.alpha) y_trump_mask = torch.ones_like(y_pred_t) * ( 1 - torch.clamp(y_pred_t / x_batch, 0, 1) > args.alpha) predicted_masks.append((y_biden_mask.cpu(), y_trump_mask.cpu())) progress_bar.update(len(x_batch)) del x_batch del y_biden_mask del y_trump_mask del y_pred_b del y_pred_t print('\n Processing results...') SDR, ISR, SIR, SAR = [], [], [], [] with tqdm(total=args.batch_size * len(dev_dl)) as progress_bar: for i, ((x_batch, _, ground_truth), (y_biden_mask, y_trump_mask)) in enumerate(zip(dev_dl, predicted_masks)): stft_biden_audio = y_biden_mask * x_batch stft_trump_audio = y_trump_mask * x_batch stft_audio = torch.stack([stft_biden_audio, stft_trump_audio], dim=1) # Calculate other stats model_stft = stft_audio.cpu().numpy() stft_container.data = stft_audio.numpy() model_audio = model_utils.invert_batch_like( model_stft, stft_container) m, nsrc, timesamples, chan = ground_truth.shape gt = torch.reshape(ground_truth, (m * nsrc, timesamples, 1)) if args.biden_only_sdr: batch_sdr, batch_isr, batch_sir, batch_sar = bsseval.evaluate( gt[:1, :, :], model_audio[:1, :, :], win=stft_container.fs, hop=stft_container.fs, ) else: batch_sdr, batch_isr, batch_sir, batch_sar = bsseval.evaluate( gt, model_audio, win=stft_container.fs, hop=stft_container.fs, ) SDR = np.concatenate([SDR, np.mean(batch_sdr, axis=1)], axis=0) ISR = np.concatenate([ISR, np.mean(batch_isr, axis=1)], axis=0) SIR = np.concatenate([SIR, np.mean(batch_sir, axis=1)], axis=0) SAR = np.concatenate([SAR, np.mean(batch_sar, axis=1)], axis=0) progress_bar.update(len(x_batch)) print(f'\n Calculating overall metrics...') print() print('*' * 30) print(f'SDR: {np.mean(SDR)}') print(f'ISR: {np.mean(ISR)}') print(f'SIR: {np.mean(SIR)}') print(f'SAR: {np.mean(SAR)}') print('*' * 30) # for i in range(ground_truths.shape[0]): # audio = AudioData(manual_init=[stft_container.fs, ground_truths[i, :, 0]]) # audio2 = AudioData(manual_init=[stft_container.fs, model_outputs[i, :, 0]]) # play(audio) # play(audio2) return model
def trim_stft_data(freq_data: StftData) -> StftData: freq_data.segment_times = freq_data.segment_times[0:DESIRED_NUM_STFT_WINDOWS] freq_data.data = freq_data.data[:, 0:DESIRED_NUM_STFT_WINDOWS] return freq_data
def load_data( directory_path: str, dev_frac: float = 0.1, max_entries: int = 15000, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ''' Returns data in order of combined_train, biden_train, mask_train, combined_dev, biden_dev, mask_dev combined and biden are complex64, but mask_train is float32. Note: Train model on x_train.abs() and y_train.abs(), but in practice you need to input x_train.abs() and multiply the output by x_train (complex) to get the real prediction. Perhaps threshold the output to create a more binary mask. ''' # Open given directory, expecting files with the names # biden_%d.pkl, combined_%d.pkl, and trump_%d.pkl # Choose file names and deterministically shuffle print('Reading datasets...') pickled_filenames = os.listdir(directory_path) combined_filenames = sorted([fn for fn in pickled_filenames if fn.find('combined_')!=-1]) biden_filenames = sorted([fn for fn in pickled_filenames if fn.find('biden_')!=-1]) trump_filenames = sorted([fn for fn in pickled_filenames if fn.find('trump_')!=-1]) zipped_filenames = list(zip(combined_filenames, biden_filenames, trump_filenames)) random.Random(230).shuffle(zipped_filenames) if len(zipped_filenames) > max_entries: zipped_filenames = zipped_filenames[:max_entries] # Read files as stft data num_examples = len(zipped_filenames) combined_ls, biden_ls, trump_ls = [], [], [] masks_ls = [] with tqdm(total=len(zipped_filenames)) as progress_bar: for i, (cd, bd, td) in enumerate(zipped_filenames): combined_data = StftData(pickle_file=f'{directory_path}/{cd}') biden_data = StftData(pickle_file=f'{directory_path}/{bd}') trump_data = StftData(pickle_file=f'{directory_path}/{td}') combined_ls.append(combined_data.data) # combined_ls.append(trump_data.data) # combined_ls.append(biden_data.data) biden_ls.append(biden_data.data) # biden_ls.append(np.zeros_like(biden_data.data)) # biden_ls.append(biden_data.data) trump_ls.append(trump_data.data) # trump_ls.append(trump_data.data) # trump_ls.append(np.zeros_like(trump_data.data)) biden_mag = np.abs(biden_data.data) trump_mag = np.abs(trump_data.data) masks_ls.append(np.ones_like(biden_mag, dtype=np.float32) * (biden_mag > trump_mag)) # masks_ls.append(np.zeros_like(trump_mag, dtype=np.float32)) # masks_ls.append(np.ones_like(biden_mag, dtype=np.float32)) progress_bar.update() # Reformat arrays combined = _convert_to_tensor(combined_ls, torch.complex64) biden = _convert_to_tensor(biden_ls, torch.complex64) trump = _convert_to_tensor(trump_ls, torch.complex64) masks = _convert_to_tensor(masks_ls, torch.float32) # Partition into train and dev print(' Done!') num_examples = len(combined) dev_idx = num_examples - int(num_examples * dev_frac) return ( combined[:dev_idx], biden[:dev_idx], trump[:dev_idx], masks[:dev_idx], combined[dev_idx:], biden[dev_idx:], trump[dev_idx:], masks[dev_idx:], biden_data, )
def load_test_data( directory_path: str, dev_frac: float = 0.1, max_entries: int = 15000, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, StftData]: ''' Returns data in order of combined_train, biden_train, mask_train, combined_dev, biden_dev, mask_dev combined and biden are complex64, but mask_train is float32. Note: Train model on x_train.abs() and y_train.abs(), but in practice you need to input x_train.abs() and multiply the output by x_train (complex) to get the real prediction. Perhaps threshold the output to create a more binary mask. ''' # Open given directory, expecting files with the names # biden_%d.pkl, combined_%d.pkl, and trump_%d.pkl # Choose file names and deterministically shuffle print('Reading datasets...') pickled_filenames = os.listdir(directory_path) combined_filenames = sorted([fn for fn in pickled_filenames if fn.find('combined_')!=-1]) biden_filenames = sorted([fn for fn in pickled_filenames if fn.find('biden_')!=-1]) trump_filenames = sorted([fn for fn in pickled_filenames if fn.find('trump_')!=-1]) zipped_filenames = list(zip(combined_filenames, biden_filenames, trump_filenames)) random.Random(230).shuffle(zipped_filenames) if len(zipped_filenames) > max_entries: zipped_filenames = zipped_filenames[:max_entries] # Only load dev examples num_examples = len(zipped_filenames) dev_idx = num_examples - int(num_examples * dev_frac) zipped_filenames = zipped_filenames[dev_idx:] # Need combined as STFT data (for network input) and inverted biden combined_ls, biden_samples_ls, trump_samples_ls, biden_ls = [], [], [], [] with tqdm(total=len(zipped_filenames)) as progress_bar: for i, (cd, bd, td) in enumerate(zipped_filenames): combined_data = StftData(pickle_file=f'{directory_path}/{cd}') biden_data = StftData(pickle_file=f'{directory_path}/{bd}') trump_data = StftData(pickle_file=f'{directory_path}/{td}') combined_ls.append(combined_data.data) biden_ls.append(biden_data.data) biden_time_amplitude = biden_data.invert().time_amplitudes trump_time_amplitude = trump_data.invert().time_amplitudes biden_samples_ls.append(biden_time_amplitude) trump_samples_ls.append(trump_time_amplitude) progress_bar.update() # Reformat arrays combined = _convert_to_tensor(combined_ls, torch.complex64) biden = _convert_to_tensor(biden_ls, torch.complex64) biden_samples = np.expand_dims(np.asarray(biden_samples_ls), axis=2) trump_samples = np.expand_dims(np.asarray(trump_samples_ls), axis=2) gt_samples = torch.tensor(np.stack([biden_samples, trump_samples], axis=1)) # Shape 2, time_samples, 1 print(' Done!') return ( combined, biden, gt_samples, biden_data, )