예제 #1
0
def classify(category, src):
    opt = setup_args(category, src)
    model = onmt.Translator_cnn(opt)
    count = 0

    total_correct, total_words, total_loss = 0, 0, 0
    src_batch, tgt_batch = [], []
    outputs, predictions, sents = [], [], []

    for line in tqdm(src, desc="Naturalness"):
        count += 1
        if line is None and len(src_batch) == 0:
            break

        src_batch += [line.split()[:opt.max_sent_length]]
        tgt_batch += [opt.tgt]

        if len(src_batch) < opt.batch_size:
            continue

        n_correct, batch_size, outs, preds = model.translate(
            src_batch, tgt_batch)

        total_correct += n_correct.item()
        total_words += batch_size
        outputs += outs.data.tolist()
        predictions += preds.tolist()

        src_batch, tgt_batch = [], []

    return outputs
예제 #2
0
def main():
    opt = parser.parse_args()
    opt.cuda = opt.gpu > -1
    if opt.cuda:
        torch.cuda.set_device(opt.gpu)

    translator = onmt.Translator_cnn(opt)

    srcBatch, tgtBatch = [], []

    count = 0
    total_correct, total_words, total_loss = 0, 0, 0
    outputs, predictions, sents = [], [], []
    for line in addone(codecs.open(opt.src, "r", "utf-8")):
        count += 1
        if line is not None:
            sents.append(line)
            srcTokens = line.split()
            if len(srcTokens) <= opt.max_sent_length:
                srcBatch += [srcTokens]
            else:
                srcBatch += [srcTokens[:opt.max_sent_length]]
            tgtBatch += [opt.tgt]

            if len(srcBatch) < opt.batch_size:
                continue
        else:
            # at the end of file, check last batch
            if len(srcBatch) == 0:
                break
        num_correct, batchSize, outs, preds = translator.translate(
            srcBatch, tgtBatch)

        total_correct += num_correct.item()
        total_words += batchSize
        outputs += outs.data.tolist()
        predictions += preds.tolist()

        srcBatch, tgtBatch = [], []
        if count % 1000 == 0:
            print('Completed: ', str(count))
            sys.stdout.flush()
    if opt.output:
        with open(opt.output, "w") as outF:
            for i in range(len(sents)):
                outF.write(
                    str(predictions[i]) + "\t" + str(outputs[i]) + "\t" +
                    sents[i])

    print('Accuracy: ', str((total_correct * 100) / total_words))
    print('')
예제 #3
0
def main():
    opt = parser.parse_args()
    opt.cuda = opt.gpu > -1
    if opt.cuda:
        torch.cuda.set_device(opt.gpu)

    translator = onmt.Translator_cnn(opt)

    srcBatch, tgtBatch = [], []
    
    tgt_id = 0 if opt.tgt == opt.label0 else 1

    count = 0
    total_correct, total_words, total_loss = 0, 0, 0
    outputs, predictions, sents = [], [], []

    # load bpe encoder.
    bpe_enc = bpe_encoder.from_dict(translator.src_dict)
    bpe_enc.mute()
    
    max_seq_length = translator.model_opt.sequence_length

    for line in tqdm(addone(codecs.open(opt.src, "r", "utf-8")), total=get_len(opt.src)+1):
        count += 1
        if line is not None:
            sents.append(line)
            # tokenise.
            tokens = [f for f in bpe_enc.transform([line])][0]
#             before = len(tokens)
            tokens = reclip(line, tokens, bpe_enc, max_seq_length-2)
#             after = len(tokens)
            tokens = [SOS] + tokens + [EOS]
#             print("b, a:", before, after)
            # add padding.
            blanks = [Constants.PAD for _ in range(max_seq_length-len(tokens))]
            tokens = tokens + blanks

            srcBatch.append(tokens)


            tgtBatch += [tgt_id]

            if len(srcBatch) < opt.batch_size:
                continue
        else:
            # at the end of file, check last batch
            if len(srcBatch) == 0:
                break

        num_correct, batchSize, outs, preds = translator.translate_bpe(srcBatch, tgtBatch)
 
        total_correct += num_correct.item()
        total_words += batchSize
        outputs += outs.data.tolist()
        predictions += preds.tolist()
    

        srcBatch, tgtBatch = [], []
    if opt.output:
        with open(opt.output, "w") as outF:
            for i in range(len(sents)):
                outF.write(str(predictions[i]) + "\t" + str(outputs[i]) + "\t" + sents[i])

    print('Accuracy: ', str((total_correct*100)/total_words))
    print('')