def valid_epoch(data_path, sess, model): _costs = 0.0 predict_labels_list = list() marked_labels_list = list() # cnnout = None for i in range(n_va_batches): [X_batch, y_batch] = get_batch(data_path, i) marked_labels_list.extend(y_batch) y_batch = to_categorical(y_batch, settings.n_class) _batch_size = len(y_batch) if _batch_size != batch_size: continue fetches = [model.loss, model.y_pred, model.logits] feed_dict = { model.X_inputs: X_batch, model.y_inputs: y_batch, model.batch_size: _batch_size, model.tst: True, model.keep_prob: 1.0 } _cost, predict_labels, logits = sess.run(fetches, feed_dict) _costs += _cost predict_labels_list.extend(predict_labels) print(max(predict_labels_list[0]), min(predict_labels_list[0])) print(predict_labels_list[0]) print(logits[0]) f1_micro, f1_macro, score12 = cail_evaluator(predict_labels_list, marked_labels_list) return _costs, f1_micro, f1_macro, score12
def predict(sess, model, logger): """Test on the test data.""" time0 = time.time() te_batches = os.listdir(data_test_path) n_te_batches = len(te_batches) predict_labels_list = list() # 所有的预测结果 marked_labels_list = list() for i in tqdm(range(n_te_batches)): X_batch, y_batch = get_batch(data_test_path, i) _batch_size = len(X_batch) marked_labels_list.extend(y_batch) fetches = [model.y_pred] feed_dict = { model.X_inputs: X_batch, model.batch_size: _batch_size, model.tst: True, model.keep_prob: 1.0 } predict_labels = sess.run(fetches, feed_dict)[0] predict_labels_list.extend(predict_labels) predict_scores_file = scores_path + model_name + '/' + 'predict.npy' marked_scores_file = scores_path + model_name + '/' + 'origin.npy' np.save(predict_scores_file, predict_labels_list) np.save(marked_scores_file, marked_labels_list) f1_micro, f1_macro, score12 = cail_evaluator(predict_labels_list, marked_labels_list) print('f1_micro=%g, f1_macro=%g, score12=%g, time=%g s' % (f1_micro, f1_macro, score12, time.time() - time0)) logger.info( '\nTest predicting...\nEND:Global_step={}: f1_micro={}, f1_macro={}, score12={}, time=%g s' .format(sess.run(model.global_step), f1_micro, f1_macro, score12, time.time() - time0))
def predict_valid(sess, model, logger): """Test on the valid data.""" time0 = time.time() predict_labels_list = list() # 所有的预测结果 marked_labels_list = list() for i in tqdm(range(int(n_va_batches))): [X_batch, y_batch] = get_batch(data_valid_path, i) marked_labels_list.extend(y_batch) _batch_size = len(X_batch) fetches = [model.y_pred] feed_dict = { model.X_inputs: X_batch, model.batch_size: _batch_size, model.tst: True, model.keep_prob: 1.0 } predict_labels = sess.run(fetches, feed_dict)[0] predict_labels_list.extend(predict_labels) f1_micro, f1_macro, score12 = cail_evaluator(predict_labels_list, marked_labels_list) print('precision_micro=%g, recall_micro=%g, score12=%g, time=%g s' % (f1_micro, f1_macro, score12, time.time() - time0)) logger.info( '\nValid predicting...\nEND:Global_step={}: f1_micro={}, f1_macro={}, score12={}, time=%g s' .format(sess.run(model.global_step), f1_micro, f1_macro, score12, time.time() - time0))
def valid_epoch(data_path, sess, model): va_batches = os.listdir(data_path) n_va_batches = len(va_batches) _costs = 0.0 predict_labels_list = list() marked_labels_list = list() for i in range(n_va_batches): [X_batch, y_batch] = get_batch(data_path, i) marked_labels_list.extend(y_batch) y_batch = to_categorical(y_batch) _batch_size = len(y_batch) fetches = [model.loss, model.y_pred] feed_dict = { model.X_inputs: X_batch, model.y_inputs: y_batch, model.batch_size: _batch_size, model.tst: True, model.keep_prob: 1.0 } _cost, predict_labels = sess.run(fetches, feed_dict) _costs += _cost predict_labels_list.extend(predict_labels) f1_micro, f1_macro, score12 = cail_evaluator(predict_labels_list, marked_labels_list) return f1_micro, f1_macro, score12
def valid_epoch(data_path, sess, model): va_batches = os.listdir(data_path) n_va_batches = len(va_batches) _costs = 0.0 predict_labels_list = list() marked_labels_list = list() for i in range(n_va_batches): [X_batch, y_batch] = get_batch(data_path, i) marked_labels_list.extend(y_batch) y_batch = to_categorical(y_batch, cfg.n_class) _batch_size = len(y_batch) fetches = [model.loss, model.y_pred] feed_dict = {model.token_seq: X_batch, model.gold_label: y_batch, model.is_train: False} _cost, predict_labels = sess.run(fetches, feed_dict) _costs += _cost predict_labels_list.extend(predict_labels) f1_micro, f1_macro, score12 = cail_evaluator(predict_labels_list, marked_labels_list) return f1_micro, f1_macro, score12
def valid_epoch(sess, model): _costs = 0.0 predict_labels_list = list() marked_labels_list = list() for i in tqdm(range(int(n_va_batches))): [X_batch, y_batch] = get_batch(data_valid_path, i) marked_labels_list.extend(y_batch) _batch_size = len(y_batch) fetches = [model.y_pred] feed_dict = { model.X_inputs: X_batch, model.batch_size: _batch_size, model.tst: True, model.keep_prob: 1.0 } predict_labels = sess.run(fetches, feed_dict) predict_labels_list.extend(predict_labels) f1_micro, f1_macro, score12 = cail_evaluator(predict_labels_list, marked_labels_list) print(f1_micro, f1_macro, score12)
def predict(sess, model, logger): time0 = time.time() te_batches = os.listdir(data_test_path) n_te_batches = len(te_batches) predict_labels_list = list() # 所有的预测结果 marked_labels_list = list() id_sample = 0 for i in tqdm(range(n_te_batches)): X_batch, y_batch, sent_len, length = get_batch(data_test_path, i) _batch_size = len(X_batch) marked_labels_list.extend(y_batch) fetches = [model.y_pred] feed_dict = { model.X_inputs: X_batch, model.wNum: sent_len, model.sNum: length, model.batch_size: _batch_size, model.tst: True, model.is_train: False } predict_labels = sess.run(fetches, feed_dict)[0] predict_labels_list.extend(predict_labels) predict_scores_file = scores_path + model_name + '/' + 'predict.npy' marked_scores_file = scores_path + model_name + '/' + 'origin.npy' np.save(predict_scores_file, predict_labels_list) np.save(marked_scores_file, marked_labels_list) print('save predict_labels_list', predict_scores_file) f1_micro, f1_macro, score12 = cail_evaluator(predict_labels_list, marked_labels_list) print('cail_evaluator: f1_micro=%g, f1_macro=%g, score12=%g, time=%g s' % (f1_micro, f1_macro, score12, time.time() - time0)) logger.info( '\nTest predicting...\ncail_evaluator--END:Global_step={}: f1_micro={}, f1_macro={}, score12={}, time={}s' .format(sess.run(model.global_step), f1_micro, f1_macro, score12, time.time() - time0))
def valid_epoch(data_path, sess, model, logger, global_step): va_batches = os.listdir(data_path) n_va_batches = len(va_batches) _costs = 0.0 acc_predict_onehot = list() acc_marked_labels = list() law_marked_labels = list() law_predict_onehot = list() death_marked_labels = list() death_predict_onehot = list() imp_marked_labels = list() imp_predict_onehot = list() lif_marked_predict = list() lif_predict_onehot = list() for i in range(n_va_batches): [X_batch, acc, law, death, imp, lif] = get_batch(data_path, i) acc_marked_labels.extend(acc) law_marked_labels.extend(law) death_marked_labels.extend(death) imp_marked_labels.extend(imp) lif_marked_predict.extend(lif) acc_batch = to_categorical(acc, 202) law_batch = to_categorical(law, 183) death_batch = to_categorical_death(death, 2) imp_batch = np.expand_dims(imp, axis=1) lif_batch = to_categorical_death(lif, 2) _batch_size = len(acc_batch) fetches = [ model.loss, model.accu_pred, model.law_pred, model.death_pred, model.imp_pred, model.lif_pred ] feed_dict = { model.X_inputs: X_batch, model.acc_y: acc_batch, model.article_y: law_batch, model.death_y: death_batch, model.imp_y: imp_batch, model.lif_y: lif_batch, model.batch_size: _batch_size, model.tst: True, model.keep_prob: 1.0 } _cost, accu_pred, law_pred, death_pred, imp_pred, lif_pred = sess.run( fetches, feed_dict) _costs += _cost acc_predict_onehot.extend(accu_pred) law_predict_onehot.extend(law_pred) death_predict_onehot.extend(death_pred) imp_predict_onehot.extend(imp_pred) lif_predict_onehot.extend(lif_pred) accu_f1_micro, accu_f1_macro, accu_score12 = cail_evaluator( acc_predict_onehot, acc_marked_labels) law_f1_micro, law_f1_macro, law_score12 = cail_evaluator( law_predict_onehot, law_marked_labels) imp_score = cail_imprisonment_evaluator( death_predict_onehot, death_marked_labels, imp_predict_onehot, imp_marked_labels, lif_predict_onehot, lif_marked_predict) # score12 = accu_score12 + law_score12 + imp_score score12 = accu_score12 logger.info('Global_step={}'.format(global_step, )) logger.info('accu: f1_micro={}, f1_macro={}, score12={}'.format( accu_f1_micro, accu_f1_macro, accu_score12)) logger.info('law: f1_micro={}, f1_macro={}, score12={}'.format( law_f1_micro, law_f1_macro, law_score12)) logger.info('imprisonment: score={}'.format(imp_score)) return score12