if args.gpu >= 0: cuda.get_device(args.gpu).use() model.to_gpu() #optimizer optimizer = chainer.optimizers.Adam() optimizer.setup(model) mistake_list = [] for xs, ys, ys_dep_tag, word, is_verb in test_data: xs = cuda.to_gpu(xs) xs = Variable(xs) with chainer.using_config('train', False): pred_ys = model.traverse([xs]) pred_ys = [F.softmax(pred_y) for pred_y in pred_ys] pred_ys = [pred_y.data.argmax(axis=0)[1] for pred_y in pred_ys] pred_ys = int(pred_ys[0]) ys = ys.argmax() item_type = return_item_type(ys, ys_dep_tag) case_num['all'] += 1 case_num[item_type] += 1 if pred_ys == ys: correct_num['all'] += 1 correct_num[item_type] += 1 item_type = return_item_type(ys, []) pred_item_type = return_item_type(pred_ys, []) confusion_matrix[item_type][pred_item_type] += 1
def predict(model_path, test_data, type_statistics_dict, domain, case): confusion_matrix = defaultdict(dict) feature_size = test_data[0][0].shape[1] model = BiLSTMBase(input_size=feature_size, output_size=feature_size, n_labels=2, n_layers=1, dropout=0.2, type_statistics_dict=type_statistics_dict) serializers.load_npz(model_path, model) correct_num = { 'all': 0., '照応なし': 0., '文内': 0., '文内(dep)': 0., '文内(zero)': 0., '発信者': 0., '受信者': 0., '項不定': 0. } case_num = { 'all': 0., '照応なし': 0., '文内': 0., '文内(dep)': 0., '文内(zero)': 0., '発信者': 0., '受信者': 0., '項不定': 0. } accuracy = { 'all': 0., '照応なし': 0., '文内': 0., '文内(dep)': 0., '文内(zero)': 0., '発信者': 0., '受信者': 0., '項不定': 0. } for key1 in correct_num.keys(): for key2 in correct_num.keys(): confusion_matrix[key1][key2] = 0 cuda.get_device(0).use() model.to_gpu() mistake_list = [] for xs, ys, ys_dep_tag, zs, word, is_verb in test_data: xs = cuda.to_gpu(xs) xs = Variable(xs) with chainer.using_config('train', False): pred_ys = model.traverse([xs], [zs]) pred_ys = [F.softmax(pred_y) for pred_y in pred_ys] pred_ys = pred_ys[0].data.argmax(axis=0)[1] ys = ys.argmax() item_type = return_item_type(ys, ys_dep_tag) case_num['all'] += 1 case_num[item_type] += 1 if pred_ys == ys: correct_num['all'] += 1 correct_num[item_type] += 1 item_type = return_item_type(ys, []) pred_item_type = return_item_type(pred_ys, []) confusion_matrix[item_type][pred_item_type] += 1 if pred_ys != ys: if item_type == '文内': item_type = ys - 4 if pred_item_type == '文内': pred_item_type = pred_ys - 4 sentence = ''.join( word[4:is_verb]) + '"' + word[is_verb:is_verb + 1] + '"' + ''.join( word[is_verb + 1:]) mistake_list.append( [item_type, pred_item_type, is_verb - 4, sentence]) correct_num['文内'] = correct_num['文内(dep)'] + correct_num['文内(zero)'] case_num['文内'] = case_num['文内(dep)'] + case_num['文内(zero)'] for key in accuracy: if case_num[key]: accuracy[key] = correct_num[key] / case_num[key] * 100 else: accuracy[key] = 999 output_path = './' + 'predict' if not os.path.exists(output_path): os.mkdir(output_path) dump_path = '{0}/domain-{1}_caes-{2}.tsv'.format(output_path, domain, case) print('model_path:{0}_domain:{1}_accuracy:{2:.2f}'.format( 'majority', domain, accuracy['all'])) if not os.path.exists(dump_path): with open(dump_path, 'a') as f: f.write( 'model_path\tdomain\taccuracy(全体)\taccuracy(照応なし)\taccuracy(発信者)\taccuracy(受信者)\taccuracy(項不定)\taccuracy(文内)\taccuracy(文内(dep))\taccuracy(文内(zep))\ttest_data_size\n' ) with open(dump_path, 'a') as f: f.write( '{0}\t{1}\t{2:.2f}\t{3:.2f}\t{4:.2f}\t{5:.2f}\t{6:.2f}\t{7:.2f}\t{8:.2f}\t{9:.2f}\t{10}\n' .format('majority', domain, accuracy['all'], accuracy['照応なし'], accuracy['発信者'], accuracy['受信者'], accuracy['項不定'], accuracy['文内'], accuracy['文内(dep)'], accuracy['文内(zero)'], len(test_data))) output_path = './' + 'confusion_matrix' if not os.path.exists(output_path): os.mkdir(output_path) dump_path = '{0}/domain-{1}_case-{2}.tsv'.format(output_path, domain, case) with open(dump_path, 'w') as f: f.write('model_path\t' + 'majority' + '\n') f.write(' \t \t予測結果\n') f.write(' \t \t照応なし\t発信者\t受信者\t項不定\t文内\tsum(全体)\n実際の分類結果') for case_type in ['照応なし', '発信者', '受信者', '項不定', '文内']: f.write(' \t{0}\t{1}\t{2}\t{3}\t{4}\t{5}\t{6}\n'.format( case_type, confusion_matrix[case_type]['照応なし'], confusion_matrix[case_type]['発信者'], confusion_matrix[case_type]['受信者'], confusion_matrix[case_type]['項不定'], confusion_matrix[case_type]['文内'], case_num[case_type])) f.write('\n') output_path = './' + 'mistake_sentence' if not os.path.exists(output_path): os.mkdir(output_path) output_path = '{0}/domain-{1}_case-{2}'.format(output_path, domain, case) if not os.path.exists(output_path): os.mkdir(output_path) dump_path = '{0}/model-{1}.txt'.format(output_path, 'majority') with open(dump_path, 'a') as f: f.write('model_path\t' + 'majority' + '\n') f.write('正解位置\t予測位置\t述語位置\t文\n') for mistake in mistake_list: mistake = [str(i) for i in mistake] f.write('\t'.join(mistake)) f.write('\n')