示例#1
0
文件: test.py 项目: qtyty/textcnn
def main():
    # 读配置文件
    config = parse_config()
    # 载入训练集合
    train_data = DataBatchIterator(config=config,
                                   is_train=True,
                                   dataset="train",
                                   batch_size=config.batch_size,
                                   shuffle=True)
    train_data.load()

    vocab = train_data.vocab

    # 载入测试集合
    test_data = DataBatchIterator(config=config,
                                  is_train=False,
                                  dataset="test",
                                  batch_size=config.batch_size)
    test_data.set_vocab(vocab)
    test_data.load()

    # 测试时
    checkpoint = torch.load(config.save_model + ".pt",
                            map_location=config.device)
    model = checkpoint

    # model = build_textcnn_model(
    #     vocab, config, train=True)
    predict, label = test_textcnn_model(model, test_data, config)
    print(classification_report(label, predict))
示例#2
0
文件: test.py 项目: Mydoria/17122852
def main():
    # 读配置文件
    config = parse_config()
    # 载入测试集合
    mylog = open('result.log', mode='a', encoding='utf-8')
    test_data = DataBatchIterator(config=config,
                                  is_train=False,
                                  dataset="test",
                                  batch_size=config.batch_size,
                                  shuffle=True)
    test_data.load()

    # 载入textcnn模型
    model = torch.load("results/model.pt")
    #print(model)
    criterion = nn.CrossEntropyLoss(reduction="sum")

    # Do training.
    loss, precision, recall, f1 = test_textcnn_model(model, test_data,
                                                     criterion, config)
    print(
        "test loss: {0:.2f},  precision: {1:.2f},  recall:{2:.2f},  f1:{3:.2f}"
        .format(loss, precision, recall, f1),
        file=mylog)
    mylog.close()
示例#3
0
文件: test.py 项目: AOLUCY/textcnn
def main():
    config = parse_config()
    checkpoint = torch.load(config.save_model + ".pt",
                            map_location=config.device)

    train_data = DataBatchIterator(config=config,
                                   is_train=True,
                                   dataset="train",
                                   batch_size=config.batch_size,
                                   shuffle=True)

    train_data.load()

    vocab = train_data.vocab

    # 载入测试集合
    valid_data = DataBatchIterator(config=config,
                                   is_train=False,
                                   dataset="test",
                                   batch_size=config.batch_size)
    valid_data.set_vocab(vocab)
    valid_data.load()
    # Do training.
    padding_idx = vocab.stoi[PAD]
    train_textcnn_model(checkpoint, train_data, valid_data, padding_idx,
                        config)
示例#4
0
def main():
    # 读配置文件
    config = parse_config()
    # 载入测试集合
    test_data = DataBatchIterator(
        config=config,
        is_train=False,
        dataset="test",
        # batch_size=config.batch_size)
    )
    # test_data.set_vocab(vocab)
    test_data.load()

    # 加载textcnn模型
    model = torch.load('./results/model.pt')

    # 打印模型信息
    print(model)

    # 测试
    accuracy, corrects, size = test_textcnn_model(model, test_data, config)

    # 打印结果
    print('\nEvaluation - acc: {:.4f}%({}/{}) \n'.format(
        accuracy, corrects, size))
示例#5
0
def main():
    # 读配置文件
    config = parse_config()
    # 载入训练集合
    train_data = DataBatchIterator(
        config=config,
        is_train=True,
        dataset="train",
        #batch_size=config.batch_size,
        shuffle=True)
    train_data.load()

    vocab = train_data.vocab

    # 载入测试集合
    valid_data = DataBatchIterator(
        config=config,
        is_train=False,
        dataset="dev",
        #batch_size=config.batch_size
    )
    valid_data.set_vocab(vocab)
    valid_data.load()

    # 构建textcnn模型
    model = build_textcnn_model(vocab, config, train=True)

    print(model)

    # Do training.
    padding_idx = vocab.stoi[PAD]
    train_textcnn_model(model, train_data, valid_data, padding_idx, config)
    torch.save(model, '%s.pt' % (config.save_model))
示例#6
0
def main():
    # 读配置文件
    config = parse_config()
    # 载入训练集合
    train_data = DataBatchIterator(config=config,
                                   is_train=True,
                                   dataset="train",
                                   batch_size=config.batch_size,
                                   shuffle=True)
    train_data.load()

    vocab = train_data.vocab  #词汇映射表

    # 载入测试集合
    test_data = DataBatchIterator(config=config,
                                  is_train=False,
                                  dataset="test",
                                  batch_size=config.batch_size)
    test_data.set_vocab(vocab)
    test_data.load()

    # 测试时载入模型
    model = torch.load(config.save_model + ".pt", map_location=config.device)

    print(model)

    test(model, test_data)
示例#7
0
def main():
    # 读配置文件
    config = parse_config()
    # 载入测试集
    test_data = DataBatchIterator(config=config,
                                  is_train=False,
                                  dataset="test",
                                  batch_size=config.batch_size)
    test_data.load()
    # 加载模型
    model = torch.load(config.save_model + ".pt", map_location=config.device)
    model.eval()
    test_data_iter = iter(test_data)
    y_pred = []  # 预测值
    y_true = []  # 真实标签
    for idx, batch in enumerate(test_data_iter):
        outputs = model(batch.sent)
        pred_each = torch.max(outputs, 1)[1].numpy().tolist()
        true_each = batch.label.numpy().tolist()
        y_pred = y_pred + pred_each
        y_true = y_true + true_each
    target_names = [
        'news_edu', 'news_finance', 'news_house', 'news_travel', 'news_tech',
        'news_sports', 'news_game', 'news_culture', 'news_car', 'news_story',
        'news_entertainment', 'news_tech', 'news_agriculture', 'news_world',
        'news_stock'
    ]
    print(classification_report(y_true, y_pred, target_names=target_names))
    classification_report(y_true, y_pred, target_names=target_names)
    print('hello')
示例#8
0
def main():
    config = parse_config()
    test_data = DataBatchIterator(
        config=config,
        is_train=False,
        dataset="test",
        # batch_size=config.batch_size)
    )
    test_data.load()
    model = torch.load('./results/model.pt')

    # 测试
    test_textcnn_model(model, test_data, config)
示例#9
0
def main():
    # 读配置文件
    config = parse_config()
    # 载入训练集合
    train_data = DataBatchIterator(config=config,
                                   is_train=True,
                                   dataset="train",
                                   batch_size=config.batch_size,
                                   shuffle=True)
    train_data.load()

    vocab = train_data.vocab

    # 载入测试集合
    valid_data = DataBatchIterator(config=config,
                                   is_train=False,
                                   dataset="test",
                                   batch_size=config.batch_size)
    valid_data.set_vocab(vocab)
    valid_data.load()

    # 测试时
    checkpoint = torch.load(config.save_model + ".pt",
                            map_location=config.device)
    model = checkpoint
    # model = build_textcnn_model(
    #     vocab, config, train=True)
    model.eval()
    total_loss = 0
    valid_data_iter = iter(valid_data)
    total_predict = torch.LongTensor([])
    total_label = torch.LongTensor([])
    for idx, batch in enumerate(valid_data_iter):
        model.zero_grad()
        ground_truth = batch.label
        # batch_first = False
        outputs = model(batch.sent)
        predict = torch.max(outputs, 1)[1]
        total_predict = torch.cat((total_predict, predict), 0)
        total_label = torch.cat((total_label, batch.label), 0)
    print(classification_report(total_label, total_predict))
示例#10
0
文件: ft.py 项目: xixi1998/datamining
def main():
    # 读配置文件
    config = parse_config()
    # 载入测试集合
    test_data = DataBatchIterator(
        config=config,
        is_train=False,
        dataset="test",
        # batch_size=config.batch_size)
    )
    # test_data.set_vocab(vocab)
    test_data.load()
    print(test_data.vocab)

    train_data = DataBatchIterator(config=config,
                                   is_train=True,
                                   dataset="train",
                                   batch_size=config.batch_size,
                                   shuffle=True)
    train_data.load()

    model = fasttext.skipgram(train_data, 'model')
    classifier = fasttext.supervised(test_data, 'model')
示例#11
0
def main():
    # 读配置文件
    config = parse_config()
    # 载入训练集合
    train_data = DataBatchIterator(
        config=config,
        is_train=True,
        dataset="train",
        batch_size=config.batch_size,
        shuffle=True)
    train_data.load()

    vocab = train_data.vocab

    # 载入测试集合
    valid_data = DataBatchIterator(
        config=config,
        is_train=False,
        dataset="dev",
        batch_size=config.batch_size)
    valid_data.set_vocab(vocab)
    valid_data.load()

    # 构建textcnn模型
    model = build_textcnn_model(
        vocab, config, train=True)

    print(model)
    # Do training.
    padding_idx = vocab.stoi[PAD]
    #train_textcnn_model(model, train_data,
    #                    valid_data, padding_idx, config)
    #torch.save(model, '%s.pt' % (config.save_model))


    # 测试时
    #加载测试集
    test_data = DataBatchIterator(
        config=config,
        is_train=False,
        dataset="test",
        batch_size=config.batch_size)
    test_data.load()

    #读取训练好的模型
    checkpoint = torch.load(config.save_model+".pt",
                         map_location = config.device)
    #测试并打印评价
    test_valid(checkpoint , config , test_data)
示例#12
0
文件: test.py 项目: tigergeng/TextCNN
import torch
from config import parse_config
from data_loader import DataBatchIterator
from sklearn.metrics import f1_score, precision_score, recall_score
if __name__ == '__main__':
    config = parse_config()
    test_data = DataBatchIterator(config=config,
                                  is_train=False,
                                  dataset="test")
    test_data.load()
    model = torch.load('./results/model.pt')
    model.eval()
    data_iter = iter(test_data)
    count = 0
    score_f1 = 0
    score_precision = 0
    score_recall = 0
    for idx, batch in enumerate(data_iter):
        model.zero_grad()
        truths = batch.label
        outputs = model(batch.sent)
        result = torch.max(outputs, 1)[1]
        y_truth = truths.data.detach().numpy().tolist()
        y_pred = result.view(truths.size()).data.detach().numpy().tolist()
        score_f1 += f1_score(y_truth, y_pred, average='macro')
        score_precision += precision_score(y_truth, y_pred, average='macro')
        score_recall += recall_score(y_truth, y_pred, average='macro')
        count += 1
    size = 8000
    score_f1 = 100.0 * score_f1 / count
    score_precision = 100.0 * score_precision / count