コード例 #1
0
 def test_in_out(self):
     dummy_input = T.rand(4, 20, 80)  # B x T x D
     dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)]
     model = SpeakerEncoder(input_dim=80,
                            proj_dim=256,
                            lstm_dim=768,
                            num_lstm_layers=3)
     # computing d vectors
     output = model.forward(dummy_input)
     assert output.shape[0] == 4
     assert output.shape[1] == 256
     output = model.inference(dummy_input)
     assert output.shape[0] == 4
     assert output.shape[1] == 256
     # compute d vectors by passing LSTM hidden
     # output = model.forward(dummy_input, dummy_hidden)
     # assert output.shape[0] == 4
     # assert output.shape[1] == 20
     # assert output.shape[2] == 256
     # check normalization
     output_norm = T.nn.functional.normalize(output, dim=1, p=2)
     assert_diff = (output_norm - output).sum().item()
     assert output.type() == "torch.FloatTensor"
     assert (abs(assert_diff) <
             1e-4), f" [!] output_norm has wrong values - {assert_diff}"
     # compute d for a given batch
     dummy_input = T.rand(1, 240, 80)  # B x T x D
     output = model.compute_embedding(dummy_input,
                                      num_frames=160,
                                      overlap=0.5)
     assert output.shape[0] == 1
     assert output.shape[1] == 256
     assert len(output.shape) == 2
コード例 #2
0
            #print(f'wav_file: {wav_file}')
            if os.path.exists(wav_file):
                wav_files.append(wav_file)
    print(f'Count of wavs imported: {len(wav_files)}')
else:
    # Parse all wav files in data_path
    wav_path = data_path
    wav_files = glob.glob(data_path + '/**/*.wav', recursive=True)

output_files = [
    wav_file.replace(wav_path, args.output_path).replace('.wav', '.npy')
    for wav_file in wav_files
]

for output_file in output_files:
    os.makedirs(os.path.dirname(output_file), exist_ok=True)

model = SpeakerEncoder(**c.model)
model.load_state_dict(torch.load(args.model_path)['model'])
model.eval()
if args.use_cuda:
    model.cuda()

for idx, wav_file in enumerate(tqdm(wav_files)):
    mel_spec = ap.melspectrogram(ap.load_wav(wav_file)).T
    mel_spec = torch.FloatTensor(mel_spec[None, :, :])
    if args.use_cuda:
        mel_spec = mel_spec.cuda()
    embedd = model.compute_embedding(mel_spec)
    np.save(output_files[idx], embedd.detach().cpu().numpy())