def decode_frame(self, ibatch, scores_t: BK.Expr, pred_max_layer: int, voc, pred_label: bool, pred_tag: str, pred_check_layer: int): # -- # first get topk for each position logprobs_t = scores_t.log_softmax(-1) # [*, dlen, L] pred_scores, pred_labels = logprobs_t.topk( pred_max_layer) # [*, dlen, K] arr_scores, arr_labels = BK.get_value(pred_scores), BK.get_value( pred_labels) # [*, dlen, K] # put results res_bidxes, res_widxes, res_frames = [], [], [] # flattened results res_farrs = np.full(arr_scores.shape, None, dtype=object) # [*, dlen, K] for bidx, item in enumerate( ibatch.items): # for each item in the batch _dec_offsets = item.seq_info.dec_offsets for sidx, sent in enumerate(item.sents): # todo(+N): currently we only predict for center if there is! if item.center_sidx is not None and sidx != item.center_sidx: continue # skip non-center sent in this mode! _start = _dec_offsets[sidx] _len = len(sent) _arr_scores, _arr_labels = arr_scores[bidx][ _start:_start + _len], arr_labels[bidx][_start:_start + _len] for widx in range(_len): _full_widx = widx + _start # idx in the msent _tmp_set = set() for _k in range(pred_max_layer): _score, _lab = float(_arr_scores[widx][_k]), int( _arr_labels[widx][_k]) if _lab == 0: # note: lab=0 means NIL break _type_str = (voc.idx2word(_lab) if pred_label else "UNK") _type_str_prefix = '.'.join( _type_str.split('.')[:pred_check_layer]) if pred_check_layer >= 0 and _type_str_prefix in _tmp_set: continue # ignore since constraint _tmp_set.add(_type_str_prefix) # add new one! res_bidxes.append(bidx) res_widxes.append(_full_widx) _new_frame = sent.make_frame(widx, 1, tag=pred_tag, type=_type_str, score=float(_score)) _new_frame.set_label_idx(int(_lab)) _new_frame._tmp_sstart = _start # todo(+N): ugly tmp value ... _new_frame._tmp_sidx = sidx _new_frame._tmp_item = item res_frames.append(_new_frame) res_farrs[bidx, _full_widx, _k] = _new_frame # return res_bidxes_t, res_widxes_t = BK.input_idx(res_bidxes), BK.input_idx( res_widxes) # [??] return (res_bidxes_t, res_widxes_t), res_frames, res_farrs # [??], [*, dlen, K]
def decode_frame_given(self, ibatch, scores_t: BK.Expr, pred_max_layer: int, voc, pred_label: bool, pred_tag: str, assume_osof: bool): if pred_label: # if overwrite label! logprobs_t = scores_t.log_softmax(-1) # [*, dlen, L] pred_scores, pred_labels = logprobs_t.max( -1) # [*, dlen], note: maximum! arr_scores, arr_labels = BK.get_value(pred_scores), BK.get_value( pred_labels) # [*, dlen] else: arr_scores = arr_labels = None # -- # read given results res_bidxes, res_widxes, res_frames = [], [], [] # flattened results tmp_farrs = defaultdict(list) # later assign for bidx, item in enumerate( ibatch.items): # for each item in the batch _trg_frames = [item.inst] if assume_osof else \ sum([sent.get_frames(pred_tag) for sidx,sent in enumerate(item.sents) if (item.center_sidx is None or sidx == item.center_sidx)],[]) # still only pick center ones! # -- _dec_offsets = item.seq_info.dec_offsets for _frame in _trg_frames: # note: simply sort by original order! sidx = item.sents.index(_frame.sent) _start = _dec_offsets[sidx] _full_hidx = _start + _frame.mention.shead_widx # add new one res_bidxes.append(bidx) res_widxes.append(_full_hidx) _frame._tmp_sstart = _start # todo(+N): ugly tmp value ... _frame._tmp_sidx = sidx _frame._tmp_item = item res_frames.append(_frame) tmp_farrs[(bidx, _full_hidx)].append(_frame) # assign/rewrite label? if pred_label: _lab = int(arr_labels[bidx, _full_hidx]) # label index _frame.set_label_idx(_lab) _frame.set_label(voc.idx2word(_lab)) _frame.set_score(float(arr_scores[bidx, _full_hidx])) # -- # -- res_farrs = np.full(BK.get_shape(scores_t)[:-1] + [pred_max_layer], None, dtype=object) # [*, dlen, K] for _key, _values in tmp_farrs.items(): bidx, widx = _key _values = _values[:pred_max_layer] # truncate if more! res_farrs[bidx, widx, :len(_values)] = _values # return res_bidxes_t, res_widxes_t = BK.input_idx(res_bidxes), BK.input_idx( res_widxes) # [??] return (res_bidxes_t, res_widxes_t), res_frames, res_farrs # [??], [*, dlen, K]
def assign_boundaries(self, items: List, left_idxes: BK.Expr, right_idxes: BK.Expr): _arr_left, _arr_right = BK.get_value(left_idxes), BK.get_value( right_idxes) for ii, item in enumerate(items): _mention = item.mention _start = item._tmp_sstart # need to minus this!! _left_widx, _right_widx = _arr_left[ii].item( ) - _start, _arr_right[ii].item() - _start _mention.set_span(*(_mention.get_span()), shead=True) # first move to shead! _mention.set_span(_left_widx, _right_widx - _left_widx + 1)
def decode_arg(self, res_evts: List, arg_scores_t: BK.Expr, pred_max_layer: int, voc, arg_allowed_sent_gap: int, arr_efs): # first get topk arg_logprobs_t = arg_scores_t.log_softmax(-1) # [??, dlen, L] pred_arg_scores, pred_arg_labels = arg_logprobs_t.topk( pred_max_layer) # [??, dlen, K] arr_arg_scores, arr_arg_labels = BK.get_value( pred_arg_scores), BK.get_value(pred_arg_labels) # [??, dlen, K] # put results res_fidxes, res_widxes, res_args = [], [], [] # flattened results for fidx, evt in enumerate(res_evts): # for each evt item = evt._tmp_item # cached _evt_sidx = evt._tmp_sidx _dec_offsets = item.seq_info.dec_offsets for sidx, sent in enumerate(item.sents): if abs(sidx - _evt_sidx) > arg_allowed_sent_gap: continue # larger than allowed sentence gap _start = _dec_offsets[sidx] _len = len(sent) _arr_scores, _arr_labels = arr_arg_scores[fidx][ _start:_start + _len], arr_arg_labels[fidx][_start:_start + _len] for widx in range(_len): _full_widx = widx + _start # idx in the msent _new_ef = None if arr_efs is not None: # note: arr_efs should also expand to frames! _new_ef = arr_efs[ fidx, _full_widx, 0] # todo(+N): only get the first one! if _new_ef is None: continue # no ef! for _score, _lab in zip(_arr_scores[widx], _arr_labels[widx]): # [K] if _lab == 0: # note: idx=0 means NIL break # add new one!! res_fidxes.append(fidx) res_widxes.append(_full_widx) if _new_ef is None: _new_ef = sent.make_entity_filler( widx, 1) # share them if possible! _new_arg = evt.add_arg(_new_ef, role=voc.idx2word(_lab), score=float(_score)) _new_arg._tmp_sstart = _start # todo(+N): ugly tmp value ... res_args.append(_new_arg) # return res_fidxes_t, res_widxes_t = BK.input_idx(res_fidxes), BK.input_idx( res_widxes) # [??] return (res_fidxes_t, res_widxes_t), res_args
def decode_evt(self, dec, ibatch, evt_scores_t: BK.Expr): _pred_max_layer_evt = dec.conf.max_layer_evt _voc_evt = dec.voc_evt _pred_evt_label = self.conf.pred_evt_label # -- evt_logprobs_t = evt_scores_t.log_softmax(-1) # [*, dlen, L] pred_evt_scores, pred_evt_labels = evt_logprobs_t.topk( _pred_max_layer_evt) # [*, dlen, K] arr_evt_scores, arr_evt_labels = BK.get_value( pred_evt_scores), BK.get_value(pred_evt_labels) # [*, dlen, K] # put results res_bidxes, res_widxes, res_evts = [], [], [] # flattened results for bidx, item in enumerate( ibatch.items): # for each item in the batch _dec_offsets = item.seq_info.dec_offsets for sidx, sent in enumerate(item.sents): # note: here we only predict for center if there is! if item.center_sidx is not None and sidx != item.center_sidx: continue # skip non-center sent in this mode! _start = _dec_offsets[sidx] _len = len(sent) _arr_scores, _arr_labels = arr_evt_scores[bidx][ _start:_start + _len], arr_evt_labels[bidx][_start:_start + _len] for widx in range(_len): for _score, _lab in zip(_arr_scores[widx], _arr_labels[widx]): # [K] if _lab == 0: # note: idx=0 means NIL break # add new one!! res_bidxes.append(bidx) res_widxes.append( _start + widx) # note: remember to add offset! _new_evt = sent.make_event( widx, 1, type=(_voc_evt.idx2word(_lab) if _pred_evt_label else "UNK"), score=float(_score)) _new_evt.set_label_idx(int(_lab)) _new_evt._tmp_sstart = _start # todo(+N): ugly tmp value ... _new_evt._tmp_sidx = sidx _new_evt._tmp_item = item res_evts.append(_new_evt) # return res_bidxes_t, res_widxes_t = BK.input_idx(res_bidxes), BK.input_idx( res_widxes) # [??] return (res_bidxes_t, res_widxes_t), res_evts
def _get_extra_score(self, cand_score, insts, cand_res, arr_gold_items, use_cons: bool, use_lu: bool): # conf: DirectExtractorConf = self.conf # -- # first cand score cand_score = self._extend_cand_score(cand_score) # then cons_lex score cons_lex_node = self.cons_lex_node if use_cons and cons_lex_node is not None: cons_lex = cons_lex_node.cons flt_arr_gold_items = arr_gold_items.flatten() _shape = BK.get_shape(cand_res.mask_expr) if cand_res.gaddr_expr is None: gaddr_expr = BK.constants(_shape, -1, dtype=BK.long) else: gaddr_expr = cand_res.gaddr_expr all_arrs = [BK.get_value(z) for z in [cand_res.widx_expr, cand_res.wlen_expr, cand_res.mask_expr, gaddr_expr]] arr_feats = np.full(_shape, None, dtype=object) for bidx, inst in enumerate(insts): one_arr_feats = arr_feats[bidx] _ii = -1 for one_widx, one_wlen, one_mask, one_gaddr in zip(*[z[bidx] for z in all_arrs]): _ii += 1 if one_mask == 0.: continue # skip invlaid ones if use_lu and one_gaddr>=0: one_feat = cons_lex.lu2feat(flt_arr_gold_items[one_gaddr].info["luName"]) else: one_feat = cons_lex.span2feat(inst, one_widx, one_wlen) one_arr_feats[_ii] = one_feat cons_valids = cons_lex_node.lookup_with_feats(arr_feats) cons_score = (1.-cons_valids) * Constants.REAL_PRAC_MIN else: cons_score = None # sum return self._sum_scores(cand_score, cons_score)
def _common_nmst(CPU_f, scores_expr, mask_expr, lengths_arr, labeled, ret_arr): assert labeled with BK.no_grad_env(): # argmax-label: [BS, m, h] scores_unlabeled_max, labels_argmax = scores_expr.max(-1) # scores_unlabeled_max_arr = BK.get_value(scores_unlabeled_max) mst_heads_arr, _, mst_scores_arr = CPU_f(scores_unlabeled_max_arr, lengths_arr, labeled=False) # [BS, m] mst_heads_expr = BK.input_idx(mst_heads_arr) mst_labels_expr = BK.gather_one_lastdim(labels_argmax, mst_heads_expr).squeeze(-1) # prepare for the outputs if ret_arr: return mst_heads_arr, BK.get_value(mst_labels_expr), mst_scores_arr else: return mst_heads_expr, mst_labels_expr, BK.input_real(mst_scores_arr)
def put_results(self, insts, best_labs, best_scores): conf: SeqExtractorConf = self.conf vocab: SeqVocab = self.vocab # -- base_vocab = vocab.base_vocab arr_slabs, arr_scores = [ BK.get_value(z) for z in [best_labs, best_scores] ] for bidx, inst in enumerate(insts): self._clear_f(inst) # first clean things cur_len = len(inst) if isinstance(inst, Sent) else len(inst.sent) # -- cur_slabs, cur_scores = arr_slabs[bidx][:cur_len], arr_scores[ bidx][:cur_len] cur_results = vocab.tags2spans_idx(cur_slabs) inst.info["slab"] = [vocab.idx2word(z) for z in cur_slabs] # put seq-lab for one_widx, one_wlen, one_lab in cur_results: one_lab = int(one_lab) assert one_lab > 0, "Error: should not extract 'O'!?" new_item = self._new_f( inst, int(one_widx), int(one_wlen), one_lab, np.mean(cur_scores[one_widx:one_widx + one_wlen]).item(), vocab=base_vocab, )
def put_results(self, insts, best_labs, best_scores): conf: AnchorExtractorConf = self.conf # -- arr_labs, arr_scores = [ BK.get_value(z) for z in [best_labs, best_scores] ] flattened_items = [] for bidx, inst in enumerate(insts): self._clear_f(inst) # first clean things cur_len = len(inst) if isinstance(inst, Sent) else len(inst.sent) cur_labs, cur_scores = arr_labs[bidx][:cur_len], arr_scores[ bidx][:cur_len] # simply put them for one_widx in range(cur_len): one_lab, one_score = int(cur_labs[one_widx]), float( cur_scores[one_widx]) # todo(+N): again, assuming NON-idx == 0 if one_lab == 0: continue # invalid one: unmask or NON # set it new_item = self._new_f(inst, one_widx, 1, one_lab, float(one_score)) new_item.mention.info["widxes1"] = [one_widx] # save it flattened_items.append(new_item) # -- return flattened_items
def predict(self, med: ZMediator): conf: ZDecoderUDEPConf = self.conf # -- # depth scores all_depth_raw_scores = med.main_scores.get((self.name, "depth"), []) # [*, slen] all_depth_logprobs = [BK.logsigmoid(z.squeeze(-1)) for z in all_depth_raw_scores] if len(all_depth_logprobs) > 0: final_depth_logprobs = self.depth_node.helper.pred(all_logprobs=all_depth_logprobs) # [*, slen] else: final_depth_logprobs = -99. # udep scores all_udep_raw_scores = med.main_scores.get((self.name, "udep"), []) # [*, slen, slen, L] all_udep_logprobs = [z.log_softmax(-1) for z in all_udep_raw_scores] final_udep_logprobs = self.udep_node.helper.pred(all_logprobs=all_udep_logprobs) # [*, slen, slen, L] # prepare final scores final_scores = BK.pad(final_udep_logprobs, [0,0,1,0,1,0], value=Constants.REAL_PRAC_MIN) # [*, 1+slen, 1+slen, L] final_scores[:, :, :, 0] = Constants.REAL_PRAC_MIN # force no 0!! final_scores[:, 1:, 0, self._label_idx_root] = (final_depth_logprobs + conf.udep_pred_root_penalty) # assign root score # decode from msp2.tools.algo.nmst import mst_unproj # decoding algorithm insts = med.insts arr_lengths = np.asarray([len(z.sent)+1 for z in insts]) # +1 for arti-root arr_scores = BK.get_value(final_scores) # [*, 1+slen, 1+slen, L] arr_ret_heads, arr_ret_labels, arr_ret_scores = mst_unproj(arr_scores, arr_lengths, labeled=True) # [*, 1+slen] self.helper.put_results(insts, [arr_ret_heads, arr_ret_labels, arr_ret_scores]) # -- return {}
def decode_evt_given(self, dec, ibatch, evt_scores_t: BK.Expr): _voc_evt = dec.voc_evt _assume_osof = dec.conf.assume_osof # one seq one frame _pred_evt_label = self.conf.pred_evt_label # -- if _pred_evt_label: evt_logprobs_t = evt_scores_t.log_softmax(-1) # [*, dlen, L] pred_evt_scores, pred_evt_labels = evt_logprobs_t.max( -1) # [*, dlen], note: maximum! arr_evt_scores, arr_evt_labels = BK.get_value( pred_evt_scores), BK.get_value(pred_evt_labels) # [*, dlen] else: arr_evt_scores = arr_evt_labels = None # -- # read given results res_bidxes, res_widxes, res_evts = [], [], [] # flattened results for bidx, item in enumerate( ibatch.items): # for each item in the batch _trg_evts = [item.inst] if _assume_osof else \ sum([sent.events for sidx,sent in enumerate(item.sents) if (item.center_sidx is None or sidx == item.center_sidx)],[]) # -- _dec_offsets = item.seq_info.dec_offsets for _evt in _trg_evts: sidx = item.sents.index(_evt.sent) _start = _dec_offsets[sidx] _full_hidx = _start + _evt.mention.shead_widx # add new one res_bidxes.append(bidx) res_widxes.append(_full_hidx) _evt._tmp_sstart = _start # todo(+N): ugly tmp value ... _evt._tmp_sidx = sidx _evt._tmp_item = item res_evts.append(_evt) # assign label? if _pred_evt_label: _lab = int(arr_evt_labels[bidx, _full_hidx]) # label index _evt.set_label_idx(_lab) _evt.set_label(_voc_evt.idx2word(_lab)) _evt.set_score(float(arr_evt_scores[bidx, _full_hidx])) # -- # -- # return res_bidxes_t, res_widxes_t = BK.input_idx(res_bidxes), BK.input_idx( res_widxes) # [??] return (res_bidxes_t, res_widxes_t), res_evts
def predict(self, med: ZMediator): # -- pred_upos_labels, pred_upos_scores = self._pred_upos() all_arrs = [ BK.get_value(z) for z in [pred_upos_labels, pred_upos_scores] ] self.helper.put_results(med.insts, all_arrs) # -- return {}
def decode_arg(self, dec, res_evts: List, arg_scores_t: BK.Expr): _pred_max_layer_arg = dec.conf.max_layer_arg _arg_allowed_sent_gap = dec.conf.arg_allowed_sent_gap _voc_arg = dec.voc_arg # -- arg_logprobs_t = arg_scores_t.log_softmax(-1) # [??, dlen, L] pred_arg_scores, pred_arg_labels = arg_logprobs_t.topk( _pred_max_layer_arg) # [??, dlen, K] arr_arg_scores, arr_arg_labels = BK.get_value( pred_arg_scores), BK.get_value(pred_arg_labels) # [??, dlen, K] # put results res_fidxes, res_widxes, res_args = [], [], [] # flattened results for fidx, evt in enumerate(res_evts): # for each evt item = evt._tmp_item # cached _evt_sidx = evt._tmp_sidx _dec_offsets = item.seq_info.dec_offsets for sidx, sent in enumerate(item.sents): if abs(sidx - _evt_sidx) > _arg_allowed_sent_gap: continue # larger than allowed sentence gap _start = _dec_offsets[sidx] _len = len(sent) _arr_scores, _arr_labels = arr_arg_scores[fidx][ _start:_start + _len], arr_arg_labels[fidx][_start:_start + _len] for widx in range(_len): for _score, _lab in zip(_arr_scores[widx], _arr_labels[widx]): # [K] if _lab == 0: # note: idx=0 means NIL break # add new one!! res_fidxes.append(fidx) res_widxes.append(_start + widx) _new_ef = sent.make_entity_filler(widx, 1) _new_arg = evt.add_arg(_new_ef, role=_voc_arg.idx2word(_lab), score=float(_score)) _new_arg._tmp_sstart = _start # todo(+N): ugly tmp value ... res_args.append(_new_arg) # return res_fidxes_t, res_widxes_t = BK.input_idx(res_fidxes), BK.input_idx( res_widxes) # [??] return (res_fidxes_t, res_widxes_t), res_args
def put_labels(self, arr_items, best_labs, best_scores): assert BK.get_shape(best_labs) == BK.get_shape(arr_items) # -- all_arrs = [BK.get_value(z) for z in [best_labs, best_scores]] for cur_items, cur_labs, cur_scores in zip(arr_items, *all_arrs): for one_item, one_lab, one_score in zip(cur_items, cur_labs, cur_scores): if one_item is None: continue one_lab, one_score = int(one_lab), float(one_score) one_item.score = one_score one_item.set_label_idx(one_lab) one_item.set_label(self.vocab.idx2word(one_lab))
def decode_arg_bio(self, res_evts: List, arg_scores_t: BK.Expr, pred_max_layer: int, voc_bio, arg_allowed_sent_gap: int, arr_efs): assert pred_max_layer == 1, "Currently BIO only allow one!" assert arr_efs is None, "Currently BIO does not allow pick given efs!" _vocab_bio_arg = voc_bio _vocab_arg = _vocab_bio_arg.base_vocab # -- arg_logprobs_t = arg_scores_t.log_softmax(-1) # [??, dlen, L] # todo(+N): allow multiple bio seqs? pred_arg_scores, pred_arg_labels = arg_logprobs_t.max( -1) # note: here only top1, [??, dlen] arr_arg_scores, arr_arg_labels = BK.get_value( pred_arg_scores), BK.get_value(pred_arg_labels) # [??, dlen] # put results for fidx, evt in enumerate(res_evts): # for each evt item = evt._tmp_item # cached _evt_sidx = evt._tmp_sidx _dec_offsets = item.seq_info.dec_offsets for sidx, sent in enumerate(item.sents): if abs(sidx - _evt_sidx) > arg_allowed_sent_gap: continue # larger than allowed sentence gap _start = _dec_offsets[sidx] _len = len(sent) _arr_scores, _arr_labels = arr_arg_scores[fidx][ _start:_start + _len], arr_arg_labels[fidx][_start:_start + _len] # decode bio _arg_results = _vocab_bio_arg.tags2spans_idx(_arr_labels) for a_widx, a_wlen, a_lab in _arg_results: a_lab = int(a_lab) assert a_lab > 0, "Error: should not extract 'O'!?" _new_ef = sent.make_entity_filler(a_widx, a_wlen) a_role = _vocab_arg.idx2word(a_lab) _new_arg = evt.add_arg(_new_ef, a_role, score=np.mean( _arr_scores[a_widx:a_widx + a_wlen]).item()) # -- return # no need to return anything here
def put_results(self, insts, best_labs, best_scores, widx_expr, wlen_expr, mask_expr): conf: DirectExtractorConf = self.conf conf_pred_ignore_non = conf.pred_ignore_non # -- all_arrs = [BK.get_value(z) for z in [best_labs, best_scores, widx_expr, wlen_expr, mask_expr]] for bidx, inst in enumerate(insts): self._clear_f(inst) # first clean things for one_lab, one_score, one_widx, one_wlen, one_mask in zip(*[z[bidx] for z in all_arrs]): one_lab = int(one_lab) # todo(+N): again, assuming NON-idx == 0 if one_mask == 0. or (conf_pred_ignore_non and one_lab == 0): continue # invalid one: unmask or NON new_item = self._new_f(inst, int(one_widx), int(one_wlen), one_lab, float(one_score))
def nmarginal_proj(scores_expr, mask_expr, lengths_arr, labeled=True): assert labeled with BK.no_grad_env(): # first make it unlabeled by sum-exp scores_unlabeled = BK.logsumexp(scores_expr, dim=-1) # [BS, m, h] # marginal for unlabeled scores_unlabeled_arr = BK.get_value(scores_unlabeled) marginals_unlabeled_arr = marginal_proj(scores_unlabeled_arr, lengths_arr, False) # back to labeled values marginals_unlabeled_expr = BK.input_real(marginals_unlabeled_arr) marginals_labeled_expr = marginals_unlabeled_expr.unsqueeze(-1) * BK.exp(scores_expr - scores_unlabeled.unsqueeze(-1)) # [BS, m, h, L] return _ensure_margins_norm(marginals_labeled_expr)
def predict(self, flt_items, flt_input_expr, flt_pair_expr, flt_full_expr, flt_mask_expr): conf: ExtenderConf = self.conf if len(flt_items) <= 0: return None # no input item! # -- enc_t = self._forward_eenc(flt_input_expr, flt_full_expr, flt_mask_expr) # [*, D] s_left, s_right = self.enode.score(enc_t, flt_pair_expr if conf.ext_use_finput else None, flt_mask_expr) # -- max_scores, left_idxes, right_idxes = SpanExpanderNode.decode_with_scores(s_left, s_right, normalize=True) all_arrs = [BK.get_value(z) for z in [left_idxes, right_idxes]] for cur_item, cur_left_idx, cur_right_idx in zip(flt_items, *all_arrs): new_widx, new_wlen = int(cur_left_idx), int(cur_right_idx+1-cur_left_idx) self.ext_span_setter(cur_item.mention, new_widx, new_wlen)
def put_results(self, insts, best_labs, best_scores, seg_masks, seg_ext_widxes0, seg_ext_widxes, seg_ext_masks0, seg_ext_masks, cand_full_scores, cand_decisions, split_decisions): conf: SoftExtractorConf = self.conf # -- all_arrs = [ BK.get_value(z) for z in [ best_labs, best_scores, seg_masks, seg_ext_widxes0, seg_ext_widxes, seg_ext_masks0, seg_ext_masks, cand_full_scores, cand_decisions, split_decisions ] ] flattened_items = [] for bidx, inst in enumerate(insts): self._clear_f(inst) # first clean things cur_len = len(inst) if isinstance(inst, Sent) else len(inst.sent) # first set general result res_cand_score = all_arrs[-3][bidx, :cur_len].tolist() res_cand = all_arrs[-2][bidx, :cur_len].tolist() res_split = [1.] + all_arrs[-1][bidx, :int(sum(res_cand)) - 1].tolist() # note: actually B-? inst.info.update({ "res_cand": res_cand, "res_cand_score": res_cand_score, "res_split": res_split }) # then set them separately for one_lab, one_score, one_mask, one_widxes0, one_widxes, one_wmasks0, one_wmasks in \ zip(*[z[bidx] for z in all_arrs[:-3]]): one_lab = int(one_lab) # todo(+N): again, assuming NON-idx == 0 if one_mask == 0. or one_lab == 0: continue # invalid one: unmask or NON # get widxes cur_widxes0 = sorted(one_widxes0[one_wmasks0 > 0.].tolist() ) # original selections from cand cur_widxes = sorted(one_widxes[ one_wmasks > 0.].tolist()) # further possible topk tmp_widx = cur_widxes[0] # currently simply set a tmp one! tmp_wlen = cur_widxes[-1] + 1 - tmp_widx # set it new_item = self._new_f(inst, tmp_widx, tmp_wlen, one_lab, float(one_score)) new_item.mention.info[ "widxes0"] = cur_widxes0 # idxes after cand new_item.mention.info[ "widxes1"] = cur_widxes # idxes after split flattened_items.append(new_item) # -- return flattened_items
def select_topk_non_overlapping(score_t: BK.Expr, topk_t: Union[int, BK.Expr], widx_t: BK.Expr, wlen_t: BK.Expr, input_mask_t: BK.Expr, mask_t: BK.Expr = None, dim=-1): score_shape = BK.get_shape(score_t) assert dim == -1 or dim == len( score_shape - 1 ), "Currently only support last-dim!!" # todo(+2): we can permute to allow any dim! # -- # prepare K if isinstance(topk_t, int): tmp_shape = score_shape.copy() tmp_shape[dim] = 1 # set it as 1 topk_t = BK.constants_idx(tmp_shape, topk_t) # -- reshape_trg = [np.prod(score_shape[:-1]).item(), -1] # [*, ?] _, sorted_idxes_t = score_t.sort(dim, descending=True) # -- # put it as CPU and use loop; todo(+N): more efficient ways? arr_sorted_idxes_t, arr_topk_t, arr_widx_t, arr_wlen_t, arr_input_mask_t, arr_mask_t = \ [BK.get_value(z.reshape(reshape_trg)) for z in [sorted_idxes_t, topk_t, widx_t, wlen_t, input_mask_t, mask_t]] _bsize, _cnum = BK.get_shape(arr_sorted_idxes_t) # [bsize, NUM] arr_topk_mask = np.full([_bsize, _cnum], 0.) # [bsize, NUM] _bidx = 0 for aslice_sorted_idxes_t, aslice_topk_t, aslice_widx_t, aslice_wlen_t, aslice_input_mask_t, aslice_mask_t \ in zip(arr_sorted_idxes_t, arr_topk_t, arr_widx_t, arr_wlen_t, arr_input_mask_t, arr_mask_t): aslice_topk_mask = arr_topk_mask[_bidx] # -- cur_ok_mask = np.copy(aslice_input_mask_t) cur_budget = aslice_topk_t.item() for _cidx in aslice_sorted_idxes_t: _cidx = _cidx.item() if cur_budget <= 0: break # no budget left if not aslice_mask_t[_cidx].item(): continue # non-valid candidate one_widx, one_wlen = aslice_widx_t[_cidx].item( ), aslice_wlen_t[_cidx].item() if np.prod(cur_ok_mask[one_widx:one_widx + one_wlen]).item() == 0.: # any hit one? continue # ok! add it! cur_budget -= 1 cur_ok_mask[one_widx:one_widx + one_wlen] = 0. aslice_topk_mask[_cidx] = 1. _bidx += 1 # note: no need to *=mask_t again since already check in the loop return BK.input_real(arr_topk_mask).reshape(score_shape)
def predict(self, med: ZMediator): conf: ZDecoderSRLConf = self.conf insts, mask_expr = med.insts, med.get_mask_t() # -- pred_evt_labels, pred_evt_scores = self._pred_evt() pred_arg_labels, pred_arg_scores = self._pred_arg(mask_expr, pred_evt_labels) # transfer data from gpu also counts (also make sure gpu calculations are done)! all_arrs = [BK.get_value(z) for z in [pred_evt_labels, pred_evt_scores, pred_arg_labels, pred_arg_scores]] # ===== # assign; also record post-processing (non-computing) time time0 = time.time() self.helper.put_results(insts, all_arrs) time1 = time.time() # -- return {f"{self.name}_posttime": time1-time0}
def decode_udep(self, ibatch, udep_logprobs_t: BK.Expr, root_logprobs_t: BK.Expr): conf: ZDecoderUdepConf = self.conf # -- arr_udep = BK.get_value(udep_logprobs_t.transpose(-2,-3)) # [*, m, h, L] arr_root = None if root_logprobs_t is None else BK.get_value(root_logprobs_t) # [*, dlen] _dim_label = arr_udep.shape[-1] _neg = -10000. # should be enough!! _voc, _lab_range = self.ztask.vpack _idx_root = self._label_idx_root # -- for bidx, item in enumerate(ibatch.items): # for each item in the batch _dec_offsets = item.seq_info.dec_offsets for sidx, sent in enumerate(item.sents): if conf.msent_pred_center and (sidx != item.center_sidx): continue # skip non-center sent in this mode! _start = _dec_offsets[sidx] _len = len(sent) _len_p1 = _len + 1 # -- _arr = np.full([_len_p1, _len_p1, _dim_label], _neg, dtype=np.float32) # [1+m, 1+h, L] # assign label scores _arr[1:_len_p1, 1:_len_p1, 1:_lab_range] = arr_udep[bidx, _start:_start+_len, _start:_start+_len, 1:_lab_range] # assign root scores if arr_root is not None: _arr[1:_len_p1, 0, _idx_root] = arr_root[bidx, _start:_start+_len] else: # todo(+N): currently simply assign a smaller "neg-inf" _arr[1:_len_p1, 0, _idx_root] = -99. # -- from msp2.tools.algo.nmst import mst_unproj # decoding algorithm arr_ret_heads, arr_ret_labels, arr_ret_scores = \ mst_unproj(_arr[None], np.asarray([_len_p1]), labeled=True) # [*, 1+slen] # assign list_dep_heads = arr_ret_heads[0, 1:_len_p1].tolist() list_dep_lidxes = arr_ret_labels[0, 1:_len_p1].tolist() list_dep_labels = _voc.seq_idx2word(list_dep_lidxes) sent.build_dep_tree(list_dep_heads, list_dep_labels)
def assign_boundaries(self, items: List, boundary_node, flat_mask_t: BK.Expr, flat_hid_t: BK.Expr, indicators: List): flat_indicators = boundary_node.prepare_indicators( indicators, BK.get_shape(flat_mask_t)) # -- _bsize, _dlen = BK.get_shape(flat_mask_t) # [???, dlen] _once_bsize = max(1, int(self.conf.boundary_bsize / max(1, _dlen))) # -- if _once_bsize >= _bsize: _, _left_idxes, _right_idxes = boundary_node.decode( flat_hid_t, flat_mask_t, flat_indicators) # [???] else: _all_left_idxes, _all_right_idxes = [], [] for ii in range(0, _bsize, _once_bsize): _, _one_left_idxes, _one_right_idxes = boundary_node.decode( flat_hid_t[ii:ii + _once_bsize], flat_mask_t[ii:ii + _once_bsize], [z[ii:ii + _once_bsize] for z in flat_indicators]) _all_left_idxes.append(_one_left_idxes) _all_right_idxes.append(_one_right_idxes) _left_idxes, _right_idxes = BK.concat(_all_left_idxes, 0), BK.concat( _all_right_idxes, 0) _arr_left, _arr_right = BK.get_value(_left_idxes), BK.get_value( _right_idxes) for ii, item in enumerate(items): _mention = item.mention _start = item._tmp_sstart # need to minus this!! _left_widx, _right_widx = _arr_left[ii].item( ) - _start, _arr_right[ii].item() - _start # todo(+N): sometimes we can have repeated ones, currently simply over-write! if _mention.get_span()[1] == 1: _mention.set_span(*(_mention.get_span()), shead=True) # first move to shead! _mention.set_span(_left_widx, _right_widx - _left_widx + 1)
def loss_on_batch(self, insts: List, loss_factor=1., training=True, force_lidx=None, **kwargs): conf: ZmtlModelConf = self.conf self.refresh_batch(training) # -- # import torch # torch.autograd.set_detect_anomaly(True) # -- actual_insts = list(self._yield_insts(insts)) med = self.med enc_cached_input = self.enc.prepare_inputs(actual_insts) # == # if needed, forward other models (can be self) aug_scores = {} with BK.no_grad_env(): if conf.aug_times >= 1: # forward all at once!! _mm_input = enc_cached_input if (conf.aug_times == 1) else self.enc.prepare_inputs(actual_insts*conf.aug_times) for mm in self.aug_models: # add them all to aug_scores!! mm.enc_forward(_mm_input, aug_scores, conf.aug_training_flag) # == self.refresh_batch(training) med.force_lidx = force_lidx # note: special assign # enc self.enc.forward(None, med, cached_input=enc_cached_input) # dec med.aug_scores = aug_scores # note: assign here!! all_losses = med.do_losses() # -- # final loss and backward info = {"inst0": len(insts), "inst": len(actual_insts), "fb": 1, "fb0": 0} final_loss, loss_info = self.collect_loss(all_losses, ret_dict=(self.pcgrad is not None)) info.update(loss_info) if training: if self.pcgrad is not None: # self.pcgrad.do_backward(self.parameters(), final_loss, loss_factor) # note: we only specially treat enc's, for others, grads will always be accumulated! self.pcgrad.do_backward(self.enc.parameters(), final_loss, loss_factor) else: # as usual # assert final_loss.requires_grad if BK.get_value(final_loss).item() > 0: # note: loss should be >0 usually!! BK.backward(final_loss, loss_factor) else: # no need to backwrad if no loss info["fb0"] = 1 med.restart() # clean! med.force_lidx = None # clear! return info
def put_results(self, insts: List[Sent], best_evt_labs, best_evt_scores, best_arg_labs, best_arg_scores): conf: MySRLConf = self.conf _evt_pred_use_posi = conf.evt_pred_use_posi vocab_evt = self.vocab_evt vocab_arg = self.vocab_arg if conf.arg_use_bio: real_vocab_arg = vocab_arg.base_vocab else: real_vocab_arg = vocab_arg # -- all_arrs = [BK.get_value(z) for z in [best_evt_labs, best_evt_scores, best_arg_labs, best_arg_scores]] for bidx, inst in enumerate(insts): inst.delete_frames(conf.arg_ftag) # delete old args # -- cur_len = len(inst) cur_evt_labs, cur_evt_scores, cur_arg_labs, cur_arg_scores = [z[bidx][:cur_len] for z in all_arrs] inst.info["evt_lab"] = [vocab_evt.idx2word(z) if z>0 else 'O' for z in cur_evt_labs] # -- if _evt_pred_use_posi: # special mode for evt in inst.get_frames(conf.evt_ftag): # reuse posi but re-assign label! one_widx = evt.mention.shead_widx one_lab, one_score = cur_evt_labs[one_widx].item(), cur_evt_scores[one_widx].item() evt.set_label(vocab_evt.idx2word(one_lab)) evt.set_label_idx(one_lab) evt.score = one_score # args new_arg_scores = cur_arg_scores[one_widx][:cur_len] new_arg_label_idxes = cur_arg_labs[one_widx][:cur_len] self.decode_arg(evt, new_arg_label_idxes, new_arg_scores, vocab_arg, real_vocab_arg) else: # pred everything! inst.delete_frames(conf.evt_ftag) for one_widx in range(cur_len): one_lab, one_score = cur_evt_labs[one_widx].item(), cur_evt_scores[one_widx].item() if one_lab == 0: continue # make new evt new_evt = inst.make_frame(one_widx, 1, conf.evt_ftag, type=vocab_evt.idx2word(one_lab), score=one_score) new_evt.set_label_idx(one_lab) self.evt_span_setter(new_evt.mention, one_widx, 1) # args new_arg_scores = cur_arg_scores[one_widx][:cur_len] new_arg_label_idxes = cur_arg_labs[one_widx][:cur_len] self.decode_arg(new_evt, new_arg_label_idxes, new_arg_scores, vocab_arg, real_vocab_arg)
def nmst_greedy(scores_expr, mask_expr, lengths_arr, labeled=True, ret_arr=False): assert labeled with BK.no_grad_env(): scores_shape = BK.get_shape(scores_expr) maxlen = scores_shape[1] # mask out diag scores_expr += BK.diagflat(BK.constants([maxlen], Constants.REAL_PRAC_MIN)).unsqueeze(-1) # combined last two dimension and Max over them combined_scores_expr = scores_expr.view(scores_shape[:-2] + [-1]) combine_max_scores, combined_max_idxes = BK.max(combined_scores_expr, dim=-1) # back to real idxes last_size = scores_shape[-1] greedy_heads = combined_max_idxes // last_size greedy_labels = combined_max_idxes % last_size if ret_arr: mst_heads_arr, mst_labels_arr, mst_scores_arr = [BK.get_value(z) for z in (greedy_heads, greedy_labels, combine_max_scores)] return mst_heads_arr, mst_labels_arr, mst_scores_arr else: return greedy_heads, greedy_labels, combine_max_scores
def decode_upos(self, ibatch, logprobs_t: BK.Expr): conf: ZDecoderUposConf = self.conf # get argmax label! pred_upos_scores, pred_upos_labels = logprobs_t.max(-1) # [*, dlen] # arr_upos_scores, arr_upos_labels = BK.get_value(pred_upos_scores), BK.get_value(pred_upos_labels) arr_upos_labels = BK.get_value(pred_upos_labels) # put results voc = self.voc for bidx, item in enumerate( ibatch.items): # for each item in the batch _dec_offsets = item.seq_info.dec_offsets for sidx, sent in enumerate(item.sents): if conf.msent_pred_center and (sidx != item.center_sidx): continue # skip non-center sent in this mode! _start = _dec_offsets[sidx] _len = len(sent) _upos_idxes = arr_upos_labels[bidx][_start:_start + _len].tolist() _upos_labels = voc.seq_idx2word(_upos_idxes) sent.build_uposes(_upos_labels)
def loss(self, insts: Union[List[Sent], List[Frame]], input_expr: BK.Expr, mask_expr: BK.Expr, pair_expr: BK.Expr = None, lookup_flatten=False, external_extra_score: BK.Expr = None): conf: AnchorExtractorConf = self.conf assert not lookup_flatten bsize, slen = BK.get_shape(mask_expr) # -- # step 0: prepare arr_items, expr_seq_gaddr, expr_seq_labs, expr_group_widxes, expr_group_masks, expr_loss_weight_non = \ self.helper.prepare(insts, mlen=BK.get_shape(mask_expr, -1), use_cache=True) arange2_t = BK.arange_idx(bsize).unsqueeze(-1) # [*, 1] arange3_t = arange2_t.unsqueeze(-1) # [*, 1, 1] # -- # step 1: label, simply scoring everything! _main_t, _pair_t = self.lab_node.transform_expr(input_expr, pair_expr) all_scores_t = self.lab_node.score_all( _main_t, _pair_t, mask_expr, None, local_normalize=False, extra_score=external_extra_score ) # unnormalized scores [*, slen, L] all_probs_t = all_scores_t.softmax(-1) # [*, slen, L] all_gprob_t = all_probs_t.gather(-1, expr_seq_labs.unsqueeze(-1)).squeeze( -1) # [*, slen] # how to weight extended_gprob_t = all_gprob_t[ arange3_t, expr_group_widxes] * expr_group_masks # [*, slen, MW] if BK.is_zero_shape(extended_gprob_t): extended_gprob_max_t = BK.zeros(mask_expr.shape) # [*, slen] else: extended_gprob_max_t, _ = extended_gprob_t.max(-1) # [*, slen] _w_alpha = conf.cand_loss_weight_alpha _weight = ( (all_gprob_t * mask_expr) / (extended_gprob_max_t.clamp(min=1e-5)))**_w_alpha # [*, slen] _label_smoothing = conf.lab_conf.labeler_conf.label_smoothing _loss1 = BK.loss_nll(all_scores_t, expr_seq_labs, label_smoothing=_label_smoothing) # [*, slen] _loss2 = BK.loss_nll(all_scores_t, BK.constants_idx([bsize, slen], 0), label_smoothing=_label_smoothing) # [*, slen] _weight1 = _weight.detach() if conf.detach_weight_lab else _weight _raw_loss = _weight1 * _loss1 + (1. - _weight1) * _loss2 # [*, slen] # final weight it cand_loss_weights = BK.where(expr_seq_labs == 0, expr_loss_weight_non.unsqueeze(-1) * conf.loss_weight_non, mask_expr) # [*, slen] final_cand_loss_weights = cand_loss_weights * mask_expr # [*, slen] loss_lab_item = LossHelper.compile_leaf_loss( f"lab", (_raw_loss * final_cand_loss_weights).sum(), final_cand_loss_weights.sum(), loss_lambda=conf.loss_lab, gold=(expr_seq_labs > 0).float().sum()) # -- # step 1.5 all_losses = [loss_lab_item] _loss_cand_entropy = conf.loss_cand_entropy if _loss_cand_entropy > 0.: _prob = extended_gprob_t # [*, slen, MW] _ent = EntropyHelper.self_entropy(_prob) # [*, slen] # [*, slen], only first one in bag _ent_mask = BK.concat([expr_seq_gaddr[:,:1]>=0, expr_seq_gaddr[:,1:]!=expr_seq_gaddr[:,:-1]],-1).float() \ * (expr_seq_labs>0).float() _loss_ent_item = LossHelper.compile_leaf_loss( f"cand_ent", (_ent * _ent_mask).sum(), _ent_mask.sum(), loss_lambda=_loss_cand_entropy) all_losses.append(_loss_ent_item) # -- # step 4: extend (select topk) if conf.loss_ext > 0.: if BK.is_zero_shape(extended_gprob_t): flt_mask = (BK.zeros(mask_expr.shape) > 0) else: _topk = min(conf.ext_loss_topk, BK.get_shape(extended_gprob_t, -1)) # number to extract _topk_grpob_t, _ = extended_gprob_t.topk( _topk, dim=-1) # [*, slen, K] flt_mask = (expr_seq_labs > 0) & (all_gprob_t >= _topk_grpob_t.min(-1)[0]) & ( _weight > conf.ext_loss_thresh) # [*, slen] flt_sidx = BK.arange_idx(bsize).unsqueeze(-1).expand_as(flt_mask)[ flt_mask] # [?] flt_expr = input_expr[flt_mask] # [?, D] flt_full_expr = self._prepare_full_expr(flt_mask) # [?, slen, D] flt_items = arr_items.flatten()[BK.get_value( expr_seq_gaddr[flt_mask])] # [?] flt_weights = _weight.detach( )[flt_mask] if conf.detach_weight_ext else _weight[flt_mask] # [?] loss_ext_item = self.ext_node.loss(flt_items, input_expr[flt_sidx], flt_expr, flt_full_expr, mask_expr[flt_sidx], flt_extra_weights=flt_weights) all_losses.append(loss_ext_item) # -- # return loss ret_loss = LossHelper.combine_multiple_losses(all_losses) return ret_loss, None
def loss(self, insts: Union[List[Sent], List[Frame]], input_expr: BK.Expr, mask_expr: BK.Expr, pair_expr: BK.Expr = None, lookup_flatten=False, external_extra_score: BK.Expr = None): conf: SoftExtractorConf = self.conf assert not lookup_flatten bsize, slen = BK.get_shape(mask_expr) # -- # step 0: prepare arr_items, expr_seq_gaddr, expr_group_widxes, expr_group_masks, expr_loss_weight_non = \ self.helper.prepare(insts, mlen=BK.get_shape(mask_expr, -1), use_cache=True) arange2_t = BK.arange_idx(bsize).unsqueeze(-1) # [*, 1] # -- # step 1: cand cand_full_scores, pred_cand_decisions = self._cand_score_and_select( input_expr, mask_expr) # [*, slen] loss_cand_items, cand_widxes, cand_masks = self._loss_feed_cand( mask_expr, cand_full_scores, pred_cand_decisions, expr_seq_gaddr, expr_group_widxes, expr_group_masks, expr_loss_weight_non, ) # ~, [*, clen] # -- # step 2: split cand_expr, cand_scores = input_expr[ arange2_t, cand_widxes], cand_full_scores[arange2_t, cand_widxes] # [*, clen] split_scores, pred_split_decisions = self._split_score( cand_expr, cand_masks) # [*, clen-1] loss_split_item, seg_masks, seg_ext_widxes, seg_ext_masks, seg_weighted_expr, oracle_gaddr = self._loss_feed_split( mask_expr, split_scores, pred_split_decisions, cand_widxes, cand_masks, cand_expr, cand_scores, expr_seq_gaddr, ) # ~, [*, seglen, *?] # -- # step 3: lab # todo(note): add a 0 as idx=-1 to make NEG ones as 0!! flatten_gold_label_idxes = BK.input_idx( [(0 if z is None else z.label_idx) for z in arr_items.flatten()] + [0]) gold_label_idxes = flatten_gold_label_idxes[ oracle_gaddr] # [*, seglen] lab_loss_weights = BK.where(oracle_gaddr >= 0, expr_loss_weight_non.unsqueeze(-1) * conf.loss_weight_non, seg_masks) # [*, seglen] final_lab_loss_weights = lab_loss_weights * seg_masks # [*, seglen] # go loss_lab, loss_count = self.lab_node.loss( seg_weighted_expr, pair_expr, seg_masks, gold_label_idxes, loss_weight_expr=final_lab_loss_weights, extra_score=external_extra_score) loss_lab_item = LossHelper.compile_leaf_loss( f"lab", loss_lab, loss_count, loss_lambda=conf.loss_lab, gold=(gold_label_idxes > 0).float().sum()) # step 4: extend flt_mask = ((gold_label_idxes > 0) & (seg_masks > 0.)) # [*, seglen] flt_sidx = BK.arange_idx(bsize).unsqueeze(-1).expand_as(flt_mask)[ flt_mask] # [?] flt_expr = seg_weighted_expr[flt_mask] # [?, D] flt_full_expr = self._prepare_full_expr(seg_ext_widxes[flt_mask], seg_ext_masks[flt_mask], slen) # [?, slen, D] flt_items = arr_items.flatten()[BK.get_value( oracle_gaddr[flt_mask])] # [?] loss_ext_item = self.ext_node.loss(flt_items, input_expr[flt_sidx], flt_expr, flt_full_expr, mask_expr[flt_sidx]) # -- # return loss ret_loss = LossHelper.combine_multiple_losses( loss_cand_items + [loss_split_item, loss_lab_item, loss_ext_item]) return ret_loss, None
def predict(self, med: ZMediator): conf: SrlInferenceHelperConf = self.conf dec = self.dec dec_conf = dec.conf # -- # prepare ibatch = med.ibatch hid_t = med.get_enc_cache( "hid").get_cached_value() # last enc layer, [*, dlen, D] _ds_idxes = ibatch.seq_info.dec_sent_idxes # [*, dlen] # -- # first predict the evt/ef frames frame_results = [] # -- _evt_settings = [ 'evt', dec.lab_evt, conf.pred_evt_nil_add, conf.pred_given_evt, dec.boundary_evt, dec_conf.max_layer_evt, dec.voc_evt, conf.pred_evt_label, conf.pred_evt_boundary, dec_conf.assume_osof, conf.pred_evt_check_layer ] _ef_settings = [ 'ef', dec.lab_ef, conf.pred_ef_nil_add, conf.pred_given_ef, dec.boundary_ef, dec_conf.max_layer_ef, dec.voc_ef, conf.pred_ef_label, conf.pred_ef_boundary, False, -1 ] if not conf.pred_ef_first: _ef_settings[0] = None # no predicting efs first!! # -- for pred_tag, node_lab, pred_nil_add, pred_given, node_boundary, pred_max_layer, voc, \ pred_label, pred_boundary, assume_osof, pred_check_layer in [_evt_settings, _ef_settings]: if pred_tag is None: frame_results.append((None, None, None)) continue # -- if pred_given: for item in ibatch.items: set_ee_heads(item.sents) # we may need a head-widx later! # -- score_cache = med.get_cache((dec.name, pred_tag)) scores_t = node_lab.score_labels( score_cache.vals, nil_add_score=pred_nil_add) # [*, dlen, L] if pred_tag == 'evt' and conf.use_cons_evt: # modify scores by constraints scores_t = self.cons_score_lu2frame( scores_t, ibatch, given_f=((lambda s: s.events) if pred_given else None)) # -> ([??], [??]), [??], [*, dlen, K] if pred_given: res = self.decode_frame_given(ibatch, scores_t, pred_max_layer, voc, pred_label, pred_tag, assume_osof) else: res = self.decode_frame(ibatch, scores_t, pred_max_layer, voc, pred_label, pred_tag, pred_check_layer) # boundary? res_idxes_t, res_frames, _ = res if pred_boundary and node_boundary is not None and len( res_frames) > 0: _flat_mask_t = ( _ds_idxes[res_idxes_t[0]] == _ds_idxes[res_idxes_t].unsqueeze(-1)).float() # [??, dlen] _flat_hid_t = hid_t[res_idxes_t[0]] # [??, dlen, D] self.assign_boundaries(res_frames, node_boundary, _flat_mask_t, _flat_hid_t, [res_idxes_t[1]]) # -- frame_results.append(res) # -- # then predict args! res_evt_idxes_t, res_evts, _ = frame_results[0] # evt results! if len(res_evts) == 0: return # note: no need to do anything! _, _, arr_efs = frame_results[1] # ef results! if arr_efs is not None: arr_efs = arr_efs[BK.get_value( res_evt_idxes_t[0])] # change to fidx at first: [??, dlen, K] # -- arg_score_cache = med.get_cache((dec.name, 'arg')) # -- base_mask_t = dec.get_dec_mask( ibatch, dec_conf.msent_pred_center) # [bs, dlen] arg_seq_mask = base_mask_t.unsqueeze(-2).expand( -1, BK.get_shape(base_mask_t, -1), -1) # [bs, dlen, dlen] # -- arg_scores_t = dec.lab_arg.score_labels( arg_score_cache.vals, seq_mask_t=arg_seq_mask, preidx_t=res_evt_idxes_t, nil_add_score=conf.pred_arg_nil_add) # [??, dlen, L] if conf.use_cons_arg: # modify scores by constraints arg_scores_t = self.cons_score_frame2role(arg_scores_t, res_evts) _pred_max_layer, _arg_allowed_sent_gap = dec_conf.max_layer_arg, dec_conf.arg_allowed_sent_gap if dec_conf.arg_use_bio: # if use BIO, then checking the seq will be fine self.decode_arg_bio(res_evts, arg_scores_t, _pred_max_layer, dec.vocab_bio_arg, _arg_allowed_sent_gap, arr_efs) else: # otherwise, still need two steps res_arg_idxes_t, res_args = self.decode_arg( res_evts, arg_scores_t, _pred_max_layer, dec.voc_arg, _arg_allowed_sent_gap, arr_efs) # [???] # arg(ef) boundary if dec.boundary_arg is not None and len( res_args) > 0 and arr_efs is None: _ab_fidxes_t, _ab_awidxes_t = res_arg_idxes_t # [???] _ab_bidxes_t, _ab_ewidxes_t = [ z[_ab_fidxes_t] for z in res_evt_idxes_t ] # [???] _ab_mask_t = (_ds_idxes[_ab_bidxes_t] == _ds_idxes[ _ab_bidxes_t, _ab_awidxes_t].unsqueeze(-1)).float() flat_hid_t = hid_t[_ab_bidxes_t] # [???, dlen, D] self.assign_boundaries(res_args, dec.boundary_arg, _ab_mask_t, flat_hid_t, [_ab_ewidxes_t, _ab_awidxes_t]) # -- # final arg post-process if self.pred_evt_filter: for evt in list(res_evts): if evt.type not in self.pred_evt_filter: evt.sent.delete_frame(evt, 'evt') for evt in res_evts: self.arg_pp.process(evt) # modified inplace! # -- return