def _reset(self, args):
        # Override
        super()._reset(args)

        self.F_bin = args.F_bin
        self.istft = BatchInvSTFT(args.fft_size,
                                  args.hop_size,
                                  window_fn=args.window_fn)
예제 #2
0
 def __init__(self, fft_size, hop_size=None, window_fn='hann'):
     super().__init__()
     if hop_size is None:
         hop_size = fft_size//4
     
     self.fft_size = fft_size
     
     self.stft = BatchSTFT(fft_size, hop_size=hop_size, window_fn=window_fn, normalize=True)
     self.istft = BatchInvSTFT(fft_size, hop_size=hop_size, window_fn=window_fn, normalize=True)
    def _reset(self, args):
        # Override
        super()._reset(args)

        self.n_bins = args.n_bins
        self.istft = BatchInvSTFT(args.fft_size,
                                  args.hop_size,
                                  window_fn=args.window_fn)

        self.lr_decay = (args.lr_end / args.lr)**(1 / self.epochs)
예제 #4
0
def _test(metric='EUC'):
    torch.manual_seed(111)

    fft_size, hop_size = 1024, 256
    n_bases = 6
    iteration = 100
    
    signal, sr = read_wav("data/music-8000.wav")
    
    T = len(signal)
    signal = torch.Tensor(signal).unsqueeze(dim=0)
    
    stft = BatchSTFT(fft_size=fft_size, hop_size=hop_size)
    istft = BatchInvSTFT(fft_size=fft_size, hop_size=hop_size)

    spectrogram = stft(signal).squeeze(dim=0)
    real = spectrogram[...,0]
    imag = spectrogram[...,1]
    amplitude = torch.sqrt(real**2 + imag**2)
    power = amplitude**2

    log_spectrogram = 10 * torch.log10(power + EPS)
    plt.figure()
    plt.pcolormesh(log_spectrogram, cmap='jet')
    plt.colorbar()
    plt.savefig('data/NMF/spectrogram.png', bbox_inches='tight')
    plt.close()

    nmf = NMF(n_bases, metric=metric)
    nmf.update(power, iteration=iteration)

    estimated_power = torch.matmul(nmf.base, nmf.activation)
    estimated_amplitude = torch.sqrt(estimated_power)
    ratio = estimated_amplitude / (amplitude + EPS)
    estimated_real, estimated_imag = ratio * real, ratio * imag
    estimated_spectrogram = torch.cat([estimated_real.unsqueeze(dim=2), estimated_imag.unsqueeze(dim=2)], dim=2).unsqueeze(dim=0)

    estimated_signal = istft(estimated_spectrogram, T=T)
    estimated_signal = estimated_signal.squeeze(dim=0).numpy()
    estimated_signal = estimated_signal / np.abs(estimated_signal).max()
    write_wav("data/NMF/{}/music-8000-estimated-iter{}.wav".format(metric, iteration), signal=estimated_signal, sr=8000)

    for idx in range(n_bases):
        estimated_power = torch.matmul(nmf.base[:, idx: idx+1], nmf.activation[idx: idx+1, :])
        estimated_amplitude = torch.sqrt(estimated_power)
        ratio = estimated_amplitude / (amplitude + EPS)
        estimated_real, estimated_imag = ratio * real, ratio * imag
        estimated_spectrogram = torch.cat([estimated_real.unsqueeze(dim=2), estimated_imag.unsqueeze(dim=2)], dim=2).unsqueeze(dim=0)

        estimated_signal = istft(estimated_spectrogram, T=T)
        estimated_signal = estimated_signal.squeeze(dim=0).numpy()
        estimated_signal = estimated_signal / np.abs(estimated_signal).max()
        write_wav("data/NMF/{}/music-8000-estimated-iter{}-base{}.wav".format(metric, iteration, idx), signal=estimated_signal, sr=8000)

        log_spectrogram = 10 * torch.log10(estimated_power + EPS).numpy()
        plt.figure()
        plt.pcolormesh(log_spectrogram, cmap='jet')
        plt.colorbar()
        plt.savefig('data/NMF/{}/estimated-spectrogram-iter{}-base{}.png'.format(metric, iteration, idx), bbox_inches='tight')
        plt.close()
    
    plt.figure()
    plt.plot(nmf.loss)
    plt.savefig('data/NMF/{}/loss.png'.format(metric), bbox_inches='tight')
    plt.close()
예제 #5
0
    from algorithm.stft import BatchSTFT, BatchInvSTFT
    from algorithm.frequency_mask import ideal_binary_mask
    from criterion.deep_clustering import AffinityLoss

    torch.manual_seed(111)

    batch_size, T = 2, 512
    n_sources = 2
    fft_size, hop_size = 256, 128
    window_fn = 'hann'
    n_bins = fft_size // 2 + 1
    hidden_channels, embed_dim = 600, 40

    stft = BatchSTFT(fft_size=fft_size, hop_size=hop_size, window_fn=window_fn)
    istft = BatchInvSTFT(fft_size=fft_size,
                         hop_size=hop_size,
                         window_fn=window_fn)
    criterion = AffinityLoss()

    signal = torch.randn((batch_size * n_sources, T), dtype=torch.float)
    spectrogram = stft(signal)
    real, imag = spectrogram[..., 0], spectrogram[..., 1]
    power = real**2 + imag**2
    target = 10 * torch.log10(power + EPS)
    _, _, n_frames = target.size()
    target = target.view(batch_size, n_sources, n_bins, n_frames)
    target = ideal_binary_mask(target)
    input = target.sum(dim=1)

    print("=" * 10, "Deep embedding", "=" * 10)
예제 #6
0
    os.makedirs("data/GriffinLim", exist_ok=True)
    torch.manual_seed(111)

    fft_size, hop_size = 1024, 256
    n_basis = 4

    signal, sr = read_wav("data/man-44100.wav")
    signal = resample_poly(signal, up=16000, down=sr)
    write_wav("data/man-16000.wav", signal=signal, sr=16000)

    T = len(signal)
    signal = torch.Tensor(signal).unsqueeze(dim=0)

    stft = BatchSTFT(fft_size=fft_size, hop_size=hop_size)
    istft = BatchInvSTFT(fft_size=fft_size, hop_size=hop_size)

    spectrogram = stft(signal)
    oracle_signal = istft(spectrogram, T=T)
    oracle_signal = oracle_signal.squeeze(dim=0).numpy()
    write_wav("data/man-oracle.wav", signal=oracle_signal, sr=16000)

    griffin_lim = GriffinLim(fft_size, hop_size=hop_size)

    spectrogram = spectrogram.squeeze(dim=0)
    real, imag = spectrogram[..., 0], spectrogram[..., 1]
    amplitude = torch.sqrt(real**2 + imag**2)

    # Griffin-Lim iteration 10
    iteration = 10
    estimated_phase = griffin_lim(amplitude, iteration=iteration)
예제 #7
0
def process_offline(sr,
                    num_chunk,
                    duration=5,
                    model_path=None,
                    save_dir="results",
                    args=None):
    num_loop = int(duration * sr / num_chunk)
    sequence = []

    P = pyaudio.PyAudio()

    # Record
    stream = P.open(format=FORMAT,
                    channels=NUM_CHANNEL,
                    rate=sr,
                    input_device_index=DEVICE_INDEX,
                    frames_per_buffer=num_chunk,
                    input=True,
                    output=False)

    for i in range(num_loop):
        input = stream.read(num_chunk)
        sequence.append(input)
        time = int(i * num_chunk / sr)
        show_progress_bar(time, duration)

    show_progress_bar(duration, duration)
    print()

    stream.stop_stream()
    stream.close()
    P.terminate()

    print("Stop recording")

    os.makedirs(save_dir, exist_ok=True)

    # Save
    signal = b"".join(sequence)
    signal = np.frombuffer(signal, dtype=np.int16)
    signal = signal / 32768

    save_path = os.path.join(save_dir, "mixture.wav")
    write_wav(save_path, signal=signal, sr=sr)

    # Separate by DNN
    model = load_model(model_path)
    model.eval()

    fft_size, hop_size = args.fft_size, args.hop_size
    window_fn = args.window_fn

    if hop_size is None:
        hop_size = fft_size // 2

    n_sources = args.n_sources
    iter_clustering = args.iter_clustering

    F_bin = fft_size // 2 + 1
    stft = BatchSTFT(fft_size, hop_size=hop_size, window_fn=window_fn)
    istft = BatchInvSTFT(fft_size, hop_size=hop_size, window_fn=window_fn)

    print("Start separation...")

    with torch.no_grad():
        mixture = torch.Tensor(signal).float()
        T = mixture.size(0)
        mixture = mixture.unsqueeze(dim=0)
        mixture = stft(mixture).unsqueeze(dim=0)
        real, imag = mixture[:, :, :F_bin], mixture[:, :, F_bin:]
        mixture_amplitude = torch.sqrt(real**2 + imag**2)
        estimated_sources_amplitude = model(
            mixture_amplitude,
            n_sources=n_sources,
            iter_clustering=iter_clustering)  # TODO: Args, threshold
        ratio = estimated_sources_amplitude / mixture_amplitude
        real, imag = ratio * real, ratio * imag
        estimated_sources = torch.cat([real, imag], dim=2)
        estimated_sources = estimated_sources.squeeze(dim=0)
        estimated_sources = istft(estimated_sources, T=T).numpy()

    print("Finished separation...")

    for idx, estimated_source in enumerate(estimated_sources):
        save_path = os.path.join(save_dir, "estimated-{}.wav".format(idx))
        write_wav(save_path, signal=estimated_source, sr=sr)