예제 #1
0
def valid_epoch(data_path, sess, model):
    _costs = 0.0
    predict_labels_list = list()
    marked_labels_list = list()
    # cnnout = None
    for i in range(n_va_batches):
        [X_batch, y_batch] = get_batch(data_path, i)
        marked_labels_list.extend(y_batch)
        y_batch = to_categorical(y_batch, settings.n_class)
        _batch_size = len(y_batch)
        if _batch_size != batch_size:
            continue
        fetches = [model.loss, model.y_pred, model.logits]
        feed_dict = {
            model.X_inputs: X_batch,
            model.y_inputs: y_batch,
            model.batch_size: _batch_size,
            model.tst: True,
            model.keep_prob: 1.0
        }
        _cost, predict_labels, logits = sess.run(fetches, feed_dict)
        _costs += _cost
        predict_labels_list.extend(predict_labels)
    print(max(predict_labels_list[0]), min(predict_labels_list[0]))
    print(predict_labels_list[0])
    print(logits[0])
    f1_micro, f1_macro, score12 = cail_evaluator(predict_labels_list,
                                                 marked_labels_list)
    return _costs, f1_micro, f1_macro, score12
예제 #2
0
def predict(sess, model, logger):
    """Test on the test data."""
    time0 = time.time()
    te_batches = os.listdir(data_test_path)
    n_te_batches = len(te_batches)
    predict_labels_list = list()  # 所有的预测结果
    marked_labels_list = list()
    for i in tqdm(range(n_te_batches)):
        X_batch, y_batch = get_batch(data_test_path, i)
        _batch_size = len(X_batch)
        marked_labels_list.extend(y_batch)
        fetches = [model.y_pred]
        feed_dict = {
            model.X_inputs: X_batch,
            model.batch_size: _batch_size,
            model.tst: True,
            model.keep_prob: 1.0
        }
        predict_labels = sess.run(fetches, feed_dict)[0]
        predict_labels_list.extend(predict_labels)
    predict_scores_file = scores_path + model_name + '/' + 'predict.npy'
    marked_scores_file = scores_path + model_name + '/' + 'origin.npy'
    np.save(predict_scores_file, predict_labels_list)
    np.save(marked_scores_file, marked_labels_list)

    f1_micro, f1_macro, score12 = cail_evaluator(predict_labels_list,
                                                 marked_labels_list)
    print('f1_micro=%g, f1_macro=%g, score12=%g, time=%g s' %
          (f1_micro, f1_macro, score12, time.time() - time0))
    logger.info(
        '\nTest predicting...\nEND:Global_step={}: f1_micro={}, f1_macro={}, score12={}, time=%g s'
        .format(sess.run(model.global_step), f1_micro, f1_macro, score12,
                time.time() - time0))
예제 #3
0
def predict_valid(sess, model, logger):
    """Test on the valid data."""
    time0 = time.time()
    predict_labels_list = list()  # 所有的预测结果
    marked_labels_list = list()
    for i in tqdm(range(int(n_va_batches))):
        [X_batch, y_batch] = get_batch(data_valid_path, i)
        marked_labels_list.extend(y_batch)
        _batch_size = len(X_batch)
        fetches = [model.y_pred]
        feed_dict = {
            model.X_inputs: X_batch,
            model.batch_size: _batch_size,
            model.tst: True,
            model.keep_prob: 1.0
        }
        predict_labels = sess.run(fetches, feed_dict)[0]
        predict_labels_list.extend(predict_labels)

    f1_micro, f1_macro, score12 = cail_evaluator(predict_labels_list,
                                                 marked_labels_list)
    print('precision_micro=%g, recall_micro=%g, score12=%g, time=%g s' %
          (f1_micro, f1_macro, score12, time.time() - time0))
    logger.info(
        '\nValid predicting...\nEND:Global_step={}: f1_micro={}, f1_macro={}, score12={}, time=%g s'
        .format(sess.run(model.global_step), f1_micro, f1_macro, score12,
                time.time() - time0))
예제 #4
0
def valid_epoch(data_path, sess, model):
    va_batches = os.listdir(data_path)
    n_va_batches = len(va_batches)
    _costs = 0.0
    predict_labels_list = list()
    marked_labels_list = list()
    for i in range(n_va_batches):
        [X_batch, y_batch] = get_batch(data_path, i)
        marked_labels_list.extend(y_batch)
        y_batch = to_categorical(y_batch)
        _batch_size = len(y_batch)
        fetches = [model.loss, model.y_pred]
        feed_dict = {
            model.X_inputs: X_batch,
            model.y_inputs: y_batch,
            model.batch_size: _batch_size,
            model.tst: True,
            model.keep_prob: 1.0
        }
        _cost, predict_labels = sess.run(fetches, feed_dict)
        _costs += _cost
        predict_labels_list.extend(predict_labels)
    f1_micro, f1_macro, score12 = cail_evaluator(predict_labels_list,
                                                 marked_labels_list)
    return f1_micro, f1_macro, score12
예제 #5
0
def valid_epoch(data_path, sess, model):
    va_batches = os.listdir(data_path)
    n_va_batches = len(va_batches)
    _costs = 0.0
    predict_labels_list = list()
    marked_labels_list = list()
    for i in range(n_va_batches):
        [X_batch, y_batch] = get_batch(data_path, i)
        marked_labels_list.extend(y_batch)
        y_batch = to_categorical(y_batch, cfg.n_class)
        _batch_size = len(y_batch)
        fetches = [model.loss, model.y_pred]
        feed_dict = {model.token_seq: X_batch, model.gold_label: y_batch, model.is_train: False}
        _cost, predict_labels = sess.run(fetches, feed_dict)
        _costs += _cost
        predict_labels_list.extend(predict_labels)
    f1_micro, f1_macro, score12 = cail_evaluator(predict_labels_list, marked_labels_list)
    return f1_micro, f1_macro, score12
예제 #6
0
def valid_epoch(sess, model):
    _costs = 0.0
    predict_labels_list = list()
    marked_labels_list = list()
    for i in tqdm(range(int(n_va_batches))):
        [X_batch, y_batch] = get_batch(data_valid_path, i)
        marked_labels_list.extend(y_batch)
        _batch_size = len(y_batch)
        fetches = [model.y_pred]
        feed_dict = {
            model.X_inputs: X_batch,
            model.batch_size: _batch_size,
            model.tst: True,
            model.keep_prob: 1.0
        }
        predict_labels = sess.run(fetches, feed_dict)
        predict_labels_list.extend(predict_labels)
    f1_micro, f1_macro, score12 = cail_evaluator(predict_labels_list,
                                                 marked_labels_list)
    print(f1_micro, f1_macro, score12)
예제 #7
0
def predict(sess, model, logger):
    time0 = time.time()
    te_batches = os.listdir(data_test_path)
    n_te_batches = len(te_batches)
    predict_labels_list = list()  # 所有的预测结果
    marked_labels_list = list()
    id_sample = 0
    for i in tqdm(range(n_te_batches)):
        X_batch, y_batch, sent_len, length = get_batch(data_test_path, i)
        _batch_size = len(X_batch)
        marked_labels_list.extend(y_batch)
        fetches = [model.y_pred]
        feed_dict = {
            model.X_inputs: X_batch,
            model.wNum: sent_len,
            model.sNum: length,
            model.batch_size: _batch_size,
            model.tst: True,
            model.is_train: False
        }
        predict_labels = sess.run(fetches, feed_dict)[0]
        predict_labels_list.extend(predict_labels)

    predict_scores_file = scores_path + model_name + '/' + 'predict.npy'
    marked_scores_file = scores_path + model_name + '/' + 'origin.npy'
    np.save(predict_scores_file, predict_labels_list)
    np.save(marked_scores_file, marked_labels_list)
    print('save predict_labels_list', predict_scores_file)

    f1_micro, f1_macro, score12 = cail_evaluator(predict_labels_list,
                                                 marked_labels_list)
    print('cail_evaluator: f1_micro=%g, f1_macro=%g, score12=%g, time=%g s' %
          (f1_micro, f1_macro, score12, time.time() - time0))
    logger.info(
        '\nTest predicting...\ncail_evaluator--END:Global_step={}: f1_micro={}, f1_macro={}, score12={}, time={}s'
        .format(sess.run(model.global_step), f1_micro, f1_macro, score12,
                time.time() - time0))
예제 #8
0
def valid_epoch(data_path, sess, model, logger, global_step):
    va_batches = os.listdir(data_path)
    n_va_batches = len(va_batches)
    _costs = 0.0
    acc_predict_onehot = list()
    acc_marked_labels = list()
    law_marked_labels = list()
    law_predict_onehot = list()
    death_marked_labels = list()
    death_predict_onehot = list()
    imp_marked_labels = list()
    imp_predict_onehot = list()
    lif_marked_predict = list()
    lif_predict_onehot = list()
    for i in range(n_va_batches):
        [X_batch, acc, law, death, imp, lif] = get_batch(data_path, i)
        acc_marked_labels.extend(acc)
        law_marked_labels.extend(law)
        death_marked_labels.extend(death)
        imp_marked_labels.extend(imp)
        lif_marked_predict.extend(lif)
        acc_batch = to_categorical(acc, 202)
        law_batch = to_categorical(law, 183)
        death_batch = to_categorical_death(death, 2)
        imp_batch = np.expand_dims(imp, axis=1)
        lif_batch = to_categorical_death(lif, 2)
        _batch_size = len(acc_batch)
        fetches = [
            model.loss, model.accu_pred, model.law_pred, model.death_pred,
            model.imp_pred, model.lif_pred
        ]
        feed_dict = {
            model.X_inputs: X_batch,
            model.acc_y: acc_batch,
            model.article_y: law_batch,
            model.death_y: death_batch,
            model.imp_y: imp_batch,
            model.lif_y: lif_batch,
            model.batch_size: _batch_size,
            model.tst: True,
            model.keep_prob: 1.0
        }
        _cost, accu_pred, law_pred, death_pred, imp_pred, lif_pred = sess.run(
            fetches, feed_dict)
        _costs += _cost
        acc_predict_onehot.extend(accu_pred)
        law_predict_onehot.extend(law_pred)
        death_predict_onehot.extend(death_pred)
        imp_predict_onehot.extend(imp_pred)
        lif_predict_onehot.extend(lif_pred)
    accu_f1_micro, accu_f1_macro, accu_score12 = cail_evaluator(
        acc_predict_onehot, acc_marked_labels)
    law_f1_micro, law_f1_macro, law_score12 = cail_evaluator(
        law_predict_onehot, law_marked_labels)
    imp_score = cail_imprisonment_evaluator(
        death_predict_onehot, death_marked_labels, imp_predict_onehot,
        imp_marked_labels, lif_predict_onehot, lif_marked_predict)
    # score12 = accu_score12 + law_score12 + imp_score
    score12 = accu_score12
    logger.info('Global_step={}'.format(global_step, ))
    logger.info('accu: f1_micro={}, f1_macro={}, score12={}'.format(
        accu_f1_micro, accu_f1_macro, accu_score12))
    logger.info('law: f1_micro={}, f1_macro={}, score12={}'.format(
        law_f1_micro, law_f1_macro, law_score12))
    logger.info('imprisonment: score={}'.format(imp_score))
    return score12