def computeValidAccuracy(args, modelDir): """ Computes frame-level validation accruacy """ modelFile = max(glob.glob(modelDir + '/*'), key=os.path.getctime) # Load the model net = eval('{}({}, p_dropout=0)'.format(args.modelType, args.numSpkrs)) checkpoint = torch.load(modelFile, map_location=torch.device('cuda')) new_state_dict = OrderedDict() for k, v in checkpoint['model_state_dict'].items(): if k.startswith('module.'): new_state_dict[k[7:]] = v # ugly fix to remove 'module' from key else: new_state_dict[k] = v # load params net.load_state_dict(new_state_dict) net = net.cuda() net.eval() correct, incorrect = 0, 0 for validArk in glob.glob(args.featDir + '/valid_egs.*.ark'): x = kaldi_python_io.Nnet3EgsReader(validArk) for key, mat in x: out = net(x=torch.Tensor(mat[0]['matrix']).permute( 1, 0).unsqueeze(0).cuda(), eps=0) if mat[1]['matrix'][0][0][0] + 1 == torch.argmax(out) + 1: correct += 1 else: incorrect += 1 return 100.0 * correct / (correct + incorrect)
def writeHdf5File(egsFile, scpFile, chunkLen, hdf5File): featDim = 30 output = subprocess.run(['wc', '-l', scpFile], stdout=subprocess.PIPE).stdout.decode('utf-8') numSamples = int(output.split()[0]) x = kaldi_python_io.Nnet3EgsReader(egsFile) with h5py.File(hdf5File, 'w') as fid: feats = fid.create_dataset('feats', (numSamples, chunkLen, featDim), dtype='f') #, compression="gzip") labels = fid.create_dataset('labels', (numSamples, 1), dtype='i8') #, compression="gzip") count = 0 for key, mat in x: labels[count] = mat[1]['matrix'][0][0][0] feats[count] = mat[0]['matrix'] count += 1
import os import sys import torch import socket import kaldi_python_io from train_utils import * egsDir = modelDir = '/home/manoj/Projects/pytorch_spkembed/xvectors_voxceleb/models/isXvec_False_modelType_3_event_202002-1719-0729' modelFile = max(glob.glob(modelDir), key=os.path.getctime) # Load the model net = simpleTDNN(params['numSpkrs'], p_dropout=0) checkpoint = torch.load(modelFile) net.load_state_dict(checkpoint['model_state_dict']) net.eval() correct, incorrect = 0, 0 for validArk in glob.glob(egsDir+'/valid_egs.*.ark'): x = kaldi_python_io.Nnet3EgsReader(validArk) for key, mat in x: out = net(torch.Tensor(mat[0]['matrix']).permute(1,0).unsqueeze(0)) if mat[1]['matrix'][0][0][0] == torch.argmax(out)+1: correct += 1 else: incorrect += 1 #print('%d,%d' %(mat[1]['matrix'][0][0][0],torch.argmax(out)+1)) print('Valid accuracy: %1.2f percent' %(1.0*correct/(correct+incorrect)))
def __init__(self, arkFile): self.fid = kaldi_python_io.Nnet3EgsReader(arkFile)