def parse(self, dataset, eval_batch_size=5000):
        sentences = []
        sentence_id_to_idx = {}
        for i, example in enumerate(dataset):
            n_words = len(example['word']) - 1
            sentence = [j + 1 for j in range(n_words)]
            sentences.append(sentence)
            sentence_id_to_idx[id(sentence)] = i  # id(object) 取object的内存地址

        model = ModelWrapper(self, dataset, sentence_id_to_idx)
        dependencies = minibatch_parse(
            sentences, model,
            eval_batch_size)  # 在minibatch_parse函数中用minibatch的方式执行,最后返回的所有的dep

        UAS = all_tokens = 0.0
        with tqdm(total=len(dataset)) as prog:
            for i, ex in enumerate(dataset):
                head = [-1] * len(ex['word'])
                for h, t, in dependencies[i]:
                    head[t] = h
                for pred_h, gold_h, gold_l, pos in \
                        zip(head[1:], ex['head'][1:], ex['label'][1:], ex['pos'][1:]):
                    assert self.id2tok[pos].startswith(P_PREFIX)
                    pos_str = self.id2tok[pos][len(P_PREFIX):]
                    if (self.with_punct) or (not punct(
                            self.language,
                            pos_str)):  # 是否要把标点符号的预测算上,如果不,判断当前token是不是标点符号
                        UAS += 1 if pred_h == gold_h else 0  # 如果正确预测了当前token的head, +1
                        all_tokens += 1
                prog.update(i + 1)
        UAS /= all_tokens
        return UAS, dependencies
Пример #2
0
    def parse(self, dataset, eval_batch_size=5000):
        sentences = []
        sentence_id_to_idx = {}
        for i, example in enumerate(dataset):
            n_words = len(example['word']) - 1
            sentence = [j + 1 for j in range(n_words)]
            sentences.append(sentence)
            sentence_id_to_idx[id(sentence)] = i

        model = ModelWrapper(self, dataset, sentence_id_to_idx)
        dependencies = minibatch_parse(sentences, model, eval_batch_size)

        UAS = all_tokens = 0.0
        for i, ex in enumerate(dataset):
            head = [-1] * len(ex['word'])
            for h, t, in dependencies[i]:
                head[t] = h
            for pred_h, gold_h, gold_l, pos in \
                    zip(head[1:], ex['head'][1:], ex['label'][1:], ex['pos'][1:]):
                assert self.id2tok[pos].startswith(P_PREFIX)
                pos_str = self.id2tok[pos][len(P_PREFIX):]
                if (self.with_punct) or (not punct(self.language, pos_str)):
                    UAS += 1 if pred_h == gold_h else 0
                    all_tokens += 1
        UAS /= all_tokens
        return UAS, dependencies
def step_impl(context, sentences, batch_size):
    context.result = minibatch_parse(sentences, context.model, batch_size)