def task_predict(input_files, input_model): # 把输入的多个模型目录字符串分离为目录列表 input_models = [os.path.expanduser(k) for k in input_model.strip().split()] # 把各个目录下的模型列表解压出来组合成一个迭代器 models = itertools.chain(*(glob.glob(m) for m in input_models)) # 生成并加载包括所有模型文件(skgmm.GMMSet object)的列表 models = [ModelInterface.load(m) for m in models] if len(models) == 0: print("No model file found in %s" % input_model) sys.exit(1) # 定义统计准确率的变量 right = 0 right1 = 0 wrong = 0 wrong1 = 0 num = 0 # 对每个预测音频文件提取特征并与每个模型匹配得到TOP结果 for f in glob.glob(os.path.expanduser(input_files)): start_time = time.time() fs, signal = read_wav(f) print(f) feat = get_feature(fs, signal) #print("Get feature ", time.time() - start_time, " seconds") predict_result = [] f_models = [(feat, m) for m in models] #print(models) # 每个音频文件分别匹配每个模型组并得出分数放到总列表 # for model in models: # #start_time1 = time.time() # #print(model) # # 模型文件是一个元组:(label,gmm) # score = model[1].score(feat) # label=model[0] # result=(label,score) # #print(results) # predict_result.append(result) #print("Get one score ", time.time() - start_time1, " seconds") pool = ThreadPool(2) predict_result = pool.map(get_score, f_models) pool.close() pool.join() #print(results) #print("Get score ", time.time() - start_time, " seconds") proba = GMMSet.softmax([i[1] for i in predict_result]) predict_result = [(predict_result[i][0], proba[i]) for i in range(len(proba))] #print("predict_result:",predict_result) # 对预测结果按分数作高到底排序 predict_result = sorted(predict_result, key=operator.itemgetter(1), reverse=True) #print("sort_predict_result:", predict_result) # 微信语音数据集的label格式 label = os.path.basename(f).split('_')[0] #[6:11] #label=os.path.basename(f).split('(')[0]#[6:11] # AISHELL数据集的label格式 # label=os.path.basename(f)[6:11] predict = predict_result[0][0] predict_score = predict_result[0][1] print("Predict ", time.time() - start_time, " seconds") # #print('Top:',predict_result[:10]) # 统计top1准确率 if label in predict: right1 += 1 print('label:', label, ' predict:', predict, ' score:', predict_score, ' top1 right') else: wrong1 += 1 print('label:', label, ' predict:', predict, ' score:', predict_score, ' top1 wrong') # 统计Top10准确率 predicts = [] predict_scores = [] for pre in predict_result[:10]: predicts.append(pre[0]) predict_scores.append(pre[1]) if label in predicts: right += 1 print('label:', label, ' predicts:', predicts, ' scores:', predict_scores, ' top10 Right') else: wrong += 1 print('label:', label, ' predicts:', predicts, ' scores:', predict_scores, ' top10 Wrong') num += 1 print('top1:', num, ' right:', right1, ' wrong:', wrong1, ' top1 acc:', right1 / num) print('top10:', num, ' right:', right, ' wrong:', wrong, ' top10 acc:', right / num)