コード例 #1
0
 def decode_frame(self, ibatch, scores_t: BK.Expr, pred_max_layer: int, voc,
                  pred_label: bool, pred_tag: str, pred_check_layer: int):
     # --
     # first get topk for each position
     logprobs_t = scores_t.log_softmax(-1)  # [*, dlen, L]
     pred_scores, pred_labels = logprobs_t.topk(
         pred_max_layer)  # [*, dlen, K]
     arr_scores, arr_labels = BK.get_value(pred_scores), BK.get_value(
         pred_labels)  # [*, dlen, K]
     # put results
     res_bidxes, res_widxes, res_frames = [], [], []  # flattened results
     res_farrs = np.full(arr_scores.shape, None,
                         dtype=object)  # [*, dlen, K]
     for bidx, item in enumerate(
             ibatch.items):  # for each item in the batch
         _dec_offsets = item.seq_info.dec_offsets
         for sidx, sent in enumerate(item.sents):
             # todo(+N): currently we only predict for center if there is!
             if item.center_sidx is not None and sidx != item.center_sidx:
                 continue  # skip non-center sent in this mode!
             _start = _dec_offsets[sidx]
             _len = len(sent)
             _arr_scores, _arr_labels = arr_scores[bidx][
                 _start:_start + _len], arr_labels[bidx][_start:_start +
                                                         _len]
             for widx in range(_len):
                 _full_widx = widx + _start  # idx in the msent
                 _tmp_set = set()
                 for _k in range(pred_max_layer):
                     _score, _lab = float(_arr_scores[widx][_k]), int(
                         _arr_labels[widx][_k])
                     if _lab == 0:  # note: lab=0 means NIL
                         break
                     _type_str = (voc.idx2word(_lab)
                                  if pred_label else "UNK")
                     _type_str_prefix = '.'.join(
                         _type_str.split('.')[:pred_check_layer])
                     if pred_check_layer >= 0 and _type_str_prefix in _tmp_set:
                         continue  # ignore since constraint
                     _tmp_set.add(_type_str_prefix)
                     # add new one!
                     res_bidxes.append(bidx)
                     res_widxes.append(_full_widx)
                     _new_frame = sent.make_frame(widx,
                                                  1,
                                                  tag=pred_tag,
                                                  type=_type_str,
                                                  score=float(_score))
                     _new_frame.set_label_idx(int(_lab))
                     _new_frame._tmp_sstart = _start  # todo(+N): ugly tmp value ...
                     _new_frame._tmp_sidx = sidx
                     _new_frame._tmp_item = item
                     res_frames.append(_new_frame)
                     res_farrs[bidx, _full_widx, _k] = _new_frame
     # return
     res_bidxes_t, res_widxes_t = BK.input_idx(res_bidxes), BK.input_idx(
         res_widxes)  # [??]
     return (res_bidxes_t,
             res_widxes_t), res_frames, res_farrs  # [??], [*, dlen, K]
コード例 #2
0
 def decode_frame_given(self, ibatch, scores_t: BK.Expr,
                        pred_max_layer: int, voc, pred_label: bool,
                        pred_tag: str, assume_osof: bool):
     if pred_label:  # if overwrite label!
         logprobs_t = scores_t.log_softmax(-1)  # [*, dlen, L]
         pred_scores, pred_labels = logprobs_t.max(
             -1)  # [*, dlen], note: maximum!
         arr_scores, arr_labels = BK.get_value(pred_scores), BK.get_value(
             pred_labels)  # [*, dlen]
     else:
         arr_scores = arr_labels = None
     # --
     # read given results
     res_bidxes, res_widxes, res_frames = [], [], []  # flattened results
     tmp_farrs = defaultdict(list)  # later assign
     for bidx, item in enumerate(
             ibatch.items):  # for each item in the batch
         _trg_frames = [item.inst] if assume_osof else \
             sum([sent.get_frames(pred_tag) for sidx,sent in enumerate(item.sents)
                  if (item.center_sidx is None or sidx == item.center_sidx)],[])  # still only pick center ones!
         # --
         _dec_offsets = item.seq_info.dec_offsets
         for _frame in _trg_frames:  # note: simply sort by original order!
             sidx = item.sents.index(_frame.sent)
             _start = _dec_offsets[sidx]
             _full_hidx = _start + _frame.mention.shead_widx
             # add new one
             res_bidxes.append(bidx)
             res_widxes.append(_full_hidx)
             _frame._tmp_sstart = _start  # todo(+N): ugly tmp value ...
             _frame._tmp_sidx = sidx
             _frame._tmp_item = item
             res_frames.append(_frame)
             tmp_farrs[(bidx, _full_hidx)].append(_frame)
             # assign/rewrite label?
             if pred_label:
                 _lab = int(arr_labels[bidx, _full_hidx])  # label index
                 _frame.set_label_idx(_lab)
                 _frame.set_label(voc.idx2word(_lab))
                 _frame.set_score(float(arr_scores[bidx, _full_hidx]))
         # --
     # --
     res_farrs = np.full(BK.get_shape(scores_t)[:-1] + [pred_max_layer],
                         None,
                         dtype=object)  # [*, dlen, K]
     for _key, _values in tmp_farrs.items():
         bidx, widx = _key
         _values = _values[:pred_max_layer]  # truncate if more!
         res_farrs[bidx, widx, :len(_values)] = _values
     # return
     res_bidxes_t, res_widxes_t = BK.input_idx(res_bidxes), BK.input_idx(
         res_widxes)  # [??]
     return (res_bidxes_t,
             res_widxes_t), res_frames, res_farrs  # [??], [*, dlen, K]
コード例 #3
0
 def assign_boundaries(self, items: List, left_idxes: BK.Expr,
                       right_idxes: BK.Expr):
     _arr_left, _arr_right = BK.get_value(left_idxes), BK.get_value(
         right_idxes)
     for ii, item in enumerate(items):
         _mention = item.mention
         _start = item._tmp_sstart  # need to minus this!!
         _left_widx, _right_widx = _arr_left[ii].item(
         ) - _start, _arr_right[ii].item() - _start
         _mention.set_span(*(_mention.get_span()),
                           shead=True)  # first move to shead!
         _mention.set_span(_left_widx, _right_widx - _left_widx + 1)
コード例 #4
0
 def decode_arg(self, res_evts: List, arg_scores_t: BK.Expr,
                pred_max_layer: int, voc, arg_allowed_sent_gap: int,
                arr_efs):
     # first get topk
     arg_logprobs_t = arg_scores_t.log_softmax(-1)  # [??, dlen, L]
     pred_arg_scores, pred_arg_labels = arg_logprobs_t.topk(
         pred_max_layer)  # [??, dlen, K]
     arr_arg_scores, arr_arg_labels = BK.get_value(
         pred_arg_scores), BK.get_value(pred_arg_labels)  # [??, dlen, K]
     # put results
     res_fidxes, res_widxes, res_args = [], [], []  # flattened results
     for fidx, evt in enumerate(res_evts):  # for each evt
         item = evt._tmp_item  # cached
         _evt_sidx = evt._tmp_sidx
         _dec_offsets = item.seq_info.dec_offsets
         for sidx, sent in enumerate(item.sents):
             if abs(sidx - _evt_sidx) > arg_allowed_sent_gap:
                 continue  # larger than allowed sentence gap
             _start = _dec_offsets[sidx]
             _len = len(sent)
             _arr_scores, _arr_labels = arr_arg_scores[fidx][
                 _start:_start + _len], arr_arg_labels[fidx][_start:_start +
                                                             _len]
             for widx in range(_len):
                 _full_widx = widx + _start  # idx in the msent
                 _new_ef = None
                 if arr_efs is not None:  # note: arr_efs should also expand to frames!
                     _new_ef = arr_efs[
                         fidx, _full_widx,
                         0]  # todo(+N): only get the first one!
                     if _new_ef is None:
                         continue  # no ef!
                 for _score, _lab in zip(_arr_scores[widx],
                                         _arr_labels[widx]):  # [K]
                     if _lab == 0:  # note: idx=0 means NIL
                         break
                     # add new one!!
                     res_fidxes.append(fidx)
                     res_widxes.append(_full_widx)
                     if _new_ef is None:
                         _new_ef = sent.make_entity_filler(
                             widx, 1)  # share them if possible!
                     _new_arg = evt.add_arg(_new_ef,
                                            role=voc.idx2word(_lab),
                                            score=float(_score))
                     _new_arg._tmp_sstart = _start  # todo(+N): ugly tmp value ...
                     res_args.append(_new_arg)
     # return
     res_fidxes_t, res_widxes_t = BK.input_idx(res_fidxes), BK.input_idx(
         res_widxes)  # [??]
     return (res_fidxes_t, res_widxes_t), res_args
コード例 #5
0
 def decode_evt(self, dec, ibatch, evt_scores_t: BK.Expr):
     _pred_max_layer_evt = dec.conf.max_layer_evt
     _voc_evt = dec.voc_evt
     _pred_evt_label = self.conf.pred_evt_label
     # --
     evt_logprobs_t = evt_scores_t.log_softmax(-1)  # [*, dlen, L]
     pred_evt_scores, pred_evt_labels = evt_logprobs_t.topk(
         _pred_max_layer_evt)  # [*, dlen, K]
     arr_evt_scores, arr_evt_labels = BK.get_value(
         pred_evt_scores), BK.get_value(pred_evt_labels)  # [*, dlen, K]
     # put results
     res_bidxes, res_widxes, res_evts = [], [], []  # flattened results
     for bidx, item in enumerate(
             ibatch.items):  # for each item in the batch
         _dec_offsets = item.seq_info.dec_offsets
         for sidx, sent in enumerate(item.sents):
             # note: here we only predict for center if there is!
             if item.center_sidx is not None and sidx != item.center_sidx:
                 continue  # skip non-center sent in this mode!
             _start = _dec_offsets[sidx]
             _len = len(sent)
             _arr_scores, _arr_labels = arr_evt_scores[bidx][
                 _start:_start + _len], arr_evt_labels[bidx][_start:_start +
                                                             _len]
             for widx in range(_len):
                 for _score, _lab in zip(_arr_scores[widx],
                                         _arr_labels[widx]):  # [K]
                     if _lab == 0:  # note: idx=0 means NIL
                         break
                     # add new one!!
                     res_bidxes.append(bidx)
                     res_widxes.append(
                         _start + widx)  # note: remember to add offset!
                     _new_evt = sent.make_event(
                         widx,
                         1,
                         type=(_voc_evt.idx2word(_lab)
                               if _pred_evt_label else "UNK"),
                         score=float(_score))
                     _new_evt.set_label_idx(int(_lab))
                     _new_evt._tmp_sstart = _start  # todo(+N): ugly tmp value ...
                     _new_evt._tmp_sidx = sidx
                     _new_evt._tmp_item = item
                     res_evts.append(_new_evt)
     # return
     res_bidxes_t, res_widxes_t = BK.input_idx(res_bidxes), BK.input_idx(
         res_widxes)  # [??]
     return (res_bidxes_t, res_widxes_t), res_evts
コード例 #6
0
ファイル: direct.py プロジェクト: zzsfornlp/zmsp
 def _get_extra_score(self, cand_score, insts, cand_res, arr_gold_items, use_cons: bool, use_lu: bool):
     # conf: DirectExtractorConf = self.conf
     # --
     # first cand score
     cand_score = self._extend_cand_score(cand_score)
     # then cons_lex score
     cons_lex_node = self.cons_lex_node
     if use_cons and cons_lex_node is not None:
         cons_lex = cons_lex_node.cons
         flt_arr_gold_items = arr_gold_items.flatten()
         _shape = BK.get_shape(cand_res.mask_expr)
         if cand_res.gaddr_expr is None:
             gaddr_expr = BK.constants(_shape, -1, dtype=BK.long)
         else:
             gaddr_expr = cand_res.gaddr_expr
         all_arrs = [BK.get_value(z) for z in [cand_res.widx_expr, cand_res.wlen_expr, cand_res.mask_expr, gaddr_expr]]
         arr_feats = np.full(_shape, None, dtype=object)
         for bidx, inst in enumerate(insts):
             one_arr_feats = arr_feats[bidx]
             _ii = -1
             for one_widx, one_wlen, one_mask, one_gaddr in zip(*[z[bidx] for z in all_arrs]):
                 _ii += 1
                 if one_mask == 0.: continue  # skip invlaid ones
                 if use_lu and one_gaddr>=0:
                     one_feat = cons_lex.lu2feat(flt_arr_gold_items[one_gaddr].info["luName"])
                 else:
                     one_feat = cons_lex.span2feat(inst, one_widx, one_wlen)
                 one_arr_feats[_ii] = one_feat
         cons_valids = cons_lex_node.lookup_with_feats(arr_feats)
         cons_score = (1.-cons_valids) * Constants.REAL_PRAC_MIN
     else:
         cons_score = None
     # sum
     return self._sum_scores(cand_score, cons_score)
コード例 #7
0
ファイル: nmst.py プロジェクト: zzsfornlp/zmsp
def _common_nmst(CPU_f, scores_expr, mask_expr, lengths_arr, labeled, ret_arr):
    assert labeled
    with BK.no_grad_env():
        # argmax-label: [BS, m, h]
        scores_unlabeled_max, labels_argmax = scores_expr.max(-1)
        #
        scores_unlabeled_max_arr = BK.get_value(scores_unlabeled_max)
        mst_heads_arr, _, mst_scores_arr = CPU_f(scores_unlabeled_max_arr, lengths_arr, labeled=False)
        # [BS, m]
        mst_heads_expr = BK.input_idx(mst_heads_arr)
        mst_labels_expr = BK.gather_one_lastdim(labels_argmax, mst_heads_expr).squeeze(-1)
        # prepare for the outputs
        if ret_arr:
            return mst_heads_arr, BK.get_value(mst_labels_expr), mst_scores_arr
        else:
            return mst_heads_expr, mst_labels_expr, BK.input_real(mst_scores_arr)
コード例 #8
0
ファイル: seq.py プロジェクト: zzsfornlp/zmsp
 def put_results(self, insts, best_labs, best_scores):
     conf: SeqExtractorConf = self.conf
     vocab: SeqVocab = self.vocab
     # --
     base_vocab = vocab.base_vocab
     arr_slabs, arr_scores = [
         BK.get_value(z) for z in [best_labs, best_scores]
     ]
     for bidx, inst in enumerate(insts):
         self._clear_f(inst)  # first clean things
         cur_len = len(inst) if isinstance(inst, Sent) else len(inst.sent)
         # --
         cur_slabs, cur_scores = arr_slabs[bidx][:cur_len], arr_scores[
             bidx][:cur_len]
         cur_results = vocab.tags2spans_idx(cur_slabs)
         inst.info["slab"] = [vocab.idx2word(z)
                              for z in cur_slabs]  # put seq-lab
         for one_widx, one_wlen, one_lab in cur_results:
             one_lab = int(one_lab)
             assert one_lab > 0, "Error: should not extract 'O'!?"
             new_item = self._new_f(
                 inst,
                 int(one_widx),
                 int(one_wlen),
                 one_lab,
                 np.mean(cur_scores[one_widx:one_widx + one_wlen]).item(),
                 vocab=base_vocab,
             )
コード例 #9
0
 def put_results(self, insts, best_labs, best_scores):
     conf: AnchorExtractorConf = self.conf
     # --
     arr_labs, arr_scores = [
         BK.get_value(z) for z in [best_labs, best_scores]
     ]
     flattened_items = []
     for bidx, inst in enumerate(insts):
         self._clear_f(inst)  # first clean things
         cur_len = len(inst) if isinstance(inst, Sent) else len(inst.sent)
         cur_labs, cur_scores = arr_labs[bidx][:cur_len], arr_scores[
             bidx][:cur_len]
         # simply put them
         for one_widx in range(cur_len):
             one_lab, one_score = int(cur_labs[one_widx]), float(
                 cur_scores[one_widx])
             # todo(+N): again, assuming NON-idx == 0
             if one_lab == 0: continue  # invalid one: unmask or NON
             # set it
             new_item = self._new_f(inst, one_widx, 1, one_lab,
                                    float(one_score))
             new_item.mention.info["widxes1"] = [one_widx]  # save it
             flattened_items.append(new_item)
         # --
     return flattened_items
コード例 #10
0
 def predict(self, med: ZMediator):
     conf: ZDecoderUDEPConf = self.conf
     # --
     # depth scores
     all_depth_raw_scores = med.main_scores.get((self.name, "depth"), [])  # [*, slen]
     all_depth_logprobs = [BK.logsigmoid(z.squeeze(-1)) for z in all_depth_raw_scores]
     if len(all_depth_logprobs) > 0:
         final_depth_logprobs = self.depth_node.helper.pred(all_logprobs=all_depth_logprobs)  # [*, slen]
     else:
         final_depth_logprobs = -99.
     # udep scores
     all_udep_raw_scores = med.main_scores.get((self.name, "udep"), [])  # [*, slen, slen, L]
     all_udep_logprobs = [z.log_softmax(-1) for z in all_udep_raw_scores]
     final_udep_logprobs = self.udep_node.helper.pred(all_logprobs=all_udep_logprobs)  # [*, slen, slen, L]
     # prepare final scores
     final_scores = BK.pad(final_udep_logprobs, [0,0,1,0,1,0], value=Constants.REAL_PRAC_MIN)  # [*, 1+slen, 1+slen, L]
     final_scores[:, :, :, 0] = Constants.REAL_PRAC_MIN  # force no 0!!
     final_scores[:, 1:, 0, self._label_idx_root] = (final_depth_logprobs + conf.udep_pred_root_penalty)  # assign root score
     # decode
     from msp2.tools.algo.nmst import mst_unproj  # decoding algorithm
     insts = med.insts
     arr_lengths = np.asarray([len(z.sent)+1 for z in insts])  # +1 for arti-root
     arr_scores = BK.get_value(final_scores)  # [*, 1+slen, 1+slen, L]
     arr_ret_heads, arr_ret_labels, arr_ret_scores = mst_unproj(arr_scores, arr_lengths, labeled=True)  # [*, 1+slen]
     self.helper.put_results(insts, [arr_ret_heads, arr_ret_labels, arr_ret_scores])
     # --
     return {}
コード例 #11
0
 def decode_evt_given(self, dec, ibatch, evt_scores_t: BK.Expr):
     _voc_evt = dec.voc_evt
     _assume_osof = dec.conf.assume_osof  # one seq one frame
     _pred_evt_label = self.conf.pred_evt_label
     # --
     if _pred_evt_label:
         evt_logprobs_t = evt_scores_t.log_softmax(-1)  # [*, dlen, L]
         pred_evt_scores, pred_evt_labels = evt_logprobs_t.max(
             -1)  # [*, dlen], note: maximum!
         arr_evt_scores, arr_evt_labels = BK.get_value(
             pred_evt_scores), BK.get_value(pred_evt_labels)  # [*, dlen]
     else:
         arr_evt_scores = arr_evt_labels = None
     # --
     # read given results
     res_bidxes, res_widxes, res_evts = [], [], []  # flattened results
     for bidx, item in enumerate(
             ibatch.items):  # for each item in the batch
         _trg_evts = [item.inst] if _assume_osof else \
             sum([sent.events for sidx,sent in enumerate(item.sents) if (item.center_sidx is None or sidx == item.center_sidx)],[])
         # --
         _dec_offsets = item.seq_info.dec_offsets
         for _evt in _trg_evts:
             sidx = item.sents.index(_evt.sent)
             _start = _dec_offsets[sidx]
             _full_hidx = _start + _evt.mention.shead_widx
             # add new one
             res_bidxes.append(bidx)
             res_widxes.append(_full_hidx)
             _evt._tmp_sstart = _start  # todo(+N): ugly tmp value ...
             _evt._tmp_sidx = sidx
             _evt._tmp_item = item
             res_evts.append(_evt)
             # assign label?
             if _pred_evt_label:
                 _lab = int(arr_evt_labels[bidx, _full_hidx])  # label index
                 _evt.set_label_idx(_lab)
                 _evt.set_label(_voc_evt.idx2word(_lab))
                 _evt.set_score(float(arr_evt_scores[bidx, _full_hidx]))
         # --
     # --
     # return
     res_bidxes_t, res_widxes_t = BK.input_idx(res_bidxes), BK.input_idx(
         res_widxes)  # [??]
     return (res_bidxes_t, res_widxes_t), res_evts
コード例 #12
0
ファイル: upos.py プロジェクト: zzsfornlp/zmsp
 def predict(self, med: ZMediator):
     # --
     pred_upos_labels, pred_upos_scores = self._pred_upos()
     all_arrs = [
         BK.get_value(z) for z in [pred_upos_labels, pred_upos_scores]
     ]
     self.helper.put_results(med.insts, all_arrs)
     # --
     return {}
コード例 #13
0
 def decode_arg(self, dec, res_evts: List, arg_scores_t: BK.Expr):
     _pred_max_layer_arg = dec.conf.max_layer_arg
     _arg_allowed_sent_gap = dec.conf.arg_allowed_sent_gap
     _voc_arg = dec.voc_arg
     # --
     arg_logprobs_t = arg_scores_t.log_softmax(-1)  # [??, dlen, L]
     pred_arg_scores, pred_arg_labels = arg_logprobs_t.topk(
         _pred_max_layer_arg)  # [??, dlen, K]
     arr_arg_scores, arr_arg_labels = BK.get_value(
         pred_arg_scores), BK.get_value(pred_arg_labels)  # [??, dlen, K]
     # put results
     res_fidxes, res_widxes, res_args = [], [], []  # flattened results
     for fidx, evt in enumerate(res_evts):  # for each evt
         item = evt._tmp_item  # cached
         _evt_sidx = evt._tmp_sidx
         _dec_offsets = item.seq_info.dec_offsets
         for sidx, sent in enumerate(item.sents):
             if abs(sidx - _evt_sidx) > _arg_allowed_sent_gap:
                 continue  # larger than allowed sentence gap
             _start = _dec_offsets[sidx]
             _len = len(sent)
             _arr_scores, _arr_labels = arr_arg_scores[fidx][
                 _start:_start + _len], arr_arg_labels[fidx][_start:_start +
                                                             _len]
             for widx in range(_len):
                 for _score, _lab in zip(_arr_scores[widx],
                                         _arr_labels[widx]):  # [K]
                     if _lab == 0:  # note: idx=0 means NIL
                         break
                     # add new one!!
                     res_fidxes.append(fidx)
                     res_widxes.append(_start + widx)
                     _new_ef = sent.make_entity_filler(widx, 1)
                     _new_arg = evt.add_arg(_new_ef,
                                            role=_voc_arg.idx2word(_lab),
                                            score=float(_score))
                     _new_arg._tmp_sstart = _start  # todo(+N): ugly tmp value ...
                     res_args.append(_new_arg)
     # return
     res_fidxes_t, res_widxes_t = BK.input_idx(res_fidxes), BK.input_idx(
         res_widxes)  # [??]
     return (res_fidxes_t, res_widxes_t), res_args
コード例 #14
0
ファイル: direct.py プロジェクト: zzsfornlp/zmsp
 def put_labels(self, arr_items, best_labs, best_scores):
     assert BK.get_shape(best_labs) == BK.get_shape(arr_items)
     # --
     all_arrs = [BK.get_value(z) for z in [best_labs, best_scores]]
     for cur_items, cur_labs, cur_scores in zip(arr_items, *all_arrs):
         for one_item, one_lab, one_score in zip(cur_items, cur_labs, cur_scores):
             if one_item is None: continue
             one_lab, one_score = int(one_lab), float(one_score)
             one_item.score = one_score
             one_item.set_label_idx(one_lab)
             one_item.set_label(self.vocab.idx2word(one_lab))
コード例 #15
0
 def decode_arg_bio(self, res_evts: List, arg_scores_t: BK.Expr,
                    pred_max_layer: int, voc_bio, arg_allowed_sent_gap: int,
                    arr_efs):
     assert pred_max_layer == 1, "Currently BIO only allow one!"
     assert arr_efs is None, "Currently BIO does not allow pick given efs!"
     _vocab_bio_arg = voc_bio
     _vocab_arg = _vocab_bio_arg.base_vocab
     # --
     arg_logprobs_t = arg_scores_t.log_softmax(-1)  # [??, dlen, L]
     # todo(+N): allow multiple bio seqs?
     pred_arg_scores, pred_arg_labels = arg_logprobs_t.max(
         -1)  # note: here only top1, [??, dlen]
     arr_arg_scores, arr_arg_labels = BK.get_value(
         pred_arg_scores), BK.get_value(pred_arg_labels)  # [??, dlen]
     # put results
     for fidx, evt in enumerate(res_evts):  # for each evt
         item = evt._tmp_item  # cached
         _evt_sidx = evt._tmp_sidx
         _dec_offsets = item.seq_info.dec_offsets
         for sidx, sent in enumerate(item.sents):
             if abs(sidx - _evt_sidx) > arg_allowed_sent_gap:
                 continue  # larger than allowed sentence gap
             _start = _dec_offsets[sidx]
             _len = len(sent)
             _arr_scores, _arr_labels = arr_arg_scores[fidx][
                 _start:_start + _len], arr_arg_labels[fidx][_start:_start +
                                                             _len]
             # decode bio
             _arg_results = _vocab_bio_arg.tags2spans_idx(_arr_labels)
             for a_widx, a_wlen, a_lab in _arg_results:
                 a_lab = int(a_lab)
                 assert a_lab > 0, "Error: should not extract 'O'!?"
                 _new_ef = sent.make_entity_filler(a_widx, a_wlen)
                 a_role = _vocab_arg.idx2word(a_lab)
                 _new_arg = evt.add_arg(_new_ef,
                                        a_role,
                                        score=np.mean(
                                            _arr_scores[a_widx:a_widx +
                                                        a_wlen]).item())
     # --
     return  # no need to return anything here
コード例 #16
0
ファイル: direct.py プロジェクト: zzsfornlp/zmsp
 def put_results(self, insts, best_labs, best_scores, widx_expr, wlen_expr, mask_expr):
     conf: DirectExtractorConf = self.conf
     conf_pred_ignore_non = conf.pred_ignore_non
     # --
     all_arrs = [BK.get_value(z) for z in [best_labs, best_scores, widx_expr, wlen_expr, mask_expr]]
     for bidx, inst in enumerate(insts):
         self._clear_f(inst)  # first clean things
         for one_lab, one_score, one_widx, one_wlen, one_mask in zip(*[z[bidx] for z in all_arrs]):
             one_lab = int(one_lab)
             # todo(+N): again, assuming NON-idx == 0
             if one_mask == 0. or (conf_pred_ignore_non and one_lab == 0): continue  # invalid one: unmask or NON
             new_item = self._new_f(inst, int(one_widx), int(one_wlen), one_lab, float(one_score))
コード例 #17
0
ファイル: nmst.py プロジェクト: zzsfornlp/zmsp
def nmarginal_proj(scores_expr, mask_expr, lengths_arr, labeled=True):
    assert labeled
    with BK.no_grad_env():
        # first make it unlabeled by sum-exp
        scores_unlabeled = BK.logsumexp(scores_expr, dim=-1)  # [BS, m, h]
        # marginal for unlabeled
        scores_unlabeled_arr = BK.get_value(scores_unlabeled)
        marginals_unlabeled_arr = marginal_proj(scores_unlabeled_arr, lengths_arr, False)
        # back to labeled values
        marginals_unlabeled_expr = BK.input_real(marginals_unlabeled_arr)
        marginals_labeled_expr = marginals_unlabeled_expr.unsqueeze(-1) * BK.exp(scores_expr - scores_unlabeled.unsqueeze(-1))
        # [BS, m, h, L]
        return _ensure_margins_norm(marginals_labeled_expr)
コード例 #18
0
 def predict(self, flt_items, flt_input_expr, flt_pair_expr, flt_full_expr, flt_mask_expr):
     conf: ExtenderConf = self.conf
     if len(flt_items) <= 0:
         return None  # no input item!
     # --
     enc_t = self._forward_eenc(flt_input_expr, flt_full_expr, flt_mask_expr)  # [*, D]
     s_left, s_right = self.enode.score(enc_t, flt_pair_expr if conf.ext_use_finput else None, flt_mask_expr)
     # --
     max_scores, left_idxes, right_idxes = SpanExpanderNode.decode_with_scores(s_left, s_right, normalize=True)
     all_arrs = [BK.get_value(z) for z in [left_idxes, right_idxes]]
     for cur_item, cur_left_idx, cur_right_idx in zip(flt_items, *all_arrs):
         new_widx, new_wlen = int(cur_left_idx), int(cur_right_idx+1-cur_left_idx)
         self.ext_span_setter(cur_item.mention, new_widx, new_wlen)
コード例 #19
0
ファイル: soft.py プロジェクト: zzsfornlp/zmsp
 def put_results(self, insts, best_labs, best_scores, seg_masks,
                 seg_ext_widxes0, seg_ext_widxes, seg_ext_masks0,
                 seg_ext_masks, cand_full_scores, cand_decisions,
                 split_decisions):
     conf: SoftExtractorConf = self.conf
     # --
     all_arrs = [
         BK.get_value(z) for z in [
             best_labs, best_scores, seg_masks, seg_ext_widxes0,
             seg_ext_widxes, seg_ext_masks0, seg_ext_masks,
             cand_full_scores, cand_decisions, split_decisions
         ]
     ]
     flattened_items = []
     for bidx, inst in enumerate(insts):
         self._clear_f(inst)  # first clean things
         cur_len = len(inst) if isinstance(inst, Sent) else len(inst.sent)
         # first set general result
         res_cand_score = all_arrs[-3][bidx, :cur_len].tolist()
         res_cand = all_arrs[-2][bidx, :cur_len].tolist()
         res_split = [1.] + all_arrs[-1][bidx, :int(sum(res_cand)) -
                                         1].tolist()  # note: actually B-?
         inst.info.update({
             "res_cand": res_cand,
             "res_cand_score": res_cand_score,
             "res_split": res_split
         })
         # then set them separately
         for one_lab, one_score, one_mask, one_widxes0, one_widxes, one_wmasks0, one_wmasks in \
                 zip(*[z[bidx] for z in all_arrs[:-3]]):
             one_lab = int(one_lab)
             # todo(+N): again, assuming NON-idx == 0
             if one_mask == 0. or one_lab == 0:
                 continue  # invalid one: unmask or NON
             # get widxes
             cur_widxes0 = sorted(one_widxes0[one_wmasks0 > 0.].tolist()
                                  )  # original selections from cand
             cur_widxes = sorted(one_widxes[
                 one_wmasks > 0.].tolist())  # further possible topk
             tmp_widx = cur_widxes[0]  # currently simply set a tmp one!
             tmp_wlen = cur_widxes[-1] + 1 - tmp_widx
             # set it
             new_item = self._new_f(inst, tmp_widx, tmp_wlen, one_lab,
                                    float(one_score))
             new_item.mention.info[
                 "widxes0"] = cur_widxes0  # idxes after cand
             new_item.mention.info[
                 "widxes1"] = cur_widxes  # idxes after split
             flattened_items.append(new_item)
         # --
     return flattened_items
コード例 #20
0
ファイル: helper.py プロジェクト: zzsfornlp/zmsp
def select_topk_non_overlapping(score_t: BK.Expr,
                                topk_t: Union[int, BK.Expr],
                                widx_t: BK.Expr,
                                wlen_t: BK.Expr,
                                input_mask_t: BK.Expr,
                                mask_t: BK.Expr = None,
                                dim=-1):
    score_shape = BK.get_shape(score_t)
    assert dim == -1 or dim == len(
        score_shape - 1
    ), "Currently only support last-dim!!"  # todo(+2): we can permute to allow any dim!
    # --
    # prepare K
    if isinstance(topk_t, int):
        tmp_shape = score_shape.copy()
        tmp_shape[dim] = 1  # set it as 1
        topk_t = BK.constants_idx(tmp_shape, topk_t)
    # --
    reshape_trg = [np.prod(score_shape[:-1]).item(), -1]  # [*, ?]
    _, sorted_idxes_t = score_t.sort(dim, descending=True)
    # --
    # put it as CPU and use loop; todo(+N): more efficient ways?
    arr_sorted_idxes_t, arr_topk_t, arr_widx_t, arr_wlen_t, arr_input_mask_t, arr_mask_t = \
        [BK.get_value(z.reshape(reshape_trg)) for z in [sorted_idxes_t, topk_t, widx_t, wlen_t, input_mask_t, mask_t]]
    _bsize, _cnum = BK.get_shape(arr_sorted_idxes_t)  # [bsize, NUM]
    arr_topk_mask = np.full([_bsize, _cnum], 0.)  # [bsize, NUM]
    _bidx = 0
    for aslice_sorted_idxes_t, aslice_topk_t, aslice_widx_t, aslice_wlen_t, aslice_input_mask_t, aslice_mask_t \
            in zip(arr_sorted_idxes_t, arr_topk_t, arr_widx_t, arr_wlen_t, arr_input_mask_t, arr_mask_t):
        aslice_topk_mask = arr_topk_mask[_bidx]
        # --
        cur_ok_mask = np.copy(aslice_input_mask_t)
        cur_budget = aslice_topk_t.item()
        for _cidx in aslice_sorted_idxes_t:
            _cidx = _cidx.item()
            if cur_budget <= 0: break  # no budget left
            if not aslice_mask_t[_cidx].item(): continue  # non-valid candidate
            one_widx, one_wlen = aslice_widx_t[_cidx].item(
            ), aslice_wlen_t[_cidx].item()
            if np.prod(cur_ok_mask[one_widx:one_widx +
                                   one_wlen]).item() == 0.:  # any hit one?
                continue
            # ok! add it!
            cur_budget -= 1
            cur_ok_mask[one_widx:one_widx + one_wlen] = 0.
            aslice_topk_mask[_cidx] = 1.
        _bidx += 1
    # note: no need to *=mask_t again since already check in the loop
    return BK.input_real(arr_topk_mask).reshape(score_shape)
コード例 #21
0
ファイル: srl.py プロジェクト: zzsfornlp/zmsp
 def predict(self, med: ZMediator):
     conf: ZDecoderSRLConf = self.conf
     insts, mask_expr = med.insts, med.get_mask_t()
     # --
     pred_evt_labels, pred_evt_scores = self._pred_evt()
     pred_arg_labels, pred_arg_scores = self._pred_arg(mask_expr, pred_evt_labels)
     # transfer data from gpu also counts (also make sure gpu calculations are done)!
     all_arrs = [BK.get_value(z) for z in [pred_evt_labels, pred_evt_scores, pred_arg_labels, pred_arg_scores]]
     # =====
     # assign; also record post-processing (non-computing) time
     time0 = time.time()
     self.helper.put_results(insts, all_arrs)
     time1 = time.time()
     # --
     return {f"{self.name}_posttime": time1-time0}
コード例 #22
0
ファイル: dec_udep.py プロジェクト: zzsfornlp/zmsp
 def decode_udep(self, ibatch, udep_logprobs_t: BK.Expr, root_logprobs_t: BK.Expr):
     conf: ZDecoderUdepConf = self.conf
     # --
     arr_udep = BK.get_value(udep_logprobs_t.transpose(-2,-3))  # [*, m, h, L]
     arr_root = None if root_logprobs_t is None else BK.get_value(root_logprobs_t)  # [*, dlen]
     _dim_label = arr_udep.shape[-1]
     _neg = -10000.  # should be enough!!
     _voc, _lab_range = self.ztask.vpack
     _idx_root = self._label_idx_root
     # --
     for bidx, item in enumerate(ibatch.items):  # for each item in the batch
         _dec_offsets = item.seq_info.dec_offsets
         for sidx, sent in enumerate(item.sents):
             if conf.msent_pred_center and (sidx != item.center_sidx):
                 continue  # skip non-center sent in this mode!
             _start = _dec_offsets[sidx]
             _len = len(sent)
             _len_p1 = _len + 1
             # --
             _arr = np.full([_len_p1, _len_p1, _dim_label], _neg, dtype=np.float32)  # [1+m, 1+h, L]
             # assign label scores
             _arr[1:_len_p1, 1:_len_p1, 1:_lab_range] = arr_udep[bidx, _start:_start+_len, _start:_start+_len, 1:_lab_range]
             # assign root scores
             if arr_root is not None:
                 _arr[1:_len_p1, 0, _idx_root] = arr_root[bidx, _start:_start+_len]
             else:  # todo(+N): currently simply assign a smaller "neg-inf"
                 _arr[1:_len_p1, 0, _idx_root] = -99.
             # --
             from msp2.tools.algo.nmst import mst_unproj  # decoding algorithm
             arr_ret_heads, arr_ret_labels, arr_ret_scores = \
                 mst_unproj(_arr[None], np.asarray([_len_p1]), labeled=True)  # [*, 1+slen]
             # assign
             list_dep_heads = arr_ret_heads[0, 1:_len_p1].tolist()
             list_dep_lidxes = arr_ret_labels[0, 1:_len_p1].tolist()
             list_dep_labels = _voc.seq_idx2word(list_dep_lidxes)
             sent.build_dep_tree(list_dep_heads, list_dep_labels)
コード例 #23
0
 def assign_boundaries(self, items: List, boundary_node,
                       flat_mask_t: BK.Expr, flat_hid_t: BK.Expr,
                       indicators: List):
     flat_indicators = boundary_node.prepare_indicators(
         indicators, BK.get_shape(flat_mask_t))
     # --
     _bsize, _dlen = BK.get_shape(flat_mask_t)  # [???, dlen]
     _once_bsize = max(1, int(self.conf.boundary_bsize / max(1, _dlen)))
     # --
     if _once_bsize >= _bsize:
         _, _left_idxes, _right_idxes = boundary_node.decode(
             flat_hid_t, flat_mask_t, flat_indicators)  # [???]
     else:
         _all_left_idxes, _all_right_idxes = [], []
         for ii in range(0, _bsize, _once_bsize):
             _, _one_left_idxes, _one_right_idxes = boundary_node.decode(
                 flat_hid_t[ii:ii + _once_bsize],
                 flat_mask_t[ii:ii + _once_bsize],
                 [z[ii:ii + _once_bsize] for z in flat_indicators])
             _all_left_idxes.append(_one_left_idxes)
             _all_right_idxes.append(_one_right_idxes)
         _left_idxes, _right_idxes = BK.concat(_all_left_idxes,
                                               0), BK.concat(
                                                   _all_right_idxes, 0)
     _arr_left, _arr_right = BK.get_value(_left_idxes), BK.get_value(
         _right_idxes)
     for ii, item in enumerate(items):
         _mention = item.mention
         _start = item._tmp_sstart  # need to minus this!!
         _left_widx, _right_widx = _arr_left[ii].item(
         ) - _start, _arr_right[ii].item() - _start
         # todo(+N): sometimes we can have repeated ones, currently simply over-write!
         if _mention.get_span()[1] == 1:
             _mention.set_span(*(_mention.get_span()),
                               shead=True)  # first move to shead!
         _mention.set_span(_left_widx, _right_widx - _left_widx + 1)
コード例 #24
0
 def loss_on_batch(self, insts: List, loss_factor=1., training=True, force_lidx=None, **kwargs):
     conf: ZmtlModelConf = self.conf
     self.refresh_batch(training)
     # --
     # import torch
     # torch.autograd.set_detect_anomaly(True)
     # --
     actual_insts = list(self._yield_insts(insts))
     med = self.med
     enc_cached_input = self.enc.prepare_inputs(actual_insts)
     # ==
     # if needed, forward other models (can be self)
     aug_scores = {}
     with BK.no_grad_env():
         if conf.aug_times >= 1:
             # forward all at once!!
             _mm_input = enc_cached_input if (conf.aug_times == 1) else self.enc.prepare_inputs(actual_insts*conf.aug_times)
             for mm in self.aug_models:  # add them all to aug_scores!!
                 mm.enc_forward(_mm_input, aug_scores, conf.aug_training_flag)
     # ==
     self.refresh_batch(training)
     med.force_lidx = force_lidx  # note: special assign
     # enc
     self.enc.forward(None, med, cached_input=enc_cached_input)
     # dec
     med.aug_scores = aug_scores  # note: assign here!!
     all_losses = med.do_losses()
     # --
     # final loss and backward
     info = {"inst0": len(insts), "inst": len(actual_insts), "fb": 1, "fb0": 0}
     final_loss, loss_info = self.collect_loss(all_losses, ret_dict=(self.pcgrad is not None))
     info.update(loss_info)
     if training:
         if self.pcgrad is not None:
             # self.pcgrad.do_backward(self.parameters(), final_loss, loss_factor)
             # note: we only specially treat enc's, for others, grads will always be accumulated!
             self.pcgrad.do_backward(self.enc.parameters(), final_loss, loss_factor)
         else:  # as usual
             # assert final_loss.requires_grad
             if BK.get_value(final_loss).item() > 0:  # note: loss should be >0 usually!!
                 BK.backward(final_loss, loss_factor)
             else:  # no need to backwrad if no loss
                 info["fb0"] = 1
     med.restart()  # clean!
     med.force_lidx = None  # clear!
     return info
コード例 #25
0
ファイル: srl.py プロジェクト: zzsfornlp/zmsp
 def put_results(self, insts: List[Sent], best_evt_labs, best_evt_scores, best_arg_labs, best_arg_scores):
     conf: MySRLConf = self.conf
     _evt_pred_use_posi = conf.evt_pred_use_posi
     vocab_evt = self.vocab_evt
     vocab_arg = self.vocab_arg
     if conf.arg_use_bio:
         real_vocab_arg = vocab_arg.base_vocab
     else:
         real_vocab_arg = vocab_arg
     # --
     all_arrs = [BK.get_value(z) for z in [best_evt_labs, best_evt_scores, best_arg_labs, best_arg_scores]]
     for bidx, inst in enumerate(insts):
         inst.delete_frames(conf.arg_ftag)  # delete old args
         # --
         cur_len = len(inst)
         cur_evt_labs, cur_evt_scores, cur_arg_labs, cur_arg_scores = [z[bidx][:cur_len] for z in all_arrs]
         inst.info["evt_lab"] = [vocab_evt.idx2word(z) if z>0 else 'O' for z in cur_evt_labs]
         # --
         if _evt_pred_use_posi:  # special mode
             for evt in inst.get_frames(conf.evt_ftag):
                 # reuse posi but re-assign label!
                 one_widx = evt.mention.shead_widx
                 one_lab, one_score = cur_evt_labs[one_widx].item(), cur_evt_scores[one_widx].item()
                 evt.set_label(vocab_evt.idx2word(one_lab))
                 evt.set_label_idx(one_lab)
                 evt.score = one_score
                 # args
                 new_arg_scores = cur_arg_scores[one_widx][:cur_len]
                 new_arg_label_idxes = cur_arg_labs[one_widx][:cur_len]
                 self.decode_arg(evt, new_arg_label_idxes, new_arg_scores, vocab_arg, real_vocab_arg)
         else:  # pred everything!
             inst.delete_frames(conf.evt_ftag)
             for one_widx in range(cur_len):
                 one_lab, one_score = cur_evt_labs[one_widx].item(), cur_evt_scores[one_widx].item()
                 if one_lab == 0:
                     continue
                 # make new evt
                 new_evt = inst.make_frame(one_widx, 1, conf.evt_ftag, type=vocab_evt.idx2word(one_lab), score=one_score)
                 new_evt.set_label_idx(one_lab)
                 self.evt_span_setter(new_evt.mention, one_widx, 1)
                 # args
                 new_arg_scores = cur_arg_scores[one_widx][:cur_len]
                 new_arg_label_idxes = cur_arg_labs[one_widx][:cur_len]
                 self.decode_arg(new_evt, new_arg_label_idxes, new_arg_scores, vocab_arg, real_vocab_arg)
コード例 #26
0
ファイル: nmst.py プロジェクト: zzsfornlp/zmsp
def nmst_greedy(scores_expr, mask_expr, lengths_arr, labeled=True, ret_arr=False):
    assert labeled
    with BK.no_grad_env():
        scores_shape = BK.get_shape(scores_expr)
        maxlen = scores_shape[1]
        # mask out diag
        scores_expr += BK.diagflat(BK.constants([maxlen], Constants.REAL_PRAC_MIN)).unsqueeze(-1)
        # combined last two dimension and Max over them
        combined_scores_expr = scores_expr.view(scores_shape[:-2] + [-1])
        combine_max_scores, combined_max_idxes = BK.max(combined_scores_expr, dim=-1)
        # back to real idxes
        last_size = scores_shape[-1]
        greedy_heads = combined_max_idxes // last_size
        greedy_labels = combined_max_idxes % last_size
        if ret_arr:
            mst_heads_arr, mst_labels_arr, mst_scores_arr = [BK.get_value(z) for z in (greedy_heads, greedy_labels, combine_max_scores)]
            return mst_heads_arr, mst_labels_arr, mst_scores_arr
        else:
            return greedy_heads, greedy_labels, combine_max_scores
コード例 #27
0
ファイル: dec_upos.py プロジェクト: zzsfornlp/zmsp
 def decode_upos(self, ibatch, logprobs_t: BK.Expr):
     conf: ZDecoderUposConf = self.conf
     # get argmax label!
     pred_upos_scores, pred_upos_labels = logprobs_t.max(-1)  # [*, dlen]
     # arr_upos_scores, arr_upos_labels = BK.get_value(pred_upos_scores), BK.get_value(pred_upos_labels)
     arr_upos_labels = BK.get_value(pred_upos_labels)
     # put results
     voc = self.voc
     for bidx, item in enumerate(
             ibatch.items):  # for each item in the batch
         _dec_offsets = item.seq_info.dec_offsets
         for sidx, sent in enumerate(item.sents):
             if conf.msent_pred_center and (sidx != item.center_sidx):
                 continue  # skip non-center sent in this mode!
             _start = _dec_offsets[sidx]
             _len = len(sent)
             _upos_idxes = arr_upos_labels[bidx][_start:_start +
                                                 _len].tolist()
             _upos_labels = voc.seq_idx2word(_upos_idxes)
             sent.build_uposes(_upos_labels)
コード例 #28
0
 def loss(self,
          insts: Union[List[Sent], List[Frame]],
          input_expr: BK.Expr,
          mask_expr: BK.Expr,
          pair_expr: BK.Expr = None,
          lookup_flatten=False,
          external_extra_score: BK.Expr = None):
     conf: AnchorExtractorConf = self.conf
     assert not lookup_flatten
     bsize, slen = BK.get_shape(mask_expr)
     # --
     # step 0: prepare
     arr_items, expr_seq_gaddr, expr_seq_labs, expr_group_widxes, expr_group_masks, expr_loss_weight_non = \
         self.helper.prepare(insts, mlen=BK.get_shape(mask_expr, -1), use_cache=True)
     arange2_t = BK.arange_idx(bsize).unsqueeze(-1)  # [*, 1]
     arange3_t = arange2_t.unsqueeze(-1)  # [*, 1, 1]
     # --
     # step 1: label, simply scoring everything!
     _main_t, _pair_t = self.lab_node.transform_expr(input_expr, pair_expr)
     all_scores_t = self.lab_node.score_all(
         _main_t,
         _pair_t,
         mask_expr,
         None,
         local_normalize=False,
         extra_score=external_extra_score
     )  # unnormalized scores [*, slen, L]
     all_probs_t = all_scores_t.softmax(-1)  # [*, slen, L]
     all_gprob_t = all_probs_t.gather(-1,
                                      expr_seq_labs.unsqueeze(-1)).squeeze(
                                          -1)  # [*, slen]
     # how to weight
     extended_gprob_t = all_gprob_t[
         arange3_t, expr_group_widxes] * expr_group_masks  # [*, slen, MW]
     if BK.is_zero_shape(extended_gprob_t):
         extended_gprob_max_t = BK.zeros(mask_expr.shape)  # [*, slen]
     else:
         extended_gprob_max_t, _ = extended_gprob_t.max(-1)  # [*, slen]
     _w_alpha = conf.cand_loss_weight_alpha
     _weight = (
         (all_gprob_t * mask_expr) /
         (extended_gprob_max_t.clamp(min=1e-5)))**_w_alpha  # [*, slen]
     _label_smoothing = conf.lab_conf.labeler_conf.label_smoothing
     _loss1 = BK.loss_nll(all_scores_t,
                          expr_seq_labs,
                          label_smoothing=_label_smoothing)  # [*, slen]
     _loss2 = BK.loss_nll(all_scores_t,
                          BK.constants_idx([bsize, slen], 0),
                          label_smoothing=_label_smoothing)  # [*, slen]
     _weight1 = _weight.detach() if conf.detach_weight_lab else _weight
     _raw_loss = _weight1 * _loss1 + (1. - _weight1) * _loss2  # [*, slen]
     # final weight it
     cand_loss_weights = BK.where(expr_seq_labs == 0,
                                  expr_loss_weight_non.unsqueeze(-1) *
                                  conf.loss_weight_non,
                                  mask_expr)  # [*, slen]
     final_cand_loss_weights = cand_loss_weights * mask_expr  # [*, slen]
     loss_lab_item = LossHelper.compile_leaf_loss(
         f"lab", (_raw_loss * final_cand_loss_weights).sum(),
         final_cand_loss_weights.sum(),
         loss_lambda=conf.loss_lab,
         gold=(expr_seq_labs > 0).float().sum())
     # --
     # step 1.5
     all_losses = [loss_lab_item]
     _loss_cand_entropy = conf.loss_cand_entropy
     if _loss_cand_entropy > 0.:
         _prob = extended_gprob_t  # [*, slen, MW]
         _ent = EntropyHelper.self_entropy(_prob)  # [*, slen]
         # [*, slen], only first one in bag
         _ent_mask = BK.concat([expr_seq_gaddr[:,:1]>=0, expr_seq_gaddr[:,1:]!=expr_seq_gaddr[:,:-1]],-1).float() \
                     * (expr_seq_labs>0).float()
         _loss_ent_item = LossHelper.compile_leaf_loss(
             f"cand_ent", (_ent * _ent_mask).sum(),
             _ent_mask.sum(),
             loss_lambda=_loss_cand_entropy)
         all_losses.append(_loss_ent_item)
     # --
     # step 4: extend (select topk)
     if conf.loss_ext > 0.:
         if BK.is_zero_shape(extended_gprob_t):
             flt_mask = (BK.zeros(mask_expr.shape) > 0)
         else:
             _topk = min(conf.ext_loss_topk,
                         BK.get_shape(extended_gprob_t,
                                      -1))  # number to extract
             _topk_grpob_t, _ = extended_gprob_t.topk(
                 _topk, dim=-1)  # [*, slen, K]
             flt_mask = (expr_seq_labs >
                         0) & (all_gprob_t >= _topk_grpob_t.min(-1)[0]) & (
                             _weight > conf.ext_loss_thresh)  # [*, slen]
         flt_sidx = BK.arange_idx(bsize).unsqueeze(-1).expand_as(flt_mask)[
             flt_mask]  # [?]
         flt_expr = input_expr[flt_mask]  # [?, D]
         flt_full_expr = self._prepare_full_expr(flt_mask)  # [?, slen, D]
         flt_items = arr_items.flatten()[BK.get_value(
             expr_seq_gaddr[flt_mask])]  # [?]
         flt_weights = _weight.detach(
         )[flt_mask] if conf.detach_weight_ext else _weight[flt_mask]  # [?]
         loss_ext_item = self.ext_node.loss(flt_items,
                                            input_expr[flt_sidx],
                                            flt_expr,
                                            flt_full_expr,
                                            mask_expr[flt_sidx],
                                            flt_extra_weights=flt_weights)
         all_losses.append(loss_ext_item)
     # --
     # return loss
     ret_loss = LossHelper.combine_multiple_losses(all_losses)
     return ret_loss, None
コード例 #29
0
ファイル: soft.py プロジェクト: zzsfornlp/zmsp
 def loss(self,
          insts: Union[List[Sent], List[Frame]],
          input_expr: BK.Expr,
          mask_expr: BK.Expr,
          pair_expr: BK.Expr = None,
          lookup_flatten=False,
          external_extra_score: BK.Expr = None):
     conf: SoftExtractorConf = self.conf
     assert not lookup_flatten
     bsize, slen = BK.get_shape(mask_expr)
     # --
     # step 0: prepare
     arr_items, expr_seq_gaddr, expr_group_widxes, expr_group_masks, expr_loss_weight_non = \
         self.helper.prepare(insts, mlen=BK.get_shape(mask_expr, -1), use_cache=True)
     arange2_t = BK.arange_idx(bsize).unsqueeze(-1)  # [*, 1]
     # --
     # step 1: cand
     cand_full_scores, pred_cand_decisions = self._cand_score_and_select(
         input_expr, mask_expr)  # [*, slen]
     loss_cand_items, cand_widxes, cand_masks = self._loss_feed_cand(
         mask_expr,
         cand_full_scores,
         pred_cand_decisions,
         expr_seq_gaddr,
         expr_group_widxes,
         expr_group_masks,
         expr_loss_weight_non,
     )  # ~, [*, clen]
     # --
     # step 2: split
     cand_expr, cand_scores = input_expr[
         arange2_t, cand_widxes], cand_full_scores[arange2_t,
                                                   cand_widxes]  # [*, clen]
     split_scores, pred_split_decisions = self._split_score(
         cand_expr, cand_masks)  # [*, clen-1]
     loss_split_item, seg_masks, seg_ext_widxes, seg_ext_masks, seg_weighted_expr, oracle_gaddr = self._loss_feed_split(
         mask_expr,
         split_scores,
         pred_split_decisions,
         cand_widxes,
         cand_masks,
         cand_expr,
         cand_scores,
         expr_seq_gaddr,
     )  # ~, [*, seglen, *?]
     # --
     # step 3: lab
     # todo(note): add a 0 as idx=-1 to make NEG ones as 0!!
     flatten_gold_label_idxes = BK.input_idx(
         [(0 if z is None else z.label_idx)
          for z in arr_items.flatten()] + [0])
     gold_label_idxes = flatten_gold_label_idxes[
         oracle_gaddr]  # [*, seglen]
     lab_loss_weights = BK.where(oracle_gaddr >= 0,
                                 expr_loss_weight_non.unsqueeze(-1) *
                                 conf.loss_weight_non,
                                 seg_masks)  # [*, seglen]
     final_lab_loss_weights = lab_loss_weights * seg_masks  # [*, seglen]
     # go
     loss_lab, loss_count = self.lab_node.loss(
         seg_weighted_expr,
         pair_expr,
         seg_masks,
         gold_label_idxes,
         loss_weight_expr=final_lab_loss_weights,
         extra_score=external_extra_score)
     loss_lab_item = LossHelper.compile_leaf_loss(
         f"lab",
         loss_lab,
         loss_count,
         loss_lambda=conf.loss_lab,
         gold=(gold_label_idxes > 0).float().sum())
     # step 4: extend
     flt_mask = ((gold_label_idxes > 0) & (seg_masks > 0.))  # [*, seglen]
     flt_sidx = BK.arange_idx(bsize).unsqueeze(-1).expand_as(flt_mask)[
         flt_mask]  # [?]
     flt_expr = seg_weighted_expr[flt_mask]  # [?, D]
     flt_full_expr = self._prepare_full_expr(seg_ext_widxes[flt_mask],
                                             seg_ext_masks[flt_mask],
                                             slen)  # [?, slen, D]
     flt_items = arr_items.flatten()[BK.get_value(
         oracle_gaddr[flt_mask])]  # [?]
     loss_ext_item = self.ext_node.loss(flt_items, input_expr[flt_sidx],
                                        flt_expr, flt_full_expr,
                                        mask_expr[flt_sidx])
     # --
     # return loss
     ret_loss = LossHelper.combine_multiple_losses(
         loss_cand_items + [loss_split_item, loss_lab_item, loss_ext_item])
     return ret_loss, None
コード例 #30
0
 def predict(self, med: ZMediator):
     conf: SrlInferenceHelperConf = self.conf
     dec = self.dec
     dec_conf = dec.conf
     # --
     # prepare
     ibatch = med.ibatch
     hid_t = med.get_enc_cache(
         "hid").get_cached_value()  # last enc layer, [*, dlen, D]
     _ds_idxes = ibatch.seq_info.dec_sent_idxes  # [*, dlen]
     # --
     # first predict the evt/ef frames
     frame_results = []
     # --
     _evt_settings = [
         'evt', dec.lab_evt, conf.pred_evt_nil_add, conf.pred_given_evt,
         dec.boundary_evt, dec_conf.max_layer_evt, dec.voc_evt,
         conf.pred_evt_label, conf.pred_evt_boundary, dec_conf.assume_osof,
         conf.pred_evt_check_layer
     ]
     _ef_settings = [
         'ef', dec.lab_ef, conf.pred_ef_nil_add, conf.pred_given_ef,
         dec.boundary_ef, dec_conf.max_layer_ef, dec.voc_ef,
         conf.pred_ef_label, conf.pred_ef_boundary, False, -1
     ]
     if not conf.pred_ef_first:
         _ef_settings[0] = None  # no predicting efs first!!
     # --
     for pred_tag, node_lab, pred_nil_add, pred_given, node_boundary, pred_max_layer, voc, \
         pred_label, pred_boundary, assume_osof, pred_check_layer in [_evt_settings, _ef_settings]:
         if pred_tag is None:
             frame_results.append((None, None, None))
             continue
         # --
         if pred_given:
             for item in ibatch.items:
                 set_ee_heads(item.sents)  # we may need a head-widx later!
         # --
         score_cache = med.get_cache((dec.name, pred_tag))
         scores_t = node_lab.score_labels(
             score_cache.vals, nil_add_score=pred_nil_add)  # [*, dlen, L]
         if pred_tag == 'evt' and conf.use_cons_evt:  # modify scores by constraints
             scores_t = self.cons_score_lu2frame(
                 scores_t,
                 ibatch,
                 given_f=((lambda s: s.events) if pred_given else None))
         # -> ([??], [??]), [??], [*, dlen, K]
         if pred_given:
             res = self.decode_frame_given(ibatch, scores_t, pred_max_layer,
                                           voc, pred_label, pred_tag,
                                           assume_osof)
         else:
             res = self.decode_frame(ibatch, scores_t, pred_max_layer, voc,
                                     pred_label, pred_tag, pred_check_layer)
         # boundary?
         res_idxes_t, res_frames, _ = res
         if pred_boundary and node_boundary is not None and len(
                 res_frames) > 0:
             _flat_mask_t = (
                 _ds_idxes[res_idxes_t[0]] ==
                 _ds_idxes[res_idxes_t].unsqueeze(-1)).float()  # [??, dlen]
             _flat_hid_t = hid_t[res_idxes_t[0]]  # [??, dlen, D]
             self.assign_boundaries(res_frames, node_boundary, _flat_mask_t,
                                    _flat_hid_t, [res_idxes_t[1]])
         # --
         frame_results.append(res)
     # --
     # then predict args!
     res_evt_idxes_t, res_evts, _ = frame_results[0]  # evt results!
     if len(res_evts) == 0:
         return  # note: no need to do anything!
     _, _, arr_efs = frame_results[1]  # ef results!
     if arr_efs is not None:
         arr_efs = arr_efs[BK.get_value(
             res_evt_idxes_t[0])]  # change to fidx at first: [??, dlen, K]
     # --
     arg_score_cache = med.get_cache((dec.name, 'arg'))
     # --
     base_mask_t = dec.get_dec_mask(
         ibatch, dec_conf.msent_pred_center)  # [bs, dlen]
     arg_seq_mask = base_mask_t.unsqueeze(-2).expand(
         -1, BK.get_shape(base_mask_t, -1), -1)  # [bs, dlen, dlen]
     # --
     arg_scores_t = dec.lab_arg.score_labels(
         arg_score_cache.vals,
         seq_mask_t=arg_seq_mask,
         preidx_t=res_evt_idxes_t,
         nil_add_score=conf.pred_arg_nil_add)  # [??, dlen, L]
     if conf.use_cons_arg:  # modify scores by constraints
         arg_scores_t = self.cons_score_frame2role(arg_scores_t, res_evts)
     _pred_max_layer, _arg_allowed_sent_gap = dec_conf.max_layer_arg, dec_conf.arg_allowed_sent_gap
     if dec_conf.arg_use_bio:  # if use BIO, then checking the seq will be fine
         self.decode_arg_bio(res_evts, arg_scores_t, _pred_max_layer,
                             dec.vocab_bio_arg, _arg_allowed_sent_gap,
                             arr_efs)
     else:  # otherwise, still need two steps
         res_arg_idxes_t, res_args = self.decode_arg(
             res_evts, arg_scores_t, _pred_max_layer, dec.voc_arg,
             _arg_allowed_sent_gap, arr_efs)  # [???]
         # arg(ef) boundary
         if dec.boundary_arg is not None and len(
                 res_args) > 0 and arr_efs is None:
             _ab_fidxes_t, _ab_awidxes_t = res_arg_idxes_t  # [???]
             _ab_bidxes_t, _ab_ewidxes_t = [
                 z[_ab_fidxes_t] for z in res_evt_idxes_t
             ]  # [???]
             _ab_mask_t = (_ds_idxes[_ab_bidxes_t] == _ds_idxes[
                 _ab_bidxes_t, _ab_awidxes_t].unsqueeze(-1)).float()
             flat_hid_t = hid_t[_ab_bidxes_t]  # [???, dlen, D]
             self.assign_boundaries(res_args, dec.boundary_arg, _ab_mask_t,
                                    flat_hid_t,
                                    [_ab_ewidxes_t, _ab_awidxes_t])
     # --
     # final arg post-process
     if self.pred_evt_filter:
         for evt in list(res_evts):
             if evt.type not in self.pred_evt_filter:
                 evt.sent.delete_frame(evt, 'evt')
     for evt in res_evts:
         self.arg_pp.process(evt)  # modified inplace!
     # --
     return