def parse(self, dataset, eval_batch_size=5000): sentences = [] sentence_id_to_idx = {} for i, example in enumerate(dataset): n_words = len(example['word']) - 1 sentence = [j + 1 for j in range(n_words)] sentences.append(sentence) sentence_id_to_idx[id(sentence)] = i model = ModelWrapper(self, dataset, sentence_id_to_idx) dependencies = minibatch_parse(sentences, model, eval_batch_size) UAS = all_tokens = 0.0 for i, ex in enumerate(dataset): head = [-1] * len(ex['word']) for h, t, in dependencies[i]: head[t] = h for pred_h, gold_h, gold_l, pos in \ zip(head[1:], ex['head'][1:], ex['label'][1:], ex['pos'][1:]): assert self.id2tok[pos].startswith(P_PREFIX) pos_str = self.id2tok[pos][len(P_PREFIX):] if (self.with_punct) or (not punct(self.language, pos_str)): UAS += 1 if pred_h == gold_h else 0 all_tokens += 1 UAS /= all_tokens return UAS, dependencies