예제 #1
0
        trainer.learning_rate *= options.learning_rate_decay
        # print trainer.learning_rate

        train_loss = train_loss / len(training_instances)

        # Evaluate dev data
        if options.skip_dev:
            continue
        model.disable_dropout()
        dev_loss = 0.0
        dev_total_instance = 0
        dev_oov_total = 0
        total_wrong = 0
        total_wrong_oov = 0
        # PRF
        prf = utils.CWSEvaluator(t2i)
        prf_dataset = {}
        dev_batch_size = math.ceil(len(dev_instances) * 0.01)
        nbatches = (len(dev_instances) + dev_batch_size - 1) // dev_batch_size
        bar = utils.Progbar(target=nbatches)
        with open("{}/devout-epoch-{:02d}.txt".format(root_dir, epoch + 1),
                  'w') as dev_writer:
            for batch_id, batch in enumerate(
                    utils.minibatches(dev_instances, dev_batch_size)):
                for idx, instance in enumerate(batch):
                    sentence = instance.sentence
                    if len(sentence) == 0: continue

                    gold_tags = instance.tags
                    losses = model.neg_log_loss(sentence, gold_tags)
                    total_loss = losses.scalar_value()
예제 #2
0
파일: main.py 프로젝트: shepherd233/MCCWS
def tester(model, test_batch, write_out=False):
    res = []
    prf = utils.CWSEvaluator(i2t)
    prf_dataset = {}
    oov_dataset = {}

    model.eval()
    for batch_x, batch_y in test_batch:
        with torch.no_grad():
            if bigram_embedding is not None:
                out = model(batch_x["task"], batch_x["uni"],
                            batch_x["seq_len"], batch_x["bi1"], batch_x["bi2"])
            else:
                out = model(batch_x["task"], batch_x["uni"],
                            batch_x["seq_len"])
        out = out["pred"]
        #print(out)
        num = out.size(0)
        out = out.detach().cpu().numpy()
        for i in range(num):
            length = int(batch_x["seq_len"][i])

            out_tags = out[i, 1:length].tolist()
            sentence = batch_x["ori_words"][i]
            gold_tags = batch_y["tags"][i][1:length].numpy().tolist()
            dataset_name = sentence[0]
            sentence = sentence[1:]
            #print(out_tags,gold_tags)
            assert utils.is_dataset_tag(dataset_name)
            assert len(gold_tags) == len(out_tags) and len(gold_tags) == len(
                sentence)

            if dataset_name not in prf_dataset:
                prf_dataset[dataset_name] = utils.CWSEvaluator(i2t)
                oov_dataset[dataset_name] = utils.CWS_OOV(
                    word_dic[dataset_name[1:-1]])

            prf_dataset[dataset_name].add_instance(gold_tags, out_tags)
            prf.add_instance(gold_tags, out_tags)

            if write_out == True:
                gold_strings = utils.to_tag_strings(i2t, gold_tags)
                obs_strings = utils.to_tag_strings(i2t, out_tags)

                word_list = utils.bmes_to_words(sentence, obs_strings)
                oov_dataset[dataset_name].update(
                    utils.bmes_to_words(sentence, gold_strings), word_list)

                raw_string = ' '.join(word_list)
                res.append(dataset_name + " " + raw_string + " " +
                           dataset_name)

    Ap = 0.0
    Ar = 0.0
    Af = 0.0
    Aoov = 0.0
    tot = 0
    nw = 0.0
    for dataset_name, performance in sorted(prf_dataset.items()):
        p = performance.result()
        if write_out == True:
            nw = oov_dataset[dataset_name].oov()
            logger.info('{}\t{:04.2f}\t{:04.2f}\t{:04.2f}\t{:04.2f}'.format(
                dataset_name, p[0], p[1], p[2], nw))
        else:
            logger.info('{}\t{:04.2f}\t{:04.2f}\t{:04.2f}'.format(
                dataset_name, p[0], p[1], p[2]))
        Ap += p[0]
        Ar += p[1]
        Af += p[2]
        Aoov += nw
        tot += 1

    prf = prf.result()
    logger.info('{}\t{:04.2f}\t{:04.2f}\t{:04.2f}'.format(
        'TOT', prf[0], prf[1], prf[2]))
    if write_out == False:
        logger.info('{}\t{:04.2f}\t{:04.2f}\t{:04.2f}'.format(
            'AVG', Ap / tot, Ar / tot, Af / tot))
    else:
        logger.info('{}\t{:04.2f}\t{:04.2f}\t{:04.2f}\t{:04.2f}'.format(
            'AVG', Ap / tot, Ar / tot, Af / tot, Aoov / tot))
    return prf[-1], res