def parse(self, dataset, eval_batch_size=5000): sentences = [] sentence_id_to_idx = {} for i, example in enumerate(dataset): n_words = len(example['word']) - 1 sentence = [j + 1 for j in range(n_words)] sentences.append(sentence) sentence_id_to_idx[id(sentence)] = i # id(object) 取object的内存地址 model = ModelWrapper(self, dataset, sentence_id_to_idx) dependencies = minibatch_parse( sentences, model, eval_batch_size) # 在minibatch_parse函数中用minibatch的方式执行,最后返回的所有的dep UAS = all_tokens = 0.0 with tqdm(total=len(dataset)) as prog: for i, ex in enumerate(dataset): head = [-1] * len(ex['word']) for h, t, in dependencies[i]: head[t] = h for pred_h, gold_h, gold_l, pos in \ zip(head[1:], ex['head'][1:], ex['label'][1:], ex['pos'][1:]): assert self.id2tok[pos].startswith(P_PREFIX) pos_str = self.id2tok[pos][len(P_PREFIX):] if (self.with_punct) or (not punct( self.language, pos_str)): # 是否要把标点符号的预测算上,如果不,判断当前token是不是标点符号 UAS += 1 if pred_h == gold_h else 0 # 如果正确预测了当前token的head, +1 all_tokens += 1 prog.update(i + 1) UAS /= all_tokens return UAS, dependencies
def parse(self, dataset, eval_batch_size=5000): sentences = [] sentence_id_to_idx = {} for i, example in enumerate(dataset): n_words = len(example['word']) - 1 sentence = [j + 1 for j in range(n_words)] sentences.append(sentence) sentence_id_to_idx[id(sentence)] = i model = ModelWrapper(self, dataset, sentence_id_to_idx) dependencies = minibatch_parse(sentences, model, eval_batch_size) UAS = all_tokens = 0.0 for i, ex in enumerate(dataset): head = [-1] * len(ex['word']) for h, t, in dependencies[i]: head[t] = h for pred_h, gold_h, gold_l, pos in \ zip(head[1:], ex['head'][1:], ex['label'][1:], ex['pos'][1:]): assert self.id2tok[pos].startswith(P_PREFIX) pos_str = self.id2tok[pos][len(P_PREFIX):] if (self.with_punct) or (not punct(self.language, pos_str)): UAS += 1 if pred_h == gold_h else 0 all_tokens += 1 UAS /= all_tokens return UAS, dependencies
def step_impl(context, sentences, batch_size): context.result = minibatch_parse(sentences, context.model, batch_size)