Exemplo n.º 1
0
def forward(model, x, mean, std, cuda, volatile=False):
    
    stack_num = x.shape[1]
    
    x = move_data_to_gpu(x, cuda, volatile)
    
    # hamming window
    x = x * hamming_window

    # rdft
    (x_real, x_imag) = dft.rdft(x)
    x_mag = torch.sqrt(x_real ** 2 + x_imag ** 2)
    cos = x_real / x_mag
    sin = x_imag / x_mag
        
    x = transform(x_mag, type='torch')
    x = pp_data.scale(x, mean, std)
    
    output = model(x)
    
    output = pp_data.inv_scale(output, mean, std)
    output = inv_transform(output, type='torch')
    
    y_real = output * cos[:, stack_num // 2, :]
    y_imag = output * sin[:, stack_num // 2, :]
    
    s = dft.irdft(y_real, y_imag)
    
    s /= hamming_window

    return s
Exemplo n.º 2
0
Arquivo: tmp01.py Projeto: zqy1/sednn
def forward(model, x, mean, std, cuda, volatile=False):
    x = np.abs(x)
    x = move_data_to_gpu(x, cuda, volatile)

    x = transform(x, type='torch')
    x = pp_data.scale(x, mean, std)

    output = model(x)

    output = pp_data.inv_scale(output, mean, std)
    output = inv_transform(output, type='torch')

    return output
Exemplo n.º 3
0
def forward(model, x, mean, std, dft, cuda, volatile=False):

    x = np.abs(x)
    x = move_data_to_gpu(x, cuda, volatile)

    # (x_real, x_imag) = dft.rdft(x)
    # x = torch.sqrt(x_real ** 2 + x_imag ** 2)

    import crash
    pause

    x = transform(x, type='torch')
    x = pp_data.scale(x, mean, std)

    import crash
    pause

    output = model(x)

    output = pp_data.inv_scale(output, mean, std)
    output = inv_transform(output, type='torch')

    return output
Exemplo n.º 4
0
Arquivo: tmp01b.py Projeto: zqy1/sednn
def inference(args):
    workspace = args.workspace
    audio_type = args.audio_type
    iter = args.iteration
    stack_num = args.stack_num
    filename = args.filename
    mini_num = args.mini_num
    visualize = args.visualize
    cuda = args.use_cuda and torch.cuda.is_available()
    print("cuda:", cuda)

    sample_rate = cfg.sample_rate
    fft_size = cfg.fft_size
    hop_size = cfg.hop_size
    window_type = cfg.window_type

    if window_type == 'hamming':
        window = np.hamming(fft_size)

    # Audio
    audio_dir = "/vol/vssp/msos/qk/workspaces/speech_enhancement/mixed_audios/spectrogram/test/0db"
    # audio_dir = "/user/HS229/qk00006/my_code2015.5-/python/pub_speech_enhancement/mixture2clean_dnn/workspace/mixed_audios/spectrogram/test/0db"
    names = os.listdir(audio_dir)

    # Load model
    model_path = os.path.join(workspace, "models", filename, audio_type,
                              "md_%d_iters.tar" % iter)
    n_freq = 257
    model = DNN(stack_num, n_freq)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['state_dict'])

    if cuda:
        model.cuda()

    # Load scalar
    scalar_path = os.path.join(workspace, "scalars", filename, "scalar.p")
    (mean_, std_) = cPickle.load(open(scalar_path, 'rb'))
    mean_ = move_data_to_gpu(mean_, cuda, volatile=True)
    std_ = move_data_to_gpu(std_, cuda, volatile=True)

    if mini_num > 0:
        n_every = len(names) / mini_num
    else:
        n_every = 1

    out_wav_dir = os.path.join(workspace, "enh_wavs", filename)
    pp_data.create_folder(out_wav_dir)

    for (cnt, name) in enumerate(names):
        if cnt % n_every == 0:
            audio_path = os.path.join(audio_dir, name)
            (audio, _) = pp_data.read_audio(audio_path, sample_rate)

            audio = pp_data.normalize(audio)
            cmplx_sp = pp_data.calc_sp(audio, fft_size, hop_size, window)
            x = np.abs(cmplx_sp)

            # Process data.
            n_pad = (stack_num - 1) / 2
            x = pp_data.pad_with_border(x, n_pad)
            x = pp_data.mat_2d_to_3d(x, stack_num, hop=1)

            output = forward(model, x, mean_, std_, cuda)

            output = pp_data.inv_scale(output, mean_, std_)
            output = inv_transform(output, type='torch')

            pred_mag_sp = output.data.cpu().numpy()

            pred_cmplx_sp = stft.real_to_complex(pred_mag_sp, cmplx_sp)
            frames = stft.istft(pred_cmplx_sp)

            cola_constant = stft.get_cola_constant(hop_size, window)
            seq = stft.overlap_add(frames, hop_size, cola_constant)
            seq = seq[0:len(audio)]

            # Write out wav
            out_wav_path = os.path.join(out_wav_dir, name)
            pp_data.write_audio(out_wav_path, seq, sample_rate)
            print("Write out wav to: %s" % out_wav_path)

            if visualize:
                vmin = -5.
                vmax = 5.
                fig, axs = plt.subplots(2, 1, sharex=True)
                axs[0].matshow(np.log(np.abs(cmplx_sp)).T,
                               origin='lower',
                               aspect='auto',
                               cmap='jet')
                axs[1].matshow(np.log(np.abs(output)).T,
                               origin='lower',
                               aspect='auto',
                               cmap='jet')
                plt.show()