def fb_on_batch(self, annotated_insts: List[ParseInstance], training=True, loss_factor=1., **kwargs): self.refresh_batch(training) # encode input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run( annotated_insts, training) mask_expr = BK.input_real(mask_arr) # the parsing loss arc_score = self.scorer_helper.score_arc(enc_repr) lab_score = self.scorer_helper.score_label(enc_repr) full_score = arc_score + lab_score parsing_loss, info = self._loss(annotated_insts, full_score, mask_expr) # other loss? jpos_loss = self.jpos_loss(jpos_pack, mask_expr) reg_loss = self.reg_scores_loss(arc_score, lab_score) # info["loss_parse"] = BK.get_value(parsing_loss).item() final_loss = parsing_loss if jpos_loss is not None: info["loss_jpos"] = BK.get_value(jpos_loss).item() final_loss = parsing_loss + jpos_loss if reg_loss is not None: final_loss = final_loss + reg_loss info["fb"] = 1 if training: BK.backward(final_loss, loss_factor) return info
def fb_on_batch(self, annotated_insts: List[ParseInstance], training=True, loss_factor=1., **kwargs): self.refresh_batch(training) # todo(note): here always using training lambdas full_score, original_scores, jpos_pack, mask_expr, valid_mask_d, _ = \ self._score(annotated_insts, False, self.lambda_g1_arc_training, self.lambda_g1_lab_training) parsing_loss, info = self._loss(annotated_insts, full_score, mask_expr, valid_mask_d) # other loss? jpos_loss = self.jpos_loss(jpos_pack, mask_expr) reg_loss = self.reg_scores_loss(*original_scores) # info["loss_parse"] = BK.get_value(parsing_loss).item() final_loss = parsing_loss if jpos_loss is not None: info["loss_jpos"] = BK.get_value(jpos_loss).item() final_loss = parsing_loss + jpos_loss if reg_loss is not None: final_loss = final_loss + reg_loss info["fb"] = 1 if training: BK.backward(final_loss, loss_factor) return info
def _inference_mentions(self, insts: List[Sentence], lexi_repr, enc_repr, mask_expr, extractor: NodeExtractorBase, item_creator): sel_logprobs, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds = \ extractor.predict(insts, lexi_repr, enc_repr, mask_expr) # handling outputs here: prepare new items head_idxes_arr = BK.get_value(sel_idxes) # [*, max-count] lab_idxes_arr = BK.get_value(sel_lab_idxes) # [*, max-count] logprobs_arr = BK.get_value(sel_logprobs) # [*, max-count] valid_arr = BK.get_value(sel_valid_mask) # [*, max-count] all_items = [] bsize, mc = valid_arr.shape for one_idxes, one_valids, one_lab_idxes, one_logprobs, one_sent in \ zip(head_idxes_arr, valid_arr, lab_idxes_arr, logprobs_arr, insts): sid = one_sent.sid partial_id0 = f"{one_sent.doc.doc_id}-s{one_sent.sid}-i" for this_i in range(mc): this_valid = float(one_valids[this_i]) if this_valid == 0: # must be compact assert np.all(one_valids[this_i:]==0.) all_items.extend([None] * (mc-this_i)) break # todo(note): we need to assign various info at the outside this_mention = Mention(HardSpan(sid, int(one_idxes[this_i]), None, None)) # todo(note): where to filter None? this_hlidx = extractor.idx2hlidx(one_lab_idxes[this_i]) all_items.append(item_creator(partial_id0+str(this_i), this_mention, this_hlidx, float(one_logprobs[this_i]))) # only return the items and the ones useful for later steps: List(sent)[List(items)], *[*, max-count] ret_items = np.asarray(all_items, dtype=object).reshape((bsize, mc)) return ret_items, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds
def fb_on_batch(self, annotated_insts: List[ParseInstance], training=True, loss_factor=1., **kwargs): self.refresh_batch(training) # pruning and scores from g1 valid_mask, go1_pack = self._get_g1_pack(annotated_insts, self.lambda_g1_arc_training, self.lambda_g1_lab_training) # encode input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run( annotated_insts, training) mask_expr = BK.input_real(mask_arr) # the parsing loss final_valid_expr = self._make_final_valid(valid_mask, mask_expr) parsing_loss, parsing_scores, info = \ self.dl.loss(annotated_insts, enc_repr, final_valid_expr, go1_pack, True, self.margin.value) info["loss_parse"] = BK.get_value(parsing_loss).item() final_loss = parsing_loss # other loss? jpos_loss = self.jpos_loss(jpos_pack, mask_expr) if jpos_loss is not None: info["loss_jpos"] = BK.get_value(jpos_loss).item() final_loss = parsing_loss + jpos_loss if parsing_scores is not None: reg_loss = self.reg_scores_loss(*parsing_scores) if reg_loss is not None: final_loss = final_loss + reg_loss info["fb"] = 1 if training: BK.backward(final_loss, loss_factor) return info
def main(args): conf, model, vpack, test_iter = prepare_test(args) dconf = conf.dconf # todo(note): here is the main change # make sure the model is order 1 graph model, otherwise cannot run through all_results = [] all_insts = [] with utils.Timer(tag="Run-score", info="", print_date=True): for cur_insts in test_iter: all_insts.extend(cur_insts) batched_arc_scores, batched_label_scores = model.score_on_batch( cur_insts) batched_arc_scores, batched_label_scores = BK.get_value( batched_arc_scores), BK.get_value(batched_label_scores) for cur_idx in range(len(cur_insts)): cur_len = len(cur_insts[cur_idx]) + 1 # discarding paddings cur_res = (batched_arc_scores[cur_idx, :cur_len, :cur_len], batched_label_scores[cur_idx, :cur_len, :cur_len]) all_results.append(cur_res) # reorder to the original order orig_indexes = [z.inst_idx for z in all_insts] orig_results = [None] * len(orig_indexes) for new_idx, orig_idx in enumerate(orig_indexes): assert orig_results[orig_idx] is None orig_results[orig_idx] = all_results[new_idx] # saving with utils.Timer(tag="Run-write", info=f"Writing to {dconf.output_file}", print_date=True): import pickle with utils.zopen(dconf.output_file, "wb") as fd: for one in orig_results: pickle.dump(one, fd) utils.printing("The end.")
def _inference_args(self, ef_items, ef_widxes, ef_valid_mask, ef_lab_idxes, enc_repr_ef, evt_items, evt_widxes, evt_valid_mask, evt_lab_idxes, enc_repr_evt): arg_linker = self.arg_linker repr_ef = BK.gather_first_dims(enc_repr_ef, ef_widxes, -2) # [*, len-ef, D] repr_evt = BK.gather_first_dims(enc_repr_evt, evt_widxes, -2) # [*, len-evt, D] role_logprobs, role_predictions = arg_linker.predict(repr_ef, repr_evt, ef_lab_idxes, evt_lab_idxes, ef_valid_mask, evt_valid_mask) # add them inplaced roles_arr = BK.get_value(role_predictions) # [*, len-ef, len-evt] logprobs_arr = BK.get_value(role_logprobs) for bidx, one_roles_arr in enumerate(roles_arr): one_ef_items, one_evt_items = ef_items[bidx], evt_items[bidx] # ===== # todo(note): delete origin links! for z in one_ef_items: if z is not None: z.links.clear() for z in one_evt_items: if z is not None: z.links.clear() # ===== one_logprobs = logprobs_arr[bidx] for ef_idx, one_ef in enumerate(one_ef_items): if one_ef is None: continue for evt_idx, one_evt in enumerate(one_evt_items): if one_evt is None: continue one_role_idx = int(one_roles_arr[ef_idx, evt_idx]) if one_role_idx > 0: # link this_hlidx = arg_linker.idx2hlidx(one_role_idx) one_evt.add_arg(one_ef, role=str(this_hlidx), role_idx=this_hlidx, score=float(one_logprobs[ef_idx, evt_idx]))
def lookup(self, insts: List, input_lexi, input_expr, input_mask): bsize = len(insts) # get gold or pre-set ones, again [*, slen, L] -> [*, mc] gold_masks, _, gold_items_arr, gold_valid = self.batch_inputs_g0(insts) sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds = self._pmask2idxes( gold_masks) ret_items = gold_items_arr[np.arange(bsize)[:, np.newaxis], BK.get_value(sel_idxes), BK.get_value(sel_lab_idxes)] return ret_items, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds
def lookup(self, insts: List, input_lexi, input_expr, input_mask): conf = self.conf bsize = len(insts) # first get gold/input info, also multiple valid-masks gold_masks, gold_idxes, gold_items_arr, gold_valid, gold_idxes2, gold_items2_arr = self.batch_inputs_h( insts) # step 1: no selection, simply forward using gold_masks sel_idxes, sel_valid_mask = BK.mask2idx(gold_masks) # [*, max-count] sel_gold_idxes = gold_idxes.gather(-1, sel_idxes) sel_gold_idxes2 = gold_idxes2.gather(-1, sel_idxes) # todo(+N): only get items by head position! _tmp_i0, _tmp_i1 = np.arange(bsize)[:, np.newaxis], BK.get_value( sel_idxes) sel_items = gold_items_arr[_tmp_i0, _tmp_i1] # [*, mc] sel2_items = gold_items2_arr[_tmp_i0, _tmp_i1] # step 2: encoding and labeling sel_shape = BK.get_shape(sel_idxes) if sel_shape[-1] == 0: sel_lab_idxes = sel_gold_idxes sel_lab_embeds = BK.zeros(sel_shape + [conf.lab_conf.n_dim]) ret_items = sel_items # dim-1==0 else: # sel_hid_exprs = self._enc(input_expr, input_mask, sel_idxes) # [*, mc, DLab] sel_lab_idxes = sel_gold_idxes sel_lab_embeds = self.hl.lookup( sel_lab_idxes) # todo(note): here no softlookup? ret_items = sel_items # second type if self.use_secondary_type: sel2_lab_idxes = sel_gold_idxes2 sel2_lab_embeds = self.hl.lookup( sel2_lab_idxes) # todo(note): here no softlookup? sel2_valid_mask = (sel2_lab_idxes > 0).float() # combine the two if sel2_lab_idxes.sum().item( ) > 0: # if there are any gold sectypes ret_items = np.concatenate([ret_items, sel2_items], -1) # [*, mc*2] sel_idxes = BK.concat([sel_idxes, sel_idxes], -1) sel_valid_mask = BK.concat( [sel_valid_mask, sel2_valid_mask], -1) sel_lab_idxes = BK.concat([sel_lab_idxes, sel2_lab_idxes], -1) sel_lab_embeds = BK.concat( [sel_lab_embeds, sel2_lab_embeds], -2) # step 3: exclude nil assuming no deliberate nil in gold/inputs if conf.exclude_nil: # [*, mc', ...] sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, _, ret_items = \ self._exclude_nil(sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_items_arr=ret_items) # step 4: return # sel_enc_expr = BK.gather_first_dims(input_expr, sel_idxes, -2) # [*, mc', D] # mask out invalid items with None ret_items[BK.get_value(1. - sel_valid_mask).astype(np.bool)] = None return ret_items, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds
def decode_one(self, slen: int, projective: bool, arr_o1_masks, arr_o1_scores, input_pack, cur_bidx_mask): m_idxes, h_idxes, _, _, final_scores = input_pack if arr_o1_scores is None: arr_o1_scores = np.full([slen, slen], 0., dtype=np.double) # direct add to the scores m_idxes, h_idxes, final_scores = m_idxes[cur_bidx_mask], h_idxes[ cur_bidx_mask], final_scores[cur_bidx_mask] arr_o1_scores[BK.get_value(m_idxes), BK.get_value(h_idxes)] += BK.get_value(final_scores) return hop_decode(slen, projective, arr_o1_masks, arr_o1_scores, None, None, None)
def _pred_and_put_res(self, predictor, hidden_t, evt_arr, put_f): logits = predictor(hidden_t) # [bsize, ?, Out] log_probs = BK.log_softmax(logits, -1) max_log_probs, max_label_idxes = log_probs.max( -1) # [bs, ?], simply argmax prediction max_log_probs_arr, max_label_idxes_arr = BK.get_value( max_log_probs), BK.get_value(max_label_idxes) for evt_row, lprob_row, lidx_row in zip(evt_arr, max_log_probs_arr, max_label_idxes_arr): for one_evt, one_lprob, one_lidx in zip(evt_row, lprob_row, lidx_row): if one_evt is not None: put_f(one_evt, one_lprob, one_lidx) # callback for inplace setting
def _new_states(self, flattened_states: List[EfState], scoring_mask_ct, topk_arc_scores, topk_m, topk_h, topk_label_scores, topk_label_idxes): topk_arc_scores, topk_m, topk_h, topk_label_scores, topk_label_idxes = \ (BK.get_value(z) for z in (topk_arc_scores, topk_m, topk_h, topk_label_scores, topk_label_idxes)) new_states = [] # for each batch element for one_state, one_mask, one_arc_scores, one_ms, one_hs, one_label_scores, one_labels in \ zip(flattened_states, scoring_mask_ct, topk_arc_scores, topk_m, topk_h, topk_label_scores, topk_label_idxes): one_new_states = [] # for each of the k arc selection for cur_arc_score, cur_m, cur_h, cur_label_scores, cur_labels in \ zip(one_arc_scores, one_ms, one_hs, one_label_scores, one_labels): # first need that selection to be valid cur_arc_score, cur_m, cur_h = cur_arc_score.item(), cur_m.item( ), cur_h.item() if one_mask[cur_m, cur_h].item() > 0.: # for each of the label for this_label_score, this_label in zip( cur_label_scores, cur_labels): this_label_score, this_label = this_label_score.item( ), this_label.item() # todo(note): actually add new state; do not include label score if label does not come from ef cur_all_score = ( cur_arc_score + this_label_score ) if self.system_labeled else cur_arc_score this_new_state = one_state.build_next( action=EfAction(cur_h, cur_m, this_label), score=cur_all_score) one_new_states.append(this_new_state) new_states.append(one_new_states) return new_states
def _emb_and_enc(self, cur_input_map: Dict, collect_loss: bool, insts=None): conf = self.conf # ----- # special mode if conf.aug_word2 and conf.aug_word2_aug_encoder: _rop = RefreshOptions(training=False) # special feature-mode!! self.embedder.refresh(_rop) self.encoder.refresh(_rop) # ----- emb_t, mask_t = self.embedder(cur_input_map) rel_dist = cur_input_map.get("rel_dist", None) if rel_dist is not None: rel_dist = BK.input_idx(rel_dist) if conf.enc_choice == "vrec": enc_t, cache, enc_loss = self.encoder(emb_t, src_mask=mask_t, rel_dist=rel_dist, collect_loss=collect_loss) elif conf.enc_choice == "original": # todo(note): change back to arr for back compatibility assert rel_dist is None, "Original encoder does not support rel_dist" enc_t = self.encoder(emb_t, BK.get_value(mask_t)) cache, enc_loss = None, None else: raise NotImplementedError() # another encoder based on attn final_enc_t = self.rpreper(emb_t, enc_t, cache) # [*, slen, D] => final encoder output if conf.aug_word2: emb2_t = self.aug_word2(insts) if conf.aug_word2_aug_encoder: # simply add them all together, detach orig-enc as features stack_hidden_t = BK.stack(cache.list_hidden[-conf.aug_detach_numlayer:], -2).detach() features = self.aug_mixturer(stack_hidden_t) aug_input = (emb2_t + conf.aug_detach_ratio*self.aug_detach_drop(features)) final_enc_t, cache, enc_loss = self.aug_encoder(aug_input, src_mask=mask_t, rel_dist=rel_dist, collect_loss=collect_loss) else: final_enc_t = (final_enc_t + emb2_t) # otherwise, simply adding return emb_t, mask_t, final_enc_t, cache, enc_loss
def nmst_greedy(scores_expr, mask_expr, lengths_arr, labeled=True, ret_arr=False): assert labeled with BK.no_grad_env(): scores_shape = BK.get_shape(scores_expr) maxlen = scores_shape[1] # mask out diag scores_expr += BK.diagflat( BK.constants([maxlen], Constants.REAL_PRAC_MIN)).unsqueeze(-1) # combined last two dimension and Max over them combined_scores_expr = scores_expr.view(scores_shape[:-2] + [-1]) combine_max_scores, combined_max_idxes = BK.max(combined_scores_expr, dim=-1) # back to real idxes last_size = scores_shape[-1] greedy_heads = combined_max_idxes // last_size greedy_labels = combined_max_idxes % last_size if ret_arr: mst_heads_arr, mst_labels_arr, mst_scores_arr = [ BK.get_value(z) for z in (greedy_heads, greedy_labels, combine_max_scores) ] return mst_heads_arr, mst_labels_arr, mst_scores_arr else: return greedy_heads, greedy_labels, combine_max_scores
def jpos_decode(self, insts: List[ParseInstance], jpos_pack): # jpos prediction (directly index, no converting) jpos_preds_expr = jpos_pack[2] if jpos_preds_expr is not None: jpos_preds_arr = BK.get_value(jpos_preds_expr) for one_idx, one_inst in enumerate(insts): cur_length = len(one_inst) + 1 # including the artificial ROOT one_inst.pred_poses.build_vals( jpos_preds_arr[one_idx][:cur_length], self.bter.pos_vocab)
def predict(self, ms_items: List, bert_expr): conf = self.conf bsize = len(ms_items) # collect instances col_efs, col_sents, col_bidxes_t, col_hidxes_t, _, _ = self._collect_insts( ms_items, False) if len(col_efs) == 0: return left_scores, right_scores = self._score(bert_expr, col_bidxes_t, col_hidxes_t) if conf.use_binary_scorer: lscores_arr, rscores_arr = BK.get_value(left_scores), BK.get_value( right_scores) # for one_ef, one_sent, one_lscores, one_rscores in zip( col_efs, col_sents, lscores_arr, rscores_arr): one_ldist, one_rdist = self._binary_decide_dist( one_lscores), self._binary_decide_dist(one_rscores) # set span hspan = one_ef.mention.hard_span sid, head_wid = hspan.sid, hspan.head_wid left_wid = max(1, head_wid - one_ldist) # not the artificial root right_wid = min(one_sent.length - 1, head_wid + one_rdist) hspan.wid = left_wid hspan.length = right_wid - left_wid + 1 else: # simply pick max _, left_max_dist = left_scores.max(-1) _, right_max_dist = right_scores.max(-1) lmax_arr, rmax_arr = BK.get_value(left_max_dist), BK.get_value( right_max_dist) # for one_ef, one_sent, one_ldist, one_rdist in zip( col_efs, col_sents, lmax_arr, rmax_arr): one_ldist, one_rdist = int(one_ldist), int(one_rdist) # set span hspan = one_ef.mention.hard_span sid, head_wid = hspan.sid, hspan.head_wid left_wid = max(1, head_wid - one_ldist) # not the artificial root right_wid = min(one_sent.length - 1, head_wid + one_rdist) hspan.wid = left_wid hspan.length = right_wid - left_wid + 1
def collect_pruning_info(insts: List[ParseInstance], valid_mask_f): # two dimensions: coverage and pruning-effect maxlen = BK.get_shape(valid_mask_f, -1) # 1. coverage valid_mask_f_flattened = valid_mask_f.view([-1, maxlen]) # [bs*len, len] cur_mod_base = 0 all_mods, all_heads = [], [] for cur_idx, cur_inst in enumerate(insts): for m, h in enumerate(cur_inst.heads.vals[1:], 1): all_mods.append(m + cur_mod_base) all_heads.append(h) cur_mod_base += maxlen cov_count = len(all_mods) cov_valid = BK.get_value( valid_mask_f_flattened[all_mods, all_heads].sum()).item() # 2. pruning-rate # todo(warn): to speed up, these stats are approximate because of including paddings # edges pr_edges = int(np.prod(BK.get_shape(valid_mask_f))) pr_edges_valid = BK.get_value(valid_mask_f.sum()).item() # valid as structured heads pr_o2_sib = pr_o2_g = pr_edges pr_o3_gsib = maxlen * pr_edges valid_chs_counts, valid_par_counts = valid_mask_f.sum( -2), valid_mask_f.sum(-1) # [*, len] valid_gsibs = valid_chs_counts * valid_par_counts pr_o2_sib_valid = BK.get_value(valid_chs_counts.sum()).item() pr_o2_g_valid = BK.get_value(valid_par_counts.sum()).item() pr_o3_gsib_valid = BK.get_value(valid_gsibs.sum()).item() return { "cov_count": cov_count, "cov_valid": cov_valid, "pr_edges": pr_edges, "pr_edges_valid": pr_edges_valid, "pr_o2_sib": pr_o2_sib, "pr_o2_g": pr_o2_g, "pr_o3_gsib": pr_o3_gsib, "pr_o2_sib_valid": pr_o2_sib_valid, "pr_o2_g_valid": pr_o2_g_valid, "pr_o3_gsib_valid": pr_o3_gsib_valid }
def decode_one(self, slen: int, projective: bool, arr_o1_masks, arr_o1_scores, input_pack, cur_bidx_mask): m_idxes, h_idxes, sib_idxes, gp_idxes, final_scores = input_pack o3gsib_pack = [ m_idxes[cur_bidx_mask].int(), h_idxes[cur_bidx_mask].int(), sib_idxes[cur_bidx_mask].int(), gp_idxes[cur_bidx_mask].int(), final_scores[cur_bidx_mask].double() ] o3gsib_arr_pack = [BK.get_value(z) for z in o3gsib_pack] return hop_decode(slen, projective, arr_o1_masks, arr_o1_scores, None, None, o3gsib_arr_pack)
def _common_nmst(CPU_f, scores_expr, mask_expr, lengths_arr, labeled, ret_arr): assert labeled with BK.no_grad_env(): # argmax-label: [BS, m, h] scores_unlabeled_max, labels_argmax = scores_expr.max(-1) # scores_unlabeled_max_arr = BK.get_value(scores_unlabeled_max) mst_heads_arr, _, mst_scores_arr = CPU_f(scores_unlabeled_max_arr, lengths_arr, labeled=False) # [BS, m] mst_heads_expr = BK.input_idx(mst_heads_arr) mst_labels_expr = BK.gather_one_lastdim(labels_argmax, mst_heads_expr).squeeze(-1) # prepare for the outputs if ret_arr: return mst_heads_arr, BK.get_value(mst_labels_expr), mst_scores_arr else: return mst_heads_expr, mst_labels_expr, BK.input_real( mst_scores_arr)
def nmarginal_proj(scores_expr, mask_expr, lengths_arr, labeled=True): assert labeled with BK.no_grad_env(): # first make it unlabeled by sum-exp scores_unlabeled = BK.logsumexp(scores_expr, dim=-1) # [BS, m, h] # marginal for unlabeled scores_unlabeled_arr = BK.get_value(scores_unlabeled) marginals_unlabeled_arr = marginal_proj(scores_unlabeled_arr, lengths_arr, False) # back to labeled values marginals_unlabeled_expr = BK.input_real(marginals_unlabeled_arr) marginals_labeled_expr = marginals_unlabeled_expr.unsqueeze( -1) * BK.exp(scores_expr - scores_unlabeled.unsqueeze(-1)) # [BS, m, h, L] return _ensure_margins_norm(marginals_labeled_expr)
def fb_on_batch(self, annotated_insts: List[ParseInstance], training=True, loss_factor=1., **kwargs): self.refresh_batch(training) # encode input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(annotated_insts, training) mask_expr = BK.input_real(mask_arr) # g1 score g1_pack = self._get_g1_pack(annotated_insts, self.lambda_g1_arc_training, self.lambda_g1_lab_training) # the parsing loss parsing_loss, parsing_scores, info = self.losser.loss(annotated_insts, enc_repr, mask_arr, g1_pack) # whether add jpos loss? jpos_loss = self.jpos_loss(jpos_pack, mask_expr) # no_loss = True final_loss = 0. if parsing_loss is None: info["loss_parse"] = 0. else: final_loss = final_loss + parsing_loss info["loss_parse"] = BK.get_value(parsing_loss).item() no_loss = False if jpos_loss is None: info["loss_jpos"] = 0. else: final_loss = final_loss + jpos_loss info["loss_jpos"] = BK.get_value(jpos_loss).item() no_loss = False if parsing_scores is not None: arc_scores, lab_scores = parsing_scores reg_loss = self.reg_scores_loss(arc_scores, lab_scores) if reg_loss is not None: final_loss = final_loss + reg_loss info["fb"] = 1 if training and not no_loss: info["fb_back"] = 1 BK.backward(final_loss, loss_factor) return info
def predict(self, insts: List[GeneralSentence], repr_t, mask_t, **kwargs): conf = self.conf # score scores_t = self._score(repr_t) # [bs, ?+rlen, D] _, argmax_idxes = scores_t.max(-1) # [bs, ?+rlen] argmax_idxes_arr = BK.get_value(argmax_idxes) # [bs, ?+rlen] # assign; todo(+2): record scores? one_offset = int(self.add_root_token) for one_bidx, one_inst in enumerate(insts): one_pidxes = argmax_idxes_arr[one_bidx, one_offset:one_offset + len(one_inst)].tolist() one_pseq = SeqField(None) one_pseq.build_vals(one_pidxes, self.vocab) one_inst.add_item("pred_" + self.attr_name, one_pseq, assert_non_exist=False) return
def inference_on_batch(self, insts: List[ParseInstance], **kwargs): # iconf = self.conf.iconf with BK.no_grad_env(): self.refresh_batch(False) # pruning and scores from g1 valid_mask, go1_pack = self._get_g1_pack( insts, self.lambda_g1_arc_testing, self.lambda_g1_lab_testing) # encode input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run( insts, False) mask_expr = BK.input_real(mask_arr) # decode final_valid_expr = self._make_final_valid(valid_mask, mask_expr) ret_heads, ret_labels, _, _ = self.dl.decode( insts, enc_repr, final_valid_expr, go1_pack, False, 0.) # collect the results together all_heads = Helper.join_list(ret_heads) if ret_labels is None: # todo(note): simply get labels from the go1-label classifier; must provide g1parser if go1_pack is None: _, go1_pack = self._get_g1_pack(insts, 1., 1.) _, go1_label_max_idxes = go1_pack[1].max( -1) # [bs, slen, slen] pred_heads_arr, _ = self.predict_padder.pad( all_heads) # [bs, slen] pred_heads_expr = BK.input_idx(pred_heads_arr) pred_labels_expr = BK.gather_one_lastdim( go1_label_max_idxes, pred_heads_expr).squeeze(-1) all_labels = BK.get_value(pred_labels_expr) # [bs, slen] else: all_labels = np.concatenate(ret_labels, 0) # ===== assign, todo(warn): here, the labels are directly original idx, no need to change for one_idx, one_inst in enumerate(insts): cur_length = len(one_inst) + 1 one_inst.pred_heads.set_vals( all_heads[one_idx] [:cur_length]) # directly int-val for heads one_inst.pred_labels.build_vals( all_labels[one_idx][:cur_length], self.label_vocab) # one_inst.pred_par_scores.set_vals(all_scores[one_idx][:cur_length]) # ===== # put jpos result (possibly) self.jpos_decode(insts, jpos_pack) # ----- info = {"sent": len(insts), "tok": sum(map(len, insts))} return info
def predict(self, insts: List[GeneralSentence], repr_t, mask_t, **kwargs): conf = self.conf if self.add_root_token: repr_t = repr_t[:, 1:] mask_t = mask_t[:, 1:] # score scores_t = self._score(repr_t) # [bs, rlen, D] # decode _, decode_idx = self._viterbi_decode(scores_t, mask_t.bool()) decode_idx_arr = BK.get_value(decode_idx) # [bs, rlen] for one_bidx, one_inst in enumerate(insts): one_pidxes = decode_idx_arr[one_bidx].tolist()[:len(one_inst)] one_pseq = SeqField(None) one_pseq.build_vals(one_pidxes, self.vocab) one_inst.add_item("pred_" + self.attr_name, one_pseq, assert_non_exist=False) return
def _exclude_nil(self, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_logprobs=None, sel_items_arr=None): # todo(note): assure that nil is 0 sel_valid_mask = sel_valid_mask * (sel_lab_idxes != 0).float() # not inplaced # idx on idx s2_idxes, s2_valid_mask = BK.mask2idx(sel_valid_mask) sel_idxes = sel_idxes.gather(-1, s2_idxes) sel_valid_mask = s2_valid_mask sel_lab_idxes = sel_lab_idxes.gather(-1, s2_idxes) sel_lab_embeds = BK.gather_first_dims(sel_lab_embeds, s2_idxes, -2) sel_logprobs = None if sel_logprobs is None else sel_logprobs.gather( -1, s2_idxes) sel_items_arr = None if sel_items_arr is None \ else sel_items_arr[np.arange(len(sel_items_arr))[:, np.newaxis], BK.get_value(s2_idxes)] return sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_logprobs, sel_items_arr
def _decode(self, insts: List[ParseInstance], full_score, mask_expr, misc_prefix): # decode mst_lengths = [len(z) + 1 for z in insts] # +=1 to include ROOT for mst decoding mst_lengths_arr = np.asarray(mst_lengths, dtype=np.int32) mst_heads_arr, mst_labels_arr, mst_scores_arr = nmst_unproj( full_score, mask_expr, mst_lengths_arr, labeled=True, ret_arr=True) if self.conf.iconf.output_marginals: # todo(note): here, we care about marginals for arc # lab_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True) arc_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True).sum(-1) bsize, max_len = BK.get_shape(mask_expr) idxes_bs_expr = BK.arange_idx(bsize).unsqueeze(-1) idxes_m_expr = BK.arange_idx(max_len).unsqueeze(0) output_marg = arc_marginals[idxes_bs_expr, idxes_m_expr, BK.input_idx(mst_heads_arr)] mst_marg_arr = BK.get_value(output_marg) else: mst_marg_arr = None # ===== assign, todo(warn): here, the labels are directly original idx, no need to change for one_idx, one_inst in enumerate(insts): cur_length = mst_lengths[one_idx] one_inst.pred_heads.set_vals( mst_heads_arr[one_idx] [:cur_length]) # directly int-val for heads one_inst.pred_labels.build_vals( mst_labels_arr[one_idx][:cur_length], self.label_vocab) one_scores = mst_scores_arr[one_idx][:cur_length] one_inst.pred_par_scores.set_vals(one_scores) # extra output one_inst.extra_pred_misc[misc_prefix + "_score"] = one_scores.tolist() if mst_marg_arr is not None: one_inst.extra_pred_misc[ misc_prefix + "_marg"] = mst_marg_arr[one_idx][:cur_length].tolist()
def _assign_attns_item(self, insts, prefix, input_erase_mask_arr=None, abs_posi_arr=None, cache=None): if cache is not None: attn_names, attn_list = [], [] for one_sidx, one_attn in enumerate(cache.list_attn): attn_names.append(f"{prefix}_att{one_sidx}") attn_list.append(one_attn) if cache.accu_attn is not None: attn_names.append(f"{prefix}_att_accu") attn_list.append(cache.accu_attn) for one_name, one_attn in zip(attn_names, attn_list): # (step_idx, ) -> [bs, len_q, len_k, head] one_attn_arr = BK.get_value(one_attn) for bidx, inst in enumerate(insts): save_arr = one_attn_arr[bidx] inst.add_item(one_name, NpArrField(save_arr, float_decimal=4), assert_non_exist=False) if abs_posi_arr is not None: for bidx, inst in enumerate(insts): inst.add_item(f"{prefix}_abs_posi", NpArrField(abs_posi_arr[bidx], float_decimal=0), assert_non_exist=False) if input_erase_mask_arr is not None: for bidx, inst in enumerate(insts): inst.add_item(f"{prefix}_erase_mask", NpArrField(input_erase_mask_arr[bidx], float_decimal=4), assert_non_exist=False)
def loss(self, insts: List[ParseInstance], enc_expr, final_valid_expr, go1_pack, training: bool, margin: float): # first do decoding and related preparation with BK.no_grad_env(): _, _, g_packs, p_packs = self.decode(insts, enc_expr, final_valid_expr, go1_pack, training, margin) # flatten the packs (remember to rebase the indexes) gold_pack = self._flatten_packs(g_packs) pred_pack = self._flatten_packs(p_packs) if self.filter_pruned: # filter out non-valid (pruned) edges, to avoid prune error mod_unpruned_mask, gold_mask = self.helper.get_unpruned_mask( final_valid_expr, gold_pack) pred_mask = mod_unpruned_mask[ pred_pack[0], pred_pack[1]] # filter by specific mod gold_pack = [(None if z is None else z[gold_mask]) for z in gold_pack] pred_pack = [(None if z is None else z[pred_mask]) for z in pred_pack] # calculate the scores for loss gold_b_idxes, gold_m_idxes, gold_h_idxes, gold_sib_idxes, gold_gp_idxes, gold_lab_idxes = gold_pack pred_b_idxes, pred_m_idxes, pred_h_idxes, pred_sib_idxes, pred_gp_idxes, pred_lab_idxes = pred_pack gold_arc_score, gold_label_score_all = self._get_basic_score( enc_expr, gold_b_idxes, gold_m_idxes, gold_h_idxes, gold_sib_idxes, gold_gp_idxes) pred_arc_score, pred_label_score_all = self._get_basic_score( enc_expr, pred_b_idxes, pred_m_idxes, pred_h_idxes, pred_sib_idxes, pred_gp_idxes) # whether have labeled scores if self.system_labeled: gold_label_score = BK.gather_one_lastdim( gold_label_score_all, gold_lab_idxes).squeeze(-1) pred_label_score = BK.gather_one_lastdim( pred_label_score_all, pred_lab_idxes).squeeze(-1) ret_scores = (gold_arc_score, pred_arc_score, gold_label_score, pred_label_score) pred_full_scores, gold_full_scores = pred_arc_score + pred_label_score, gold_arc_score + gold_label_score else: ret_scores = (gold_arc_score, pred_arc_score) pred_full_scores, gold_full_scores = pred_arc_score, gold_arc_score # hinge loss: filter-margin by loss*margin to be aware of search error if self.filter_margin: with BK.no_grad_env(): mat_shape = BK.get_shape(enc_expr)[:2] # [bs, slen] heads_gold = self._get_tmp_mat(mat_shape, 0, BK.int64, gold_b_idxes, gold_m_idxes, gold_h_idxes) heads_pred = self._get_tmp_mat(mat_shape, 0, BK.int64, pred_b_idxes, pred_m_idxes, pred_h_idxes) error_count = (heads_gold != heads_pred).float() if self.system_labeled: labels_gold = self._get_tmp_mat(mat_shape, 0, BK.int64, gold_b_idxes, gold_m_idxes, gold_lab_idxes) labels_pred = self._get_tmp_mat(mat_shape, 0, BK.int64, pred_b_idxes, pred_m_idxes, pred_lab_idxes) error_count += (labels_gold != labels_pred).float() scores_gold = self._get_tmp_mat(mat_shape, 0., BK.float32, gold_b_idxes, gold_m_idxes, gold_full_scores) scores_pred = self._get_tmp_mat(mat_shape, 0., BK.float32, pred_b_idxes, pred_m_idxes, pred_full_scores) # todo(note): here, a small 0.1 is to exclude zero error: anyway they will get zero gradient sent_mask = ((scores_gold.sum(-1) - scores_pred.sum(-1)) <= (margin * error_count.sum(-1) + 0.1)).float() num_valid_sent = float(BK.get_value(sent_mask.sum())) final_loss_sum = ( pred_full_scores * sent_mask[pred_b_idxes] - gold_full_scores * sent_mask[gold_b_idxes]).sum() else: num_valid_sent = len(insts) final_loss_sum = (pred_full_scores - gold_full_scores).sum() # prepare final loss # divide loss by what? num_sent = len(insts) num_valid_tok = sum(len(z) for z in insts) if self.loss_div_tok: final_loss = final_loss_sum / num_valid_tok else: final_loss = final_loss_sum / num_sent final_loss_sum_val = float(BK.get_value(final_loss_sum)) info = { "sent": num_sent, "sent_valid": num_valid_sent, "tok": num_valid_tok, "loss_sum": final_loss_sum_val } return final_loss, ret_scores, info
def loss(self, insts: List, input_lexi, input_expr, input_mask, margin=0.): conf = self.conf bsize = len(insts) # first get gold info, also multiple valid-masks gold_masks, gold_idxes, gold_items_arr, gold_valid, gold_idxes2, gold_items2_arr = self.batch_inputs_h( insts) input_mask = input_mask * gold_valid.unsqueeze(-1) # [*, slen] # step 1: selector if conf.use_selector: sel_loss, sel_mask = self.sel.loss(input_expr, input_mask, gold_masks, margin=margin) else: sel_loss, sel_mask = None, self._select_cands_training( input_mask, gold_masks, conf.train_min_rate) sel_idxes, sel_valid_mask = BK.mask2idx(sel_mask) # [*, max-count] sel_gold_idxes = gold_idxes.gather(-1, sel_idxes) sel_gold_idxes2 = gold_idxes2.gather(-1, sel_idxes) # todo(+N): only get items by head position! _tmp_i0, _tmp_i1 = np.arange(bsize)[:, np.newaxis], BK.get_value( sel_idxes) sel_items = gold_items_arr[_tmp_i0, _tmp_i1] # [*, mc] sel2_items = gold_items2_arr[_tmp_i0, _tmp_i1] # step 2: encoding and labeling # if we select nothing # ----- debug # zlog(f"fb-extractor 1: shape sel_idxes = {sel_idxes.shape}") # ----- sel_shape = BK.get_shape(sel_idxes) if sel_shape[-1] == 0: lab_loss = [[BK.zeros([]), BK.zeros([])]] sel2_lab_loss = [[BK.zeros([]), BK.zeros([])] ] if self.use_secondary_type else None sel_lab_idxes = sel_gold_idxes sel_lab_embeds = BK.zeros(sel_shape + [conf.lab_conf.n_dim]) ret_items = sel_items # dim-1==0 else: sel_hid_exprs = self._enc(input_lexi, input_expr, input_mask, sel_idxes) # [*, mc, DLab] lab_loss, sel_lab_idxes, sel_lab_embeds = self.hl.loss( sel_hid_exprs, sel_valid_mask, sel_gold_idxes, margin=margin) if conf.train_gold_corr: sel_lab_idxes = sel_gold_idxes if not self.hl.conf.use_lookup_soft: sel_lab_embeds = self.hl.lookup(sel_lab_idxes) ret_items = sel_items # ===== if self.use_secondary_type: sectype_embeds = self.t1tot2(sel_lab_idxes) # [*, mc, D] if conf.sectype_noback_enc: sel2_input = sel_hid_exprs.detach( ) + sectype_embeds # [*, mc, D] else: sel2_input = sel_hid_exprs + sectype_embeds # [*, mc, D] # ===== # sepcial for the sectype mask (sample it within the gold ones) sel2_valid_mask = self._select_cands_training( (sel_gold_idxes > 0).float(), (sel_gold_idxes2 > 0).float(), conf.train_min_rate_s2) # ===== sel2_lab_loss, sel2_lab_idxes, sel2_lab_embeds = self.hl.loss( sel2_input, sel2_valid_mask, sel_gold_idxes2, margin=margin) if conf.train_gold_corr: sel2_lab_idxes = sel_gold_idxes2 if not self.hl.conf.use_lookup_soft: sel2_lab_embeds = self.hl.lookup(sel2_lab_idxes) if conf.sectype_t2ift1: sel2_lab_idxes = sel2_lab_idxes * (sel_lab_idxes > 0).long( ) # pred t2 only if t1 is not 0 (nil) # combine the two if sel2_lab_idxes.sum().item( ) > 0: # if there are any gold sectypes ret_items = np.concatenate([ret_items, sel2_items], -1) # [*, mc*2] sel_idxes = BK.concat([sel_idxes, sel_idxes], -1) sel_valid_mask = BK.concat( [sel_valid_mask, sel2_valid_mask], -1) sel_lab_idxes = BK.concat([sel_lab_idxes, sel2_lab_idxes], -1) sel_lab_embeds = BK.concat( [sel_lab_embeds, sel2_lab_embeds], -2) else: sel2_lab_loss = None # ===== # step 3: exclude nil and return if conf.exclude_nil: # [*, mc', ...] sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, _, ret_items = \ self._exclude_nil(sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_items_arr=ret_items) # sel_enc_expr = BK.gather_first_dims(input_expr, sel_idxes, -2) # [*, mc', D] # step 4: finally prepare loss and items for one_loss in lab_loss: one_loss[0] *= conf.lambda_ne ret_losses = lab_loss if sel2_lab_loss is not None: for one_loss in sel2_lab_loss: one_loss[0] *= conf.lambda_ne2 ret_losses = ret_losses + sel2_lab_loss if sel_loss is not None: for one_loss in sel_loss: one_loss[0] *= conf.lambda_ns ret_losses = ret_losses + sel_loss # ----- debug # zlog(f"fb-extractor 2: shape sel_idxes = {sel_idxes.shape}") # ----- # mask out invalid items with None ret_items[BK.get_value(1. - sel_valid_mask).astype(np.bool)] = None return ret_losses, ret_items, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds
def fb(self, annotated_insts, scoring_expr_pack, training: bool, loss_factor: float): # depth constrain: <= sched_depth cur_depth_constrain = int(self.sched_depth.value) # run ags = [ BfsLinearAgenda.init_agenda(TdState, z, self.require_sg) for z in annotated_insts ] self.oracle_manager.refresh_insts(annotated_insts) self.searcher.refresh(scoring_expr_pack) self.searcher.go(ags) # collect local loss: credit assignment if self.train_force or self.train_ss: states = [] for ag in ags: for final_state in ag.local_golds: # todo(warn): remember to use depth_eff rather than depth # todo(warn): deprecated # if final_state.depth_eff > cur_depth_constrain: # continue states.append(final_state) logprobs_arc = [s.arc_score_slice for s in states] # no labeling scores for reduce operations logprobs_label = [ s.label_score_slice for s in states if s.label_score_slice is not None ] credits_arc, credits_label = None, None elif self.train_of: states = [] for ag in ags: for final_state in ag.ends: for s in final_state.get_path(True): states.append(s) logprobs_arc = [s.arc_score_slice for s in states] # no labeling scores for reduce operations logprobs_label = [ s.label_score_slice for s in states if s.label_score_slice is not None ] credits_arc, credits_label = None, None elif self.train_rl: logprobs_arc, logprobs_label, credits_arc, credits_label = [], [], [], [] for ag in ags: # todo(+2): need to check search failure? # todo(+2): ignoring labels when reducing or wrong-arc for final_state in ag.ends: # todo(warn): deprecated # if final_state.depth_eff > cur_depth_constrain: # continue one_credits_arc = [] one_credits_label = [] self.oracle_manager.set_losses(final_state) for s in final_state.get_path(True): _, _, delta_arc, delta_label = s.oracle_loss_cache logprobs_arc.append(s.arc_score_slice) if delta_arc > 0: # only blame arc one_credits_arc.append(-delta_arc) else: one_credits_arc.append(0) if delta_label > 0: logprobs_label.append(s.label_score_slice) one_credits_label.append(-delta_label) elif s.label_score_slice is not None: # not bad labeling logprobs_label.append(s.label_score_slice) one_credits_label.append(0) # TODO(+N): minus average may encourage bad moves? # balance # avg_arc = sum(one_credits_arc) / len(one_credits_arc) # avg_label = 0. if len(one_credits_label)==0 else sum(one_credits_label) / len(one_credits_label) baseline_arc = baseline_label = -0.5 credits_arc.extend(z - baseline_arc for z in one_credits_arc) credits_label.extend(z - baseline_label for z in one_credits_label) else: raise NotImplementedError("CANNOT get here!") # sum all local losses loss_zero = BK.zeros([]) if len(logprobs_arc) > 0: batched_logprobs_arc = SliceManager.combine_slices( logprobs_arc, None) loss_arc = (-BK.sum(batched_logprobs_arc)) if (credits_arc is None) \ else (-BK.sum(batched_logprobs_arc * BK.input_real(credits_arc))) else: loss_arc = loss_zero if len(logprobs_label) > 0: batched_logprobs_label = SliceManager.combine_slices( logprobs_label, None) loss_label = (-BK.sum(batched_logprobs_label)) if (credits_label is None) \ else (-BK.sum(batched_logprobs_label*BK.input_real(credits_label))) else: loss_label = loss_zero final_loss_sum = loss_arc + loss_label # divide loss by what? num_sent = len(annotated_insts) num_valid_arcs, num_valid_labels = len(logprobs_arc), len( logprobs_label) # num_valid_steps = len(states) if self.tconf.loss_div_step: final_loss = loss_arc / max(1, num_valid_arcs) + loss_label / max( 1, num_valid_labels) else: final_loss = final_loss_sum / num_sent # val_loss_arc = BK.get_value(loss_arc).item() val_loss_label = BK.get_value(loss_label).item() val_loss_sum = val_loss_arc + val_loss_label # cur_has_loss = 1 if ((num_valid_arcs + num_valid_labels) > 0) else 0 if training and cur_has_loss: BK.backward(final_loss, loss_factor) # todo(warn): make tok==steps for dividing in common.run info = { "sent": num_sent, "tok": num_valid_arcs, "valid_arc": num_valid_arcs, "valid_label": num_valid_labels, "loss_sum": val_loss_sum, "loss_arc": val_loss_arc, "loss_label": val_loss_label, "fb_all": 1, "fb_valid": cur_has_loss } return info
def _decode(self, mb_insts: List[ParseInstance], mb_enc_expr, mb_valid_expr, mb_go1_pack, training: bool, margin: float): # ===== use_sib, use_gp = self.use_sib, self.use_gp # ===== mb_size = len(mb_insts) mat_shape = BK.get_shape(mb_valid_expr) max_slen = mat_shape[-1] # step 1: extract the candidate features batch_idxes, m_idxes, h_idxes, sib_idxes, gp_idxes = self.helper.get_cand_features( mb_valid_expr) # ===== # step 2: high order scoring # step 2.1: basic scoring, [*], [*, Lab] arc_scores, lab_scores = self._get_basic_score(mb_enc_expr, batch_idxes, m_idxes, h_idxes, sib_idxes, gp_idxes) cur_system_labeled = (lab_scores is not None) # step 2.2: margin # get gold labels, which can be useful for later calculating loss if training: gold_b_idxes, gold_m_idxes, gold_h_idxes, gold_sib_idxes, gold_gp_idxes, gold_lab_idxes = \ [(None if z is None else BK.input_idx(z)) for z in self._get_idxes_from_insts(mb_insts, use_sib, use_gp)] # add the margins to the scores: (m,h), (m,sib), (m,gp) cur_margin = margin / self.margin_div self._add_margin_inplaced(mat_shape, gold_b_idxes, gold_m_idxes, gold_h_idxes, gold_lab_idxes, batch_idxes, m_idxes, h_idxes, arc_scores, lab_scores, cur_margin, cur_margin) if use_sib: self._add_margin_inplaced(mat_shape, gold_b_idxes, gold_m_idxes, gold_sib_idxes, gold_lab_idxes, batch_idxes, m_idxes, sib_idxes, arc_scores, lab_scores, cur_margin, cur_margin) if use_gp: self._add_margin_inplaced(mat_shape, gold_b_idxes, gold_m_idxes, gold_gp_idxes, gold_lab_idxes, batch_idxes, m_idxes, gp_idxes, arc_scores, lab_scores, cur_margin, cur_margin) # may be useful for later training gold_pack = (mb_size, gold_b_idxes, gold_m_idxes, gold_h_idxes, gold_sib_idxes, gold_gp_idxes, gold_lab_idxes) else: gold_pack = None # step 2.3: o1scores if mb_go1_pack is not None: go1_arc_scores, go1_lab_scores = mb_go1_pack # todo(note): go1_arc_scores is not added here, but as the input to the dec-algo if cur_system_labeled: lab_scores += go1_lab_scores[batch_idxes, m_idxes, h_idxes] else: go1_arc_scores = None # step 2.4: max out labels; todo(+N): or using logsumexp here? if cur_system_labeled: max_lab_scores, max_lab_idxes = lab_scores.max(-1) final_scores = arc_scores + max_lab_scores # [*], final input arc scores else: max_lab_idxes = None final_scores = arc_scores # ===== # step 3: actual decode res_heads = [] for sid, inst in enumerate(mb_insts): slen = len(inst) + 1 # plus one for the art-root arr_o1_masks = BK.get_value(mb_valid_expr[sid, :slen, :slen].int()) arr_o1_scores = BK.get_value( go1_arc_scores[sid, :slen, :slen].double()) if ( go1_arc_scores is not None) else None cur_bidx_mask = (batch_idxes == sid) input_pack = [m_idxes, h_idxes, sib_idxes, gp_idxes, final_scores] one_heads = self.helper.decode_one(slen, self.projective, arr_o1_masks, arr_o1_scores, input_pack, cur_bidx_mask) res_heads.append(one_heads) # ===== # step 4: get labels back and pred_pack pred_b_idxes, pred_m_idxes, pred_h_idxes, pred_sib_idxes, pred_gp_idxes, _ = \ [(None if z is None else BK.input_idx(z)) for z in self._get_idxes_from_preds(res_heads, None, use_sib, use_gp)] if cur_system_labeled: # obtain hit components pred_hit_mask = self._get_hit_mask(mat_shape, pred_b_idxes, pred_m_idxes, pred_h_idxes, batch_idxes, m_idxes, h_idxes) if use_sib: pred_hit_mask &= self._get_hit_mask(mat_shape, pred_b_idxes, pred_m_idxes, pred_sib_idxes, batch_idxes, m_idxes, sib_idxes) if use_gp: pred_hit_mask &= self._get_hit_mask(mat_shape, pred_b_idxes, pred_m_idxes, pred_gp_idxes, batch_idxes, m_idxes, gp_idxes) # get pred labels (there should be only one hit per mod!) pred_labels = BK.constants_idx([mb_size, max_slen], 0) pred_labels[batch_idxes[pred_hit_mask], m_idxes[pred_hit_mask]] = max_lab_idxes[pred_hit_mask] res_labels = BK.get_value(pred_labels) pred_lab_idxes = pred_labels[pred_b_idxes, pred_m_idxes] else: res_labels = None pred_lab_idxes = None pred_pack = (mb_size, pred_b_idxes, pred_m_idxes, pred_h_idxes, pred_sib_idxes, pred_gp_idxes, pred_lab_idxes) # return return res_heads, res_labels, gold_pack, pred_pack