コード例 #1
0
def computeValidAccuracy(args, modelDir):
    """ Computes frame-level validation accruacy
    """
    modelFile = max(glob.glob(modelDir + '/*'), key=os.path.getctime)
    # Load the model
    net = eval('{}({}, p_dropout=0)'.format(args.modelType, args.numSpkrs))

    checkpoint = torch.load(modelFile, map_location=torch.device('cuda'))
    new_state_dict = OrderedDict()
    for k, v in checkpoint['model_state_dict'].items():
        if k.startswith('module.'):
            new_state_dict[k[7:]] = v  # ugly fix to remove 'module' from key
        else:
            new_state_dict[k] = v
    # load params
    net.load_state_dict(new_state_dict)
    net = net.cuda()
    net.eval()

    correct, incorrect = 0, 0
    for validArk in glob.glob(args.featDir + '/valid_egs.*.ark'):
        x = kaldi_python_io.Nnet3EgsReader(validArk)
        for key, mat in x:
            out = net(x=torch.Tensor(mat[0]['matrix']).permute(
                1, 0).unsqueeze(0).cuda(),
                      eps=0)
            if mat[1]['matrix'][0][0][0] + 1 == torch.argmax(out) + 1:
                correct += 1
            else:
                incorrect += 1
    return 100.0 * correct / (correct + incorrect)
コード例 #2
0
def writeHdf5File(egsFile, scpFile, chunkLen, hdf5File):

    featDim = 30
    output = subprocess.run(['wc', '-l', scpFile],
                            stdout=subprocess.PIPE).stdout.decode('utf-8')
    numSamples = int(output.split()[0])
    x = kaldi_python_io.Nnet3EgsReader(egsFile)
    with h5py.File(hdf5File, 'w') as fid:
        feats = fid.create_dataset('feats', (numSamples, chunkLen, featDim),
                                   dtype='f')  #, compression="gzip")
        labels = fid.create_dataset('labels', (numSamples, 1),
                                    dtype='i8')  #, compression="gzip")
        count = 0
        for key, mat in x:
            labels[count] = mat[1]['matrix'][0][0][0]
            feats[count] = mat[0]['matrix']
            count += 1
コード例 #3
0
import os
import sys
import torch
import socket
import kaldi_python_io
from train_utils import *


egsDir =
modelDir = '/home/manoj/Projects/pytorch_spkembed/xvectors_voxceleb/models/isXvec_False_modelType_3_event_202002-1719-0729'
modelFile = max(glob.glob(modelDir), key=os.path.getctime)

# Load the model
net = simpleTDNN(params['numSpkrs'], p_dropout=0)
checkpoint = torch.load(modelFile)
net.load_state_dict(checkpoint['model_state_dict'])
net.eval()

correct, incorrect = 0, 0
for validArk in glob.glob(egsDir+'/valid_egs.*.ark'):
    x = kaldi_python_io.Nnet3EgsReader(validArk)
    for key, mat in x:
        out = net(torch.Tensor(mat[0]['matrix']).permute(1,0).unsqueeze(0))
        if mat[1]['matrix'][0][0][0] == torch.argmax(out)+1:
            correct += 1
        else:
            incorrect += 1
        #print('%d,%d' %(mat[1]['matrix'][0][0][0],torch.argmax(out)+1))
print('Valid accuracy: %1.2f percent' %(1.0*correct/(correct+incorrect)))
コード例 #4
0
 def __init__(self, arkFile):
     self.fid = kaldi_python_io.Nnet3EgsReader(arkFile)