def process_through_model(path_to_input_wav: str, path_to_output: str,
                          model: nn.Module,
                          args: argparse.Namespace) -> AudioData:
    # 1. create AudioData object from input
    full_input_audiodata: AudioData = AudioData(wav_filepath=path_to_input_wav)
    if full_input_audiodata.sampling_freq != CD_QUALITY_SAMPLING_FREQ and full_input_audiodata.sampling_freq != TARGET_SAMPLING_FREQ:
        raise ValueError(
            "The input .wav file should have a "
            "sampling frequency of either 44.1 kHz or 44.1/3 kHz, but it has {} kHz instead"
            .format(full_input_audiodata.sampling_freq / 1000))
    if full_input_audiodata.sampling_freq == CD_QUALITY_SAMPLING_FREQ:
        full_input_audiodata = downsample(full_input_audiodata,
                                          DEFAULT_DOWNSAMPLE_FACTOR)

    # 2. split input data into non-overlapping segments. This will truncate a small chunk (<4.5s) of the input audio.
    # Then, convert the segments to the frequency domain.
    input_segments = list()
    num_samples_per_segment = DEFAULT_NUM_SAMPLES_PER_SEGMENT
    for i in range(
            0,
            len(full_input_audiodata.time_amplitudes) //
            num_samples_per_segment):
        begin_idx = i * num_samples_per_segment
        end_idx = begin_idx + num_samples_per_segment
        next_segment: AudioData = AudioData(manual_init=(
            full_input_audiodata.sampling_freq,
            full_input_audiodata.time_amplitudes[begin_idx:end_idx]))
        next_segment_freq: StftData = StftData(args=DEFAULT_STFT_ARGS,
                                               audiodata=next_segment)
        input_segments.append(next_segment_freq)

    # 3. Run the model on the input segments
    num_freqs = input_segments[0].sample_freqs.shape[0]
    num_windows = input_segments[0].segment_times.shape[0]
    input_freq_domain = np.zeros(
        (len(input_segments), 1, num_freqs, num_windows), dtype=np.complex64)

    for i in range(len(input_segments)):
        input_freq_domain[i, 0, :, :] = input_segments[i].data

    output_segments = apply_model(input_freq_domain, model, input_segments[0],
                                  args)

    # 4. Concatenate the output segments into an output AudioData object
    output_time_amplitudes = np.zeros(
        (num_samples_per_segment * len(output_segments)), )
    for i in range(0, len(output_segments)):
        begin_idx = i * num_samples_per_segment
        end_idx = begin_idx + num_samples_per_segment
        output_time_amplitudes[begin_idx:end_idx] = output_segments[
            i].time_amplitudes

    output_audio: AudioData = AudioData(manual_init=(TARGET_SAMPLING_FREQ,
                                                     output_time_amplitudes))

    # 5. convert output to wav file and save it to wav filepath
    audiodata_to_wav(output_audio, path_to_output)
    return output_audio
def invert_batch_like(batch: np.ndarray, container: StftData) -> np.ndarray:
    '''
    Takes in batch of shape (m, nsrc, 1, H, W), returns ndarray of shape (m * nsrc, time_samples, 1),
    containing the inverse STFT of the H,W images in batch.
    '''
    # Assume batch is shape (m, nsrc, 1, H, W)
    m, nsrc, _, _, _= batch.shape
    batch_small = np.squeeze(batch, axis=2)
    output = []
    for i in range(m):
        for j in range(nsrc):
            container.data = batch[i, j, :, :]
            audio = container.invert()
            output.append(audio.time_amplitudes)
    output = np.array(output)
    time_samples = output.shape[2]
    output = np.reshape(output, (m * nsrc, time_samples, 1))
    # output = np.expand_dims(np.concatenate(output, axis=0),3) # (m, nsrc, time_samples, 1)
    return output
def apply_model(input: np.ndarray, model: nn.Module, example_stft: StftData,
                args: argparse.Namespace) -> List[AudioData]:
    input_tensor = torch.tensor(input, dtype=torch.complex64)
    input_mags = input_tensor.abs()
    predictions, _ = model(input_mags)
    if args.nonboolean_mask:
        predicted_mask = torch.clamp(predictions / input_mags, 0, 1)
    else:
        predicted_mask = torch.ones_like(predictions) * (torch.clamp(
            predictions / input_mags, 0, 1) > args.alpha)

    predicted_mask = predicted_mask.detach().numpy()
    for i, mask in enumerate(predicted_mask[:, 0, :, :]):
        blurred_mask = cv2.GaussianBlur(mask,
                                        ksize=(5, 5),
                                        sigmaX=0,
                                        sigmaY=0.5)
        predicted_mask[i, 0] = blurred_mask

    output_freqs = input_tensor.numpy() * (predicted_mask)

    # for i in range(predicted_mask.shape[0]):
    #     example_stft.data = input_tensor[i,0].numpy()
    #     example_stft.save_spectrogram(show=True)
    #     example_stft.data = predicted_mask[i,0]
    #     example_stft.save_spectrogram(show=True)
    #     example_stft.data = output_freqs[i,0]
    #     example_stft.save_spectrogram(show=True)

    # Convert the outputs back to the time domain
    output_time_data = list()
    for i in range(output_freqs.shape[0]):
        freq_data = StftData(
            args=example_stft.args,
            manual_init=(output_freqs[i, 0, :, :], example_stft.sample_freqs,
                         example_stft.segment_times, example_stft.fs))
        output_time_data.append(freq_data.invert())

    return output_time_data
def create_data_for_model(target_dir: str, source1_dir: str, source2_dir: str,
                          source1_name: str = "source1", source2_name: str = "source2",
                          num_examples_to_create: int = 15000):
    if num_examples_to_create > 50000:
        raise ValueError("Let's not brick the VM. If you really want to create this many examples, edit the code to allow it.")

    # desired frequency outputs are 512x128.
    args: StftArgs = StftArgs(nperseg=1022, noverlap=511)
    source1_snippets = load_snippets(source1_dir)
    source2_snippets = load_snippets(source2_dir)

    if num_examples_to_create > len(source1_snippets) * len(source2_snippets):
        print("There aren't enough unique combinations to create {} unique examples! Falling back to only creating {} examples instead."
              .format(num_examples_to_create, len(source1_snippets) * len(source2_snippets)))
        num_examples_to_create = len(source1_snippets) * len(source2_snippets)

    random.shuffle(source1_snippets)
    random.shuffle(source2_snippets)

    pairgen: PairGenerator = PairGenerator(source1_snippets, source2_snippets)
    example_num = 0
    while example_num < num_examples_to_create:
        if example_num%100 == 0:
            print("Processing example #{}".format(example_num))
        audio1, audio2 = pairgen.get_pair()
        superimposed, audio1, audio2 = AudioDataUtils.superimpose(audio1, audio2)

        freqdata1 = trim_stft_data(StftData(args=args, audiodata=audio1))
        freqdata2 = trim_stft_data(StftData(args=args, audiodata=audio2))
        freqdata_super = trim_stft_data(StftData(args=args, audiodata=superimposed))

        file_prefix = target_dir + "/"
        freqdata1.save(file_prefix + source1_name + "_{}.pkl".format(example_num))
        freqdata2.save(file_prefix + source2_name + "_{}.pkl".format(example_num))
        freqdata_super.save(file_prefix + "combined_{}.pkl".format(example_num))
        example_num += 1
Beispiel #5
0
def test_model(
    dev_dl: data.DataLoader,
    model: nn.Module,
    args: argparse.Namespace,
    stft_container: StftData,
) -> nn.Module:

    device = model_utils.get_device()
    loss_fn = model_utils.l1_norm_loss

    print('\nRunning test metrics...')

    # Validation portion

    # Forward inference on model
    predicted_masks = []
    print('  Running forward inference...')
    with tqdm(total=args.batch_size * len(dev_dl)) as progress_bar:
        for i, (x_batch, _, _) in enumerate(dev_dl):
            x_batch = x_batch.abs().to(device)

            # Forward pass on model
            # y_pred = model(torch.clamp_min(torch.log(x_batch), 0))

            y_pred_b, y_pred_t = model(x_batch)

            if args.nonboolean_mask:
                y_biden_mask = torch.clamp(y_pred_b.detach() / x_batch, 0, 1)
                y_trump_mask = torch.clamp(y_pred_t.detach() / x_batch, 0, 1)
            else:
                y_biden_mask = torch.ones_like(y_pred_b) * (torch.clamp(
                    y_pred_b / x_batch, 0, 1) > args.alpha)
                y_trump_mask = torch.ones_like(y_pred_t) * (
                    1 - torch.clamp(y_pred_t / x_batch, 0, 1) > args.alpha)

            predicted_masks.append((y_biden_mask.cpu(), y_trump_mask.cpu()))

            progress_bar.update(len(x_batch))

            del x_batch
            del y_biden_mask
            del y_trump_mask
            del y_pred_b
            del y_pred_t

    print('\n  Processing results...')
    SDR, ISR, SIR, SAR = [], [], [], []
    with tqdm(total=args.batch_size * len(dev_dl)) as progress_bar:
        for i, ((x_batch, _, ground_truth),
                (y_biden_mask,
                 y_trump_mask)) in enumerate(zip(dev_dl, predicted_masks)):
            stft_biden_audio = y_biden_mask * x_batch
            stft_trump_audio = y_trump_mask * x_batch
            stft_audio = torch.stack([stft_biden_audio, stft_trump_audio],
                                     dim=1)

            # Calculate other stats
            model_stft = stft_audio.cpu().numpy()
            stft_container.data = stft_audio.numpy()
            model_audio = model_utils.invert_batch_like(
                model_stft, stft_container)

            m, nsrc, timesamples, chan = ground_truth.shape
            gt = torch.reshape(ground_truth, (m * nsrc, timesamples, 1))
            if args.biden_only_sdr:
                batch_sdr, batch_isr, batch_sir, batch_sar = bsseval.evaluate(
                    gt[:1, :, :],
                    model_audio[:1, :, :],
                    win=stft_container.fs,
                    hop=stft_container.fs,
                )
            else:
                batch_sdr, batch_isr, batch_sir, batch_sar = bsseval.evaluate(
                    gt,
                    model_audio,
                    win=stft_container.fs,
                    hop=stft_container.fs,
                )

            SDR = np.concatenate([SDR, np.mean(batch_sdr, axis=1)], axis=0)
            ISR = np.concatenate([ISR, np.mean(batch_isr, axis=1)], axis=0)
            SIR = np.concatenate([SIR, np.mean(batch_sir, axis=1)], axis=0)
            SAR = np.concatenate([SAR, np.mean(batch_sar, axis=1)], axis=0)

            progress_bar.update(len(x_batch))

    print(f'\n  Calculating overall metrics...')

    print()
    print('*' * 30)
    print(f'SDR: {np.mean(SDR)}')
    print(f'ISR: {np.mean(ISR)}')
    print(f'SIR: {np.mean(SIR)}')
    print(f'SAR: {np.mean(SAR)}')
    print('*' * 30)

    # for i in range(ground_truths.shape[0]):
    #     audio = AudioData(manual_init=[stft_container.fs, ground_truths[i, :, 0]])
    #     audio2 = AudioData(manual_init=[stft_container.fs, model_outputs[i, :, 0]])
    #     play(audio)
    #     play(audio2)

    return model
def trim_stft_data(freq_data: StftData) -> StftData:
    freq_data.segment_times = freq_data.segment_times[0:DESIRED_NUM_STFT_WINDOWS]
    freq_data.data = freq_data.data[:, 0:DESIRED_NUM_STFT_WINDOWS]
    return freq_data
def load_data(
    directory_path: str,
    dev_frac: float = 0.1,
    max_entries: int = 15000,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    '''
    Returns data in order of combined_train, biden_train, mask_train, combined_dev, biden_dev, mask_dev
    combined and biden are complex64, but mask_train is float32.
    Note: Train model on x_train.abs() and y_train.abs(), but in practice you need to input
    x_train.abs() and multiply the output by x_train (complex) to get the real prediction.
    Perhaps threshold the output to create a more binary mask.
    '''
    # Open given directory, expecting files with the names
    # biden_%d.pkl, combined_%d.pkl, and trump_%d.pkl

    # Choose file names and deterministically shuffle
    print('Reading datasets...')
    pickled_filenames = os.listdir(directory_path)
    combined_filenames = sorted([fn for fn in pickled_filenames if fn.find('combined_')!=-1])
    biden_filenames = sorted([fn for fn in pickled_filenames if fn.find('biden_')!=-1])
    trump_filenames = sorted([fn for fn in pickled_filenames if fn.find('trump_')!=-1])
    zipped_filenames = list(zip(combined_filenames, biden_filenames, trump_filenames))
    random.Random(230).shuffle(zipped_filenames)
    if len(zipped_filenames) > max_entries:
        zipped_filenames = zipped_filenames[:max_entries]

    # Read files as stft data
    num_examples = len(zipped_filenames)
    combined_ls, biden_ls, trump_ls = [], [], []
    masks_ls = []
    with tqdm(total=len(zipped_filenames)) as progress_bar:
        for i, (cd, bd, td) in enumerate(zipped_filenames):
            combined_data = StftData(pickle_file=f'{directory_path}/{cd}')
            biden_data = StftData(pickle_file=f'{directory_path}/{bd}')
            trump_data = StftData(pickle_file=f'{directory_path}/{td}')
            combined_ls.append(combined_data.data)
            # combined_ls.append(trump_data.data)
            # combined_ls.append(biden_data.data)
            biden_ls.append(biden_data.data)
            # biden_ls.append(np.zeros_like(biden_data.data))
            # biden_ls.append(biden_data.data)
            trump_ls.append(trump_data.data)
            # trump_ls.append(trump_data.data)
            # trump_ls.append(np.zeros_like(trump_data.data))
            biden_mag = np.abs(biden_data.data)
            trump_mag = np.abs(trump_data.data)
            masks_ls.append(np.ones_like(biden_mag, dtype=np.float32) * (biden_mag > trump_mag))
            # masks_ls.append(np.zeros_like(trump_mag, dtype=np.float32))
            # masks_ls.append(np.ones_like(biden_mag, dtype=np.float32))
            progress_bar.update()

    # Reformat arrays
    combined = _convert_to_tensor(combined_ls, torch.complex64)
    biden = _convert_to_tensor(biden_ls, torch.complex64)
    trump = _convert_to_tensor(trump_ls, torch.complex64)
    masks = _convert_to_tensor(masks_ls, torch.float32)

    # Partition into train and dev
    print('  Done!')
    num_examples = len(combined)
    dev_idx = num_examples - int(num_examples * dev_frac)
    return (
        combined[:dev_idx],
        biden[:dev_idx],
        trump[:dev_idx],
        masks[:dev_idx],
        combined[dev_idx:],
        biden[dev_idx:],
        trump[dev_idx:],
        masks[dev_idx:],
        biden_data,
    )
def load_test_data(
    directory_path: str,
    dev_frac: float = 0.1,
    max_entries: int = 15000,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, StftData]:
    '''
    Returns data in order of combined_train, biden_train, mask_train, combined_dev, biden_dev, mask_dev
    combined and biden are complex64, but mask_train is float32.
    Note: Train model on x_train.abs() and y_train.abs(), but in practice you need to input
    x_train.abs() and multiply the output by x_train (complex) to get the real prediction.
    Perhaps threshold the output to create a more binary mask.
    '''
    # Open given directory, expecting files with the names
    # biden_%d.pkl, combined_%d.pkl, and trump_%d.pkl

    # Choose file names and deterministically shuffle
    print('Reading datasets...')
    pickled_filenames = os.listdir(directory_path)
    combined_filenames = sorted([fn for fn in pickled_filenames if fn.find('combined_')!=-1])
    biden_filenames = sorted([fn for fn in pickled_filenames if fn.find('biden_')!=-1])
    trump_filenames = sorted([fn for fn in pickled_filenames if fn.find('trump_')!=-1])
    zipped_filenames = list(zip(combined_filenames, biden_filenames, trump_filenames))
    random.Random(230).shuffle(zipped_filenames)
    if len(zipped_filenames) > max_entries:
        zipped_filenames = zipped_filenames[:max_entries]

    # Only load dev examples
    num_examples = len(zipped_filenames)
    dev_idx = num_examples - int(num_examples * dev_frac)
    zipped_filenames = zipped_filenames[dev_idx:]

    # Need combined as STFT data (for network input) and inverted biden
    combined_ls, biden_samples_ls, trump_samples_ls, biden_ls = [], [], [], []
    with tqdm(total=len(zipped_filenames)) as progress_bar:
        for i, (cd, bd, td) in enumerate(zipped_filenames):
            combined_data = StftData(pickle_file=f'{directory_path}/{cd}')
            biden_data = StftData(pickle_file=f'{directory_path}/{bd}')
            trump_data = StftData(pickle_file=f'{directory_path}/{td}')
            combined_ls.append(combined_data.data)
            biden_ls.append(biden_data.data)
            biden_time_amplitude = biden_data.invert().time_amplitudes
            trump_time_amplitude = trump_data.invert().time_amplitudes
            biden_samples_ls.append(biden_time_amplitude)
            trump_samples_ls.append(trump_time_amplitude)
            progress_bar.update()

    # Reformat arrays
    combined = _convert_to_tensor(combined_ls, torch.complex64)
    biden = _convert_to_tensor(biden_ls, torch.complex64)
    biden_samples = np.expand_dims(np.asarray(biden_samples_ls), axis=2)
    trump_samples = np.expand_dims(np.asarray(trump_samples_ls), axis=2)
    gt_samples = torch.tensor(np.stack([biden_samples, trump_samples], axis=1)) # Shape 2, time_samples, 1

    print('  Done!')

    return (
        combined,
        biden,
        gt_samples,
        biden_data,
    )