Exemplo n.º 1
0
def predict(args):
    device = torch.device(args.device)
    cls_embed = torch.load(args.embedding_save_path)
    model = textCNN(args.embedding_size, args.cls_num, args.l1_channels_num)
    model.load_state_dict(torch.load(args.model_save_path))
    model.eval()
    model = model.to(device)
    cls_embed = cls_embed.to(device)
    cls_embed.eval()
    dev_dataset = DataSet(args.predict_data, None, args.batch_size)
    dev_dataset.reorderForEval()
    towrite = open(args.predict_writeto, "w+")
    towrite.write("idx,labels\n")
    idx = 0
    print("Begin Predict task...")
    while (True):
        (example, dump), p = dev_dataset.getPredictBatch(False)
        example = cls_embed(example, device=device)
        #    print(example.size())
        outs = model(example)
        outs = (torch.argmax(outs, -1) + 1).squeeze().tolist()
        for out in zip(example[1], outs):
            towrite.write("{0},{1}\n".format(out[0], int(out[1])))
            idx += 1
        if (p): break
    towrite.close()
    print("Predict task Done!")
Exemplo n.º 2
0
def predict(args):
    device = torch.device(args.device)
    cls_embed = torch.load(args.embedding_save_path)
    model = textCNN(args.embedding_size, args.cls_num, args.l1_channels_num)
    model.load_state_dict(torch.load(args.model_save_path))
    model.eval()
    model = model.to(device)
    cls_embed = cls_embed.to(device)
    cls_embed.eval()
    dev_dataset = DataSet(args.predict_data, None, args.batch_size)
    towrite = open(args.predict_writeto, "w+")
    towrite.write("idx,labels\n")
    idx = 0
    print("Begin Predict task...")
    while (True):
        example, p = dev_dataset.getPredictBatch()
        example = cls_embed(example, device=device)
        #    print(example.size())
        out = model(example)
        out = torch.argmax(out, -1).item() + 1
        towrite.write("{0},{1}\n".format(idx, out))
        idx += 1
        if (p): break
    towrite.close()
    print("Predict task Done!")