Exemple #1
0
 def prepare(self, insts: List[Sent], use_cache: bool):
     # get info
     if use_cache:
         zobjs = []
         attr_name = f"_cache_srl"  # should be unique
         for s in insts:
             one = getattr(s, attr_name, None)
             if one is None:
                 one = self._prep_sent(s)
                 setattr(s, attr_name, one)  # set cache
             zobjs.append(one)
     else:
         zobjs = [self._prep_sent(s) for s in insts]
     # batch things
     bsize, mlen = len(insts), max(z.slen for z in zobjs) if len(zobjs)>0 else 1  # at least put one as padding
     batched_shape = (bsize, mlen)
     arr_items = np.full(batched_shape, None, dtype=object)
     arr_evt_labels = np.full(batched_shape, 0, dtype=np.int)
     arr_arg_labels = np.full(batched_shape+(mlen,), 0, dtype=np.int)
     for zidx, zobj in enumerate(zobjs):
         zlen = zobj.slen
         arr_items[zidx, :zlen] = zobj.evt_items
         arr_evt_labels[zidx, :zlen] = zobj.evt_arr
         arr_arg_labels[zidx, :zlen, :zlen] = zobj.arg_arr
     expr_evt_labels = BK.input_idx(arr_evt_labels)  # [*, slen]
     expr_arg_labels = BK.input_idx(arr_arg_labels)  # [*, slen]
     expr_loss_weight_non = BK.input_real([z.loss_weight_non for z in zobjs])  # [*]
     return arr_items, expr_evt_labels, expr_arg_labels, expr_loss_weight_non
Exemple #2
0
 def decode_frame(self, ibatch, scores_t: BK.Expr, pred_max_layer: int, voc,
                  pred_label: bool, pred_tag: str, pred_check_layer: int):
     # --
     # first get topk for each position
     logprobs_t = scores_t.log_softmax(-1)  # [*, dlen, L]
     pred_scores, pred_labels = logprobs_t.topk(
         pred_max_layer)  # [*, dlen, K]
     arr_scores, arr_labels = BK.get_value(pred_scores), BK.get_value(
         pred_labels)  # [*, dlen, K]
     # put results
     res_bidxes, res_widxes, res_frames = [], [], []  # flattened results
     res_farrs = np.full(arr_scores.shape, None,
                         dtype=object)  # [*, dlen, K]
     for bidx, item in enumerate(
             ibatch.items):  # for each item in the batch
         _dec_offsets = item.seq_info.dec_offsets
         for sidx, sent in enumerate(item.sents):
             # todo(+N): currently we only predict for center if there is!
             if item.center_sidx is not None and sidx != item.center_sidx:
                 continue  # skip non-center sent in this mode!
             _start = _dec_offsets[sidx]
             _len = len(sent)
             _arr_scores, _arr_labels = arr_scores[bidx][
                 _start:_start + _len], arr_labels[bidx][_start:_start +
                                                         _len]
             for widx in range(_len):
                 _full_widx = widx + _start  # idx in the msent
                 _tmp_set = set()
                 for _k in range(pred_max_layer):
                     _score, _lab = float(_arr_scores[widx][_k]), int(
                         _arr_labels[widx][_k])
                     if _lab == 0:  # note: lab=0 means NIL
                         break
                     _type_str = (voc.idx2word(_lab)
                                  if pred_label else "UNK")
                     _type_str_prefix = '.'.join(
                         _type_str.split('.')[:pred_check_layer])
                     if pred_check_layer >= 0 and _type_str_prefix in _tmp_set:
                         continue  # ignore since constraint
                     _tmp_set.add(_type_str_prefix)
                     # add new one!
                     res_bidxes.append(bidx)
                     res_widxes.append(_full_widx)
                     _new_frame = sent.make_frame(widx,
                                                  1,
                                                  tag=pred_tag,
                                                  type=_type_str,
                                                  score=float(_score))
                     _new_frame.set_label_idx(int(_lab))
                     _new_frame._tmp_sstart = _start  # todo(+N): ugly tmp value ...
                     _new_frame._tmp_sidx = sidx
                     _new_frame._tmp_item = item
                     res_frames.append(_new_frame)
                     res_farrs[bidx, _full_widx, _k] = _new_frame
     # return
     res_bidxes_t, res_widxes_t = BK.input_idx(res_bidxes), BK.input_idx(
         res_widxes)  # [??]
     return (res_bidxes_t,
             res_widxes_t), res_frames, res_farrs  # [??], [*, dlen, K]
Exemple #3
0
 def decode_frame_given(self, ibatch, scores_t: BK.Expr,
                        pred_max_layer: int, voc, pred_label: bool,
                        pred_tag: str, assume_osof: bool):
     if pred_label:  # if overwrite label!
         logprobs_t = scores_t.log_softmax(-1)  # [*, dlen, L]
         pred_scores, pred_labels = logprobs_t.max(
             -1)  # [*, dlen], note: maximum!
         arr_scores, arr_labels = BK.get_value(pred_scores), BK.get_value(
             pred_labels)  # [*, dlen]
     else:
         arr_scores = arr_labels = None
     # --
     # read given results
     res_bidxes, res_widxes, res_frames = [], [], []  # flattened results
     tmp_farrs = defaultdict(list)  # later assign
     for bidx, item in enumerate(
             ibatch.items):  # for each item in the batch
         _trg_frames = [item.inst] if assume_osof else \
             sum([sent.get_frames(pred_tag) for sidx,sent in enumerate(item.sents)
                  if (item.center_sidx is None or sidx == item.center_sidx)],[])  # still only pick center ones!
         # --
         _dec_offsets = item.seq_info.dec_offsets
         for _frame in _trg_frames:  # note: simply sort by original order!
             sidx = item.sents.index(_frame.sent)
             _start = _dec_offsets[sidx]
             _full_hidx = _start + _frame.mention.shead_widx
             # add new one
             res_bidxes.append(bidx)
             res_widxes.append(_full_hidx)
             _frame._tmp_sstart = _start  # todo(+N): ugly tmp value ...
             _frame._tmp_sidx = sidx
             _frame._tmp_item = item
             res_frames.append(_frame)
             tmp_farrs[(bidx, _full_hidx)].append(_frame)
             # assign/rewrite label?
             if pred_label:
                 _lab = int(arr_labels[bidx, _full_hidx])  # label index
                 _frame.set_label_idx(_lab)
                 _frame.set_label(voc.idx2word(_lab))
                 _frame.set_score(float(arr_scores[bidx, _full_hidx]))
         # --
     # --
     res_farrs = np.full(BK.get_shape(scores_t)[:-1] + [pred_max_layer],
                         None,
                         dtype=object)  # [*, dlen, K]
     for _key, _values in tmp_farrs.items():
         bidx, widx = _key
         _values = _values[:pred_max_layer]  # truncate if more!
         res_farrs[bidx, widx, :len(_values)] = _values
     # return
     res_bidxes_t, res_widxes_t = BK.input_idx(res_bidxes), BK.input_idx(
         res_widxes)  # [??]
     return (res_bidxes_t,
             res_widxes_t), res_frames, res_farrs  # [??], [*, dlen, K]
Exemple #4
0
 def decode_arg(self, res_evts: List, arg_scores_t: BK.Expr,
                pred_max_layer: int, voc, arg_allowed_sent_gap: int,
                arr_efs):
     # first get topk
     arg_logprobs_t = arg_scores_t.log_softmax(-1)  # [??, dlen, L]
     pred_arg_scores, pred_arg_labels = arg_logprobs_t.topk(
         pred_max_layer)  # [??, dlen, K]
     arr_arg_scores, arr_arg_labels = BK.get_value(
         pred_arg_scores), BK.get_value(pred_arg_labels)  # [??, dlen, K]
     # put results
     res_fidxes, res_widxes, res_args = [], [], []  # flattened results
     for fidx, evt in enumerate(res_evts):  # for each evt
         item = evt._tmp_item  # cached
         _evt_sidx = evt._tmp_sidx
         _dec_offsets = item.seq_info.dec_offsets
         for sidx, sent in enumerate(item.sents):
             if abs(sidx - _evt_sidx) > arg_allowed_sent_gap:
                 continue  # larger than allowed sentence gap
             _start = _dec_offsets[sidx]
             _len = len(sent)
             _arr_scores, _arr_labels = arr_arg_scores[fidx][
                 _start:_start + _len], arr_arg_labels[fidx][_start:_start +
                                                             _len]
             for widx in range(_len):
                 _full_widx = widx + _start  # idx in the msent
                 _new_ef = None
                 if arr_efs is not None:  # note: arr_efs should also expand to frames!
                     _new_ef = arr_efs[
                         fidx, _full_widx,
                         0]  # todo(+N): only get the first one!
                     if _new_ef is None:
                         continue  # no ef!
                 for _score, _lab in zip(_arr_scores[widx],
                                         _arr_labels[widx]):  # [K]
                     if _lab == 0:  # note: idx=0 means NIL
                         break
                     # add new one!!
                     res_fidxes.append(fidx)
                     res_widxes.append(_full_widx)
                     if _new_ef is None:
                         _new_ef = sent.make_entity_filler(
                             widx, 1)  # share them if possible!
                     _new_arg = evt.add_arg(_new_ef,
                                            role=voc.idx2word(_lab),
                                            score=float(_score))
                     _new_arg._tmp_sstart = _start  # todo(+N): ugly tmp value ...
                     res_args.append(_new_arg)
     # return
     res_fidxes_t, res_widxes_t = BK.input_idx(res_fidxes), BK.input_idx(
         res_widxes)  # [??]
     return (res_fidxes_t, res_widxes_t), res_args
Exemple #5
0
 def decode_evt(self, dec, ibatch, evt_scores_t: BK.Expr):
     _pred_max_layer_evt = dec.conf.max_layer_evt
     _voc_evt = dec.voc_evt
     _pred_evt_label = self.conf.pred_evt_label
     # --
     evt_logprobs_t = evt_scores_t.log_softmax(-1)  # [*, dlen, L]
     pred_evt_scores, pred_evt_labels = evt_logprobs_t.topk(
         _pred_max_layer_evt)  # [*, dlen, K]
     arr_evt_scores, arr_evt_labels = BK.get_value(
         pred_evt_scores), BK.get_value(pred_evt_labels)  # [*, dlen, K]
     # put results
     res_bidxes, res_widxes, res_evts = [], [], []  # flattened results
     for bidx, item in enumerate(
             ibatch.items):  # for each item in the batch
         _dec_offsets = item.seq_info.dec_offsets
         for sidx, sent in enumerate(item.sents):
             # note: here we only predict for center if there is!
             if item.center_sidx is not None and sidx != item.center_sidx:
                 continue  # skip non-center sent in this mode!
             _start = _dec_offsets[sidx]
             _len = len(sent)
             _arr_scores, _arr_labels = arr_evt_scores[bidx][
                 _start:_start + _len], arr_evt_labels[bidx][_start:_start +
                                                             _len]
             for widx in range(_len):
                 for _score, _lab in zip(_arr_scores[widx],
                                         _arr_labels[widx]):  # [K]
                     if _lab == 0:  # note: idx=0 means NIL
                         break
                     # add new one!!
                     res_bidxes.append(bidx)
                     res_widxes.append(
                         _start + widx)  # note: remember to add offset!
                     _new_evt = sent.make_event(
                         widx,
                         1,
                         type=(_voc_evt.idx2word(_lab)
                               if _pred_evt_label else "UNK"),
                         score=float(_score))
                     _new_evt.set_label_idx(int(_lab))
                     _new_evt._tmp_sstart = _start  # todo(+N): ugly tmp value ...
                     _new_evt._tmp_sidx = sidx
                     _new_evt._tmp_item = item
                     res_evts.append(_new_evt)
     # return
     res_bidxes_t, res_widxes_t = BK.input_idx(res_bidxes), BK.input_idx(
         res_widxes)  # [??]
     return (res_bidxes_t, res_widxes_t), res_evts
Exemple #6
0
 def forward(self, inputs, add_bos=False, add_eos=False):
     conf: PlainInputEmbedderConf = self.conf
     # --
     voc = self.voc
     input_t = BK.input_idx(inputs)  # [*, len]
     # rare unk in training
     if self.is_training() and self.use_rare_unk:
         rare_unk_rate = conf.rare_unk_rate
         cur_unk_imask = (
             self.rare_unk_mask[input_t] *
             (BK.rand(BK.get_shape(input_t)) < rare_unk_rate)).long()
         input_t = input_t * (1 - cur_unk_imask) + voc.unk * cur_unk_imask
     # bos and eos
     all_input_slices = []
     slice_shape = BK.get_shape(input_t)[:-1] + [1]
     if add_bos:
         all_input_slices.append(
             BK.constants(slice_shape, voc.bos, dtype=input_t.dtype))
     all_input_slices.append(input_t)  # [*, len]
     if add_eos:
         all_input_slices.append(
             BK.constants(slice_shape, voc.eos, dtype=input_t.dtype))
     final_input_t = BK.concat(all_input_slices, -1)  # [*, 1?+len+1?]
     # finally
     ret = self.E(final_input_t)  # [*, ??, dim]
     return ret
Exemple #7
0
 def prepare(self, insts: Union[List[Sent], List[Frame]], mlen: int,
             use_cache: bool):
     conf: SeqExtractorConf = self.conf
     # get info
     if use_cache:
         zobjs = []
         attr_name = f"_scache_{conf.ftag}"  # should be unique
         for s in insts:
             one = getattr(s, attr_name, None)
             if one is None:
                 one = self._prep_f(s)
                 setattr(s, attr_name, one)  # set cache
             zobjs.append(one)
     else:
         zobjs = [self._prep_f(s) for s in insts]
     # batch things
     bsize = len(insts)
     # mlen = max(z.len for z in zobjs)  # note: fed by outside!!
     batched_shape = (bsize, mlen)
     # arr_first_items = np.full(batched_shape, None, dtype=object)
     arr_slabs = np.full(batched_shape, 0, dtype=np.int)
     for zidx, zobj in enumerate(zobjs):
         # arr_first_items[zidx, zobj.len] = zobj.first_items
         arr_slabs[zidx, :zobj.len] = zobj.tags
     # final setup things
     expr_slabs = BK.input_idx(arr_slabs)
     expr_loss_weight_non = BK.input_real(
         [z.loss_weight_non for z in zobjs])  # [*]
     # return arr_first_items, expr_slabs, expr_loss_weight_non
     return expr_slabs, expr_loss_weight_non
Exemple #8
0
 def forward(self, inputs, add_bos=False, add_eos=False):
     conf: PosiInputEmbedderConf = self.conf
     # --
     try:
         # input is a shape as prepared by "PosiHelper"
         batch_size, max_len = inputs
         if add_bos:
             max_len += 1
         if add_eos:
             max_len += 1
         posi_idxes = BK.arange_idx(max_len)  # [?len?]
         ret = self.E(posi_idxes).unsqueeze(0).expand(batch_size, -1, -1)
     except:
         # input is tensor
         posi_idxes = BK.input_idx(inputs)  # [*, len]
         cur_maxlen = BK.get_shape(posi_idxes, -1)
         # --
         all_input_slices = []
         slice_shape = BK.get_shape(posi_idxes)[:-1] + [1]
         if add_bos:  # add 0 and offset
             all_input_slices.append(
                 BK.constants(slice_shape, 0, dtype=posi_idxes.dtype))
             cur_maxlen += 1
             posi_idxes += 1
         all_input_slices.append(posi_idxes)  # [*, len]
         if add_eos:
             all_input_slices.append(
                 BK.constants(slice_shape,
                              cur_maxlen,
                              dtype=posi_idxes.dtype))
         final_input_t = BK.concat(all_input_slices, -1)  # [*, 1?+len+1?]
         # finally
         ret = self.E(final_input_t)  # [*, ??, dim]
     return ret
Exemple #9
0
 def get_dec_mask(self, ibatch, center_only: bool):
     if center_only:
         center_idxes = BK.input_idx([z.center_sidx for z in ibatch.items]).unsqueeze(-1)  # [bs, 1]
         ret_mask = (ibatch.seq_info.dec_sent_idxes == center_idxes).float()  # [bs, dlen]
     else:  # otherwise, simply further exclude CLS/PAD
         ret_mask = (ibatch.seq_info.dec_sent_idxes >= 0).float()  # [*, dlen]
     # ret_mask *= ibatch.seq_info.dec_sel_masks  # [*, dlen], note: no need for this
     return ret_mask
Exemple #10
0
 def _get_arg_external_extra_score(self, flt_items):
     if self.cons_arg is not None:
         evt_idxes = [(0 if z is None else z.label_idx) for z in flt_items]
         valid_masks = self.cons_arg.lookup(BK.input_idx(evt_idxes))  # [*, L]
         ret = Constants.REAL_PRAC_MIN * (1. - valid_masks)  # [*, L]
         return ret.unsqueeze(-2)  # [bs, 1, L], let later broadcast!
     else:
         return None
Exemple #11
0
 def batched_rev_idxes(self):
     if self._batched_rev_idxes is None:
         padder = DataPadder(2, pad_vals=0)  # again pad 0
         batched_rev_idxes, _ = padder.pad([
             s.align_info.split2orig for s in self.seq_subs
         ])  # [bsize, sub_len]
         self._batched_rev_idxes = BK.input_idx(batched_rev_idxes)
     return self._batched_rev_idxes  # [bsize, sub_len]
Exemple #12
0
 def decode_evt_given(self, dec, ibatch, evt_scores_t: BK.Expr):
     _voc_evt = dec.voc_evt
     _assume_osof = dec.conf.assume_osof  # one seq one frame
     _pred_evt_label = self.conf.pred_evt_label
     # --
     if _pred_evt_label:
         evt_logprobs_t = evt_scores_t.log_softmax(-1)  # [*, dlen, L]
         pred_evt_scores, pred_evt_labels = evt_logprobs_t.max(
             -1)  # [*, dlen], note: maximum!
         arr_evt_scores, arr_evt_labels = BK.get_value(
             pred_evt_scores), BK.get_value(pred_evt_labels)  # [*, dlen]
     else:
         arr_evt_scores = arr_evt_labels = None
     # --
     # read given results
     res_bidxes, res_widxes, res_evts = [], [], []  # flattened results
     for bidx, item in enumerate(
             ibatch.items):  # for each item in the batch
         _trg_evts = [item.inst] if _assume_osof else \
             sum([sent.events for sidx,sent in enumerate(item.sents) if (item.center_sidx is None or sidx == item.center_sidx)],[])
         # --
         _dec_offsets = item.seq_info.dec_offsets
         for _evt in _trg_evts:
             sidx = item.sents.index(_evt.sent)
             _start = _dec_offsets[sidx]
             _full_hidx = _start + _evt.mention.shead_widx
             # add new one
             res_bidxes.append(bidx)
             res_widxes.append(_full_hidx)
             _evt._tmp_sstart = _start  # todo(+N): ugly tmp value ...
             _evt._tmp_sidx = sidx
             _evt._tmp_item = item
             res_evts.append(_evt)
             # assign label?
             if _pred_evt_label:
                 _lab = int(arr_evt_labels[bidx, _full_hidx])  # label index
                 _evt.set_label_idx(_lab)
                 _evt.set_label(_voc_evt.idx2word(_lab))
                 _evt.set_score(float(arr_evt_scores[bidx, _full_hidx]))
         # --
     # --
     # return
     res_bidxes_t, res_widxes_t = BK.input_idx(res_bidxes), BK.input_idx(
         res_widxes)  # [??]
     return (res_bidxes_t, res_widxes_t), res_evts
Exemple #13
0
 def prepare(self, ibatch):
     b_seq_info = ibatch.seq_info
     bsize, dlen = BK.get_shape(b_seq_info.dec_sel_masks)
     arr_udep_labels = np.full([bsize, dlen, dlen], 0, dtype=np.int)  # by default 0
     arr_head = np.full([bsize, dlen], -1, dtype=np.int)  # the 0 ones are root
     for bidx, item in enumerate(ibatch.items):  # for each item in the batch
         _dec_offsets = item.seq_info.dec_offsets
         for sidx, sent in enumerate(item.sents):  # for each sent in the msent item
             tree = sent.tree_dep
             _start = _dec_offsets[sidx]
             _slen = len(sent)
             # note: here transpose it: (h,m), arti-root not included!
             arr_udep_labels[bidx, _start:_start+_slen, _start:_start+_slen] = tree.label_matrix[:, 1:].T
             arr_head[bidx, _start:_start+_slen] = tree.seq_head.vals
     # --
     expr_udep_labels = BK.input_idx(arr_udep_labels)  # [bs, dlen, dlen]
     expr_isroot = (BK.input_idx(arr_head) == 0).long()  # [bs, dlen]
     return expr_udep_labels, expr_isroot
Exemple #14
0
 def prepare(self, insts: List, use_cache: bool = None):
     bsize, mlen = len(insts), max(len(z.sent)
                                   for z in insts) if len(insts) > 0 else 1
     batched_shape = (bsize, mlen)
     arr_upos_labels = np.full(batched_shape, 0, dtype=np.int)
     for bidx, inst in enumerate(insts):
         zlen = len(inst.sent)
         arr_upos_labels[bidx, :zlen] = inst.sent.seq_upos.idxes
     expr_upos_labels = BK.input_idx(arr_upos_labels)
     return expr_upos_labels
Exemple #15
0
 def prepare(self, insts: Union[List[Sent], List[Frame]], mlen: int,
             use_cache: bool):
     conf: AnchorExtractorConf = self.conf
     # get info
     if use_cache:
         zobjs = []
         attr_name = f"_acache_{conf.ftag}"  # should be unique
         for s in insts:
             one = getattr(s, attr_name, None)
             if one is None:
                 one = self._prep_f(s)
                 setattr(s, attr_name, one)  # set cache
             zobjs.append(one)
     else:
         zobjs = [self._prep_f(s) for s in insts]
     # batch things
     bsize, mlen2 = len(insts), max(len(z.items)
                                    for z in zobjs) if len(zobjs) > 0 else 1
     mnum = max(len(g) for z in zobjs
                for g in z.group_widxes) if len(zobjs) > 0 else 1
     arr_items = np.full((bsize, mlen2), None, dtype=object)  # [*, ?]
     arr_seq_iidxes = np.full((bsize, mlen), -1, dtype=np.int)
     arr_seq_labs = np.full((bsize, mlen), 0, dtype=np.int)
     arr_group_widxes = np.full((bsize, mlen, mnum), 0, dtype=np.int)
     arr_group_masks = np.full((bsize, mlen, mnum), 0., dtype=np.float)
     for zidx, zobj in enumerate(zobjs):
         arr_items[zidx, :len(zobj.items)] = zobj.items
         iidx_offset = zidx * mlen2  # note: offset for valid ones!
         arr_seq_iidxes[zidx, :len(zobj.seq_iidxes)] = [
             (iidx_offset + ii) if ii >= 0 else ii for ii in zobj.seq_iidxes
         ]
         arr_seq_labs[zidx, :len(zobj.seq_labs)] = zobj.seq_labs
         for zidx2, zwidxes in enumerate(zobj.group_widxes):
             arr_group_widxes[zidx, zidx2, :len(zwidxes)] = zwidxes
             arr_group_masks[zidx, zidx2, :len(zwidxes)] = 1.
     # final setup things
     expr_seq_iidxes = BK.input_idx(arr_seq_iidxes)  # [*, slen]
     expr_seq_labs = BK.input_idx(arr_seq_labs)  # [*, slen]
     expr_group_widxes = BK.input_idx(arr_group_widxes)  # [*, slen, MW]
     expr_group_masks = BK.input_real(arr_group_masks)  # [*, slen, MW]
     expr_loss_weight_non = BK.input_real(
         [z.loss_weight_non for z in zobjs])  # [*]
     return arr_items, expr_seq_iidxes, expr_seq_labs, expr_group_widxes, expr_group_masks, expr_loss_weight_non
Exemple #16
0
 def decode_arg(self, dec, res_evts: List, arg_scores_t: BK.Expr):
     _pred_max_layer_arg = dec.conf.max_layer_arg
     _arg_allowed_sent_gap = dec.conf.arg_allowed_sent_gap
     _voc_arg = dec.voc_arg
     # --
     arg_logprobs_t = arg_scores_t.log_softmax(-1)  # [??, dlen, L]
     pred_arg_scores, pred_arg_labels = arg_logprobs_t.topk(
         _pred_max_layer_arg)  # [??, dlen, K]
     arr_arg_scores, arr_arg_labels = BK.get_value(
         pred_arg_scores), BK.get_value(pred_arg_labels)  # [??, dlen, K]
     # put results
     res_fidxes, res_widxes, res_args = [], [], []  # flattened results
     for fidx, evt in enumerate(res_evts):  # for each evt
         item = evt._tmp_item  # cached
         _evt_sidx = evt._tmp_sidx
         _dec_offsets = item.seq_info.dec_offsets
         for sidx, sent in enumerate(item.sents):
             if abs(sidx - _evt_sidx) > _arg_allowed_sent_gap:
                 continue  # larger than allowed sentence gap
             _start = _dec_offsets[sidx]
             _len = len(sent)
             _arr_scores, _arr_labels = arr_arg_scores[fidx][
                 _start:_start + _len], arr_arg_labels[fidx][_start:_start +
                                                             _len]
             for widx in range(_len):
                 for _score, _lab in zip(_arr_scores[widx],
                                         _arr_labels[widx]):  # [K]
                     if _lab == 0:  # note: idx=0 means NIL
                         break
                     # add new one!!
                     res_fidxes.append(fidx)
                     res_widxes.append(_start + widx)
                     _new_ef = sent.make_entity_filler(widx, 1)
                     _new_arg = evt.add_arg(_new_ef,
                                            role=_voc_arg.idx2word(_lab),
                                            score=float(_score))
                     _new_arg._tmp_sstart = _start  # todo(+N): ugly tmp value ...
                     res_args.append(_new_arg)
     # return
     res_fidxes_t, res_widxes_t = BK.input_idx(res_fidxes), BK.input_idx(
         res_widxes)  # [??]
     return (res_fidxes_t, res_widxes_t), res_args
Exemple #17
0
 def prepare(self, insts: Union[List[Sent], List[Frame]], use_cache: bool):
     conf: DirectExtractorConf = self.conf
     # get info
     if use_cache:
         zobjs = []
         attr_name = f"_dcache_{conf.ftag}"  # should be unique
         for s in insts:
             one = getattr(s, attr_name, None)
             if one is None:
                 one = self._prep_f(s)
                 setattr(s, attr_name, one)  # set cache
             zobjs.append(one)
     else:
         zobjs = [self._prep_f(s) for s in insts]
     # batch things
     bsize, mlen = len(insts), max(z.len for z in zobjs) if len(zobjs)>0 else 1  # at least put one as padding
     batched_shape = (bsize, mlen)
     arr_items = np.full(batched_shape, None, dtype=object)
     arr_gaddrs = np.arange(bsize*mlen).reshape(batched_shape)  # gold address
     arr_core_widxes = np.full(batched_shape, 0, dtype=np.int)
     arr_core_wlens = np.full(batched_shape, 1, dtype=np.int)
     # arr_ext_widxes = np.full(batched_shape, 0, dtype=np.int)
     # arr_ext_wlens = np.full(batched_shape, 1, dtype=np.int)
     for zidx, zobj in enumerate(zobjs):
         zlen = zobj.len
         arr_items[zidx, :zlen] = zobj.items
         arr_core_widxes[zidx, :zlen] = zobj.core_widxes
         arr_core_wlens[zidx, :zlen] = zobj.core_wlens
         # arr_ext_widxes[zidx, :zlen] = zobj.ext_widxes
         # arr_ext_wlens[zidx, :zlen] = zobj.ext_wlens
     arr_gaddrs[arr_items==None] = -1  # set -1 as gaddr
     # final setup things
     expr_gaddr = BK.input_idx(arr_gaddrs)  # [*, GLEN]
     expr_core_widxes = BK.input_idx(arr_core_widxes)
     expr_core_wlens = BK.input_idx(arr_core_wlens)
     # expr_ext_widxes = BK.input_idx(arr_ext_widxes)
     # expr_ext_wlens = BK.input_idx(arr_ext_wlens)
     expr_loss_weight_non = BK.input_real([z.loss_weight_non for z in zobjs])  # [*]
     # return arr_flatten_items, expr_gaddr, expr_core_widxes, expr_core_wlens, \
     #        expr_ext_widxes, expr_ext_wlens, expr_loss_weight_non
     return arr_items, expr_gaddr, expr_core_widxes, expr_core_wlens, expr_loss_weight_non
Exemple #18
0
 def prepare(self, insts: List[Sent], use_cache: bool):
     # get info
     zobjs = ZDecoderHelper.get_zobjs(insts, self._prep_sent, use_cache, f"_cache_srl")
     # batch things
     bsize, mlen = len(insts), max(z.slen for z in zobjs) if len(zobjs)>0 else 1  # at least put one as padding
     batched_shape = (bsize, mlen)
     arr_items = np.full(batched_shape, None, dtype=object)
     arr_evt_labels = np.full(batched_shape, 0, dtype=np.int)
     arr_arg_labels = np.full(batched_shape+(mlen,), 0, dtype=np.int)
     arr_arg2_labels = np.full(batched_shape+(mlen,), 0, dtype=np.int)
     for zidx, zobj in enumerate(zobjs):
         zlen = zobj.slen
         arr_items[zidx, :zlen] = zobj.evt_items
         arr_evt_labels[zidx, :zlen] = zobj.evt_arr
         arr_arg_labels[zidx, :zlen, :zlen] = zobj.arg_arr
         arr_arg2_labels[zidx, :zlen, :zlen] = zobj.arg2_arr
     expr_evt_labels = BK.input_idx(arr_evt_labels)  # [*, slen]
     expr_arg_labels = BK.input_idx(arr_arg_labels)  # [*, slen, slen]
     expr_arg2_labels = BK.input_idx(arr_arg2_labels)  # [*, slen, slen]
     expr_loss_weight_non = BK.input_real([z.loss_weight_non for z in zobjs])  # [*]
     return arr_items, expr_evt_labels, expr_arg_labels, expr_arg2_labels, expr_loss_weight_non
Exemple #19
0
 def loss(self, flt_items, flt_input_expr, flt_pair_expr, flt_full_expr, flt_mask_expr, flt_extra_weights=None):
     conf: ExtenderConf = self.conf
     _loss_lambda = conf._loss_lambda
     # --
     enc_t = self._forward_eenc(flt_input_expr, flt_full_expr, flt_mask_expr)  # [*, slen, D]
     s_left, s_right = self.enode.score(enc_t, flt_pair_expr if conf.ext_use_finput else None, flt_mask_expr)  # [*, slen]
     # --
     gold_posi = [self.ext_span_getter(z.mention) for z in flt_items]  # List[(widx, wlen)]
     widx_t = BK.input_idx([z[0] for z in gold_posi])  # [*]
     wlen_t = BK.input_idx([z[1] for z in gold_posi])
     loss_left_t, loss_right_t = BK.loss_nll(s_left, widx_t), BK.loss_nll(s_right, widx_t+wlen_t-1)  # [*]
     if flt_extra_weights is not None:
         loss_left_t *= flt_extra_weights
         loss_right_t *= flt_extra_weights
         loss_div = flt_extra_weights.sum()  # note: also use this!
     else:
         loss_div = BK.constants([len(flt_items)], value=1.).sum()
     loss_left_item = LossHelper.compile_leaf_loss("left", loss_left_t.sum(), loss_div, loss_lambda=_loss_lambda)
     loss_right_item = LossHelper.compile_leaf_loss("right", loss_right_t.sum(), loss_div, loss_lambda=_loss_lambda)
     ret_loss = LossHelper.combine_multiple_losses([loss_left_item, loss_right_item])
     return ret_loss
Exemple #20
0
 def forward(self, med: ZMediator):
     # --
     # get hid_t
     hid_t0 = med.get_enc_cache_val("hid")
     sinfo = med.ibatch.seq_info
     _arange_t, _sel_t = sinfo.arange2_t, sinfo.dec_sel_idxes
     hid_t = hid_t0[_arange_t, _sel_t]  # [*, dlen, D]
     # --
     # prepare relations
     bsize, dlen = BK.get_shape(sinfo.dec_sel_masks)
     arr_rels = np.full([bsize, dlen, dlen], 0,
                        dtype=np.int)  # by default 0
     arr_labs = np.full([bsize, dlen], 0, dtype=np.int)
     for bidx, item in enumerate(
             med.ibatch.items):  # for each item in the batch
         _dec_offsets = item.seq_info.dec_offsets
         for sidx, sent in enumerate(
                 item.sents):  # for each sent in the msent item
             tree = sent.tree_dep
             _start = _dec_offsets[sidx]
             _slen = len(sent)
             _arr_ms = np.asarray(range(_slen)) + _start  # [??]
             _arr_hs = np.asarray(tree.seq_head.vals) + (
                 _start - 1)  # note(+N): need to do more if msent!!
             _arr_labs = np.asarray(tree.seq_label.idxes)  # [??]
             arr_labs[bidx, _start:_start + _slen] = _arr_labs
             arr_rels[bidx, _arr_hs, _arr_ms] = _arr_labs
             arr_rels[bidx, _arr_ms, _arr_hs] = -_arr_labs
     expr_labs = BK.input_idx(arr_rels)  # [*, dlen, dlen]
     # --
     # go through
     res_t = hid_t
     if self.type_emb is not None:
         expr_seq_labs = BK.input_idx(arr_labs)  # [*, dlen]
         lab_t = self.type_emb(expr_seq_labs)
         res_t = res_t + lab_t
     for node in self.nodes:
         res_t = node.forward(res_t, expr_labs, sinfo.dec_sel_masks)
         med.layer_end({'hid': res_t})  # step once!
     return res_t
Exemple #21
0
 def loss(self, insts: Union[List[Sent], List[Frame]], input_expr: BK.Expr, mask_expr: BK.Expr,
          pair_expr: BK.Expr = None, lookup_flatten=False, external_extra_score: BK.Expr=None):
     conf: DirectExtractorConf = self.conf
     # step 0: prepare golds
     arr_gold_items, expr_gold_gaddr, expr_gold_widxes, expr_gold_wlens, expr_loss_weight_non = \
         self.helper.prepare(insts, use_cache=True)
     # step 1: extract cands
     if conf.loss_use_posi:
         cand_res = self.extract_node.go_lookup(
             input_expr, expr_gold_widxes, expr_gold_wlens, (expr_gold_gaddr>=0).float(), gaddr_expr=expr_gold_gaddr)
     else:  # todo(note): assume no in-middle mask!!
         cand_widx, cand_wlen, cand_mask, cand_gaddr = self.extract_node.prepare_with_lengths(
             BK.get_shape(mask_expr), mask_expr.sum(-1).long(), expr_gold_widxes, expr_gold_wlens, expr_gold_gaddr)
         if conf.span_train_sample:  # simply do sampling
             cand_res = self.extract_node.go_sample(
                 input_expr, mask_expr, cand_widx, cand_wlen, cand_mask,
                 rate=conf.span_train_sample_rate, count=conf.span_train_sample_count,
                 gaddr_expr=cand_gaddr, add_gold_rate=1.0)  # note: always fully add gold for sampling!!
         else:  # beam pruner using topk
             cand_res = self.extract_node.go_topk(
                 input_expr, mask_expr, cand_widx, cand_wlen, cand_mask,
                 rate=conf.span_topk_rate, count=conf.span_topk_count,
                 gaddr_expr=cand_gaddr, add_gold_rate=conf.span_train_topk_add_gold_rate)
     # step 1+: prepare for labeling
     cand_gold_mask = (cand_res.gaddr_expr>=0).float() * cand_res.mask_expr  # [*, cand_len]
     # todo(note): add a 0 as idx=-1 to make NEG ones as 0!!
     flatten_gold_label_idxes = BK.input_idx([(0 if z is None else z.label_idx) for z in arr_gold_items.flatten()] + [0])
     gold_label_idxes = flatten_gold_label_idxes[cand_res.gaddr_expr]
     cand_loss_weights = BK.where(gold_label_idxes==0, expr_loss_weight_non.unsqueeze(-1)*conf.loss_weight_non, cand_res.mask_expr)
     final_loss_weights = cand_loss_weights * cand_res.mask_expr
     # cand loss
     if conf.loss_cand > 0. and not conf.loss_use_posi:
         loss_cand0 = BK.loss_binary(cand_res.score_expr, cand_gold_mask, label_smoothing=conf.cand_label_smoothing)
         loss_cand = (loss_cand0 * final_loss_weights).sum()
         loss_cand_item = LossHelper.compile_leaf_loss(f"cand", loss_cand, final_loss_weights.sum(),
                                                       loss_lambda=conf.loss_cand)
     else:
         loss_cand_item = None
     # extra score
     cand_extra_score = self._get_extra_score(
         cand_res.score_expr, insts, cand_res, arr_gold_items, conf.loss_use_cons, conf.loss_use_lu)
     final_extra_score = self._sum_scores(external_extra_score, cand_extra_score)
     # step 2: label; with special weights
     loss_lab, loss_count = self.lab_node.loss(
         cand_res.span_expr, pair_expr, cand_res.mask_expr, gold_label_idxes,
         loss_weight_expr=final_loss_weights, extra_score=final_extra_score)
     loss_lab_item = LossHelper.compile_leaf_loss(f"lab", loss_lab, loss_count,
                                                  loss_lambda=conf.loss_lab, gold=cand_gold_mask.sum())
     # ==
     # return loss
     ret_loss = LossHelper.combine_multiple_losses([loss_cand_item, loss_lab_item])
     return self._finish_loss(ret_loss, insts, input_expr, mask_expr, pair_expr, lookup_flatten)
Exemple #22
0
 def __init__(self, ibatch: InputBatch, IDX_PAD: int):
     # preps
     self.bsize = len(ibatch)
     self.arange1_t = BK.arange_idx(self.bsize)  # [bsize]
     self.arange2_t = self.arange1_t.unsqueeze(-1)  # [bsize, 1]
     self.arange3_t = self.arange2_t.unsqueeze(-1)  # [bsize, 1, 1]
     # batched them
     all_seq_infos = [z.seq_info for z in ibatch.items]
     # enc: [*, len_enc]: ids(pad IDX_PAD), masks, segids(pad 0)
     self.enc_input_ids = BK.input_idx(
         DataPadder.go_batch_2d([z.enc_input_ids for z in all_seq_infos],
                                int(IDX_PAD)))
     self.enc_input_masks = BK.input_real(
         DataPadder.lengths2mask(
             [len(z.enc_input_ids) for z in all_seq_infos]))
     self.enc_input_segids = BK.input_idx(
         DataPadder.go_batch_2d([z.enc_input_segids for z in all_seq_infos],
                                0))
     # dec: [*, len_dec]: sel_idxes(pad 0), sel_lens(pad 1), masks, sent_idxes(pad ??)
     self.dec_sel_idxes = BK.input_idx(
         DataPadder.go_batch_2d([z.dec_sel_idxes for z in all_seq_infos],
                                0))
     self.dec_sel_lens = BK.input_idx(
         DataPadder.go_batch_2d([z.dec_sel_lens for z in all_seq_infos], 1))
     self.dec_sel_masks = BK.input_real(
         DataPadder.lengths2mask(
             [len(z.dec_sel_idxes) for z in all_seq_infos]))
     _max_dec_len = BK.get_shape(self.dec_sel_masks, 1)
     _dec_offsets = BK.input_idx(
         DataPadder.go_batch_2d([z.dec_offsets for z in all_seq_infos],
                                _max_dec_len))
     # note: CLS as -1, then 0,1,2,..., PAD gets -2!
     self.dec_sent_idxes = \
         (BK.arange_idx(_max_dec_len).unsqueeze(0).unsqueeze(-1) >= _dec_offsets.unsqueeze(-2)).sum(-1).long() - 1
     self.dec_sent_idxes[self.dec_sel_masks <= 0.] = -2
     # dec -> enc: [*, len_enc] (calculated on needed!)
     # note: require 1-to-1 mapping (except pads)!!
     self._enc_back_hits = None
     self._enc_back_sel_idxes = None
Exemple #23
0
 def __init__(self, berter: BertEncoder,
              seq_subs: List[InputSubwordSeqField]):
     self.seq_subs = seq_subs
     self.berter = berter
     self.bsize = len(seq_subs)
     self.arange1_t = BK.arange_idx(self.bsize)  # [bsize]
     self.arange2_t = self.arange1_t.unsqueeze(-1)  # [bsize, 1]
     self.arange3_t = self.arange2_t.unsqueeze(-1)  # [bsize, 1, 1]
     # --
     tokenizer = self.berter.tokenizer
     PAD_IDX = tokenizer.pad_token_id
     # MASK_IDX = tokenizer.mask_token_id
     # CLS_IDX_l = [tokenizer.cls_token_id]
     # SEP_IDX_l = [tokenizer.sep_token_id]
     # make batched idxes
     padder = DataPadder(2, pad_vals=PAD_IDX, mask_range=2)
     batched_sublens = [len(s.idxes) for s in seq_subs]  # [bsize]
     batched_input_ids, batched_input_mask = padder.pad(
         [s.idxes for s in seq_subs])  # [bsize, sub_len]
     self.batched_sublens_p1 = BK.input_idx(
         batched_sublens
     ) + 1  # also the idx of EOS (if counting including BOS)
     self.batched_input_ids = BK.input_idx(batched_input_ids)
     self.batched_input_mask = BK.input_real(batched_input_mask)
     # make batched mappings (sub->orig)
     padder2 = DataPadder(2, pad_vals=0,
                          mask_range=2)  # pad as 0 to avoid out-of-range
     batched_first_idxes, batched_first_mask = padder2.pad(
         [s.align_info.orig2begin for s in seq_subs])  # [bsize, orig_len]
     self.batched_first_idxes = BK.input_idx(batched_first_idxes)
     self.batched_first_mask = BK.input_real(batched_first_mask)
     # reversed batched_mappings (orig->sub) (created when needed)
     self._batched_rev_idxes = None  # [bsize, sub_len]
     # --
     self.batched_repl_masks = None  # [bsize, sub_len], to replace with MASK
     self.batched_token_type_ids = None  # [bsize, 1+sub_len+1]
     self.batched_position_ids = None  # [bsize, 1+sub_len+1]
     self.other_factors = {}  # name -> aug_batched_ids
Exemple #24
0
 def prepare(self, ibatch):
     b_seq_info = ibatch.seq_info
     arr_upos_labels = np.full(BK.get_shape(b_seq_info.dec_sel_masks),
                               0,
                               dtype=np.int)  # by default 0
     for bidx, item in enumerate(
             ibatch.items):  # for each item in the batch
         _dec_offsets = item.seq_info.dec_offsets
         for sidx, sent in enumerate(
                 item.sents):  # for each sent in the msent item
             _start = _dec_offsets[sidx]
             arr_upos_labels[bidx, _start:_start +
                             len(sent)] = sent.seq_upos.idxes
     expr_upos_labels = BK.input_idx(arr_upos_labels)  # [bs, dlen]
     return expr_upos_labels
Exemple #25
0
 def prepare(self, insts: List, use_cache: bool):
     # get info
     zobjs = ZDecoderHelper.get_zobjs(insts, self._prep_inst, use_cache, f"_cache_udep")
     # then
     bsize, mlen = len(insts), max(z.slen for z in zobjs) if len(zobjs)>0 else 1
     batched_shape = (bsize, mlen)
     arr_depth = np.full(batched_shape, 0., dtype=np.float)  # [*, slen]
     arr_udep = np.full(batched_shape+(mlen,), 0, dtype=np.int)  # [*, slen_m, slen_h]
     for zidx, zobj in enumerate(zobjs):
         zlen = zobj.slen
         arr_depth[zidx, :zlen] = zobj.depth_arr
         arr_udep[zidx, :zlen, :zlen] = zobj.udep_arr
     expr_depth = BK.input_real(arr_depth)  # [*, slen]
     expr_udep = BK.input_idx(arr_udep)  # [*, slen_m, slen_h]
     return expr_depth, expr_udep
Exemple #26
0
 def _transform_factors(self, factors: Union[List[List[int]], BK.Expr],
                        is_orig: bool, PAD_IDX: Union[int, float]):
     if isinstance(factors, BK.Expr):  # already padded
         batched_ids = factors
     else:
         padder = DataPadder(2, pad_vals=PAD_IDX)
         batched_ids, _ = padder.pad(factors)
         batched_ids = BK.input_idx(
             batched_ids)  # [bsize, orig-len if is_orig else sub_len]
     if is_orig:  # map to subtoks
         final_batched_ids = batched_ids[
             self.arange2_t, self.batched_rev_idxes]  # [bsize, sub_len]
     else:
         final_batched_ids = batched_ids  # [bsize, sub_len]
     return final_batched_ids
Exemple #27
0
def _common_nmst(CPU_f, scores_expr, mask_expr, lengths_arr, labeled, ret_arr):
    assert labeled
    with BK.no_grad_env():
        # argmax-label: [BS, m, h]
        scores_unlabeled_max, labels_argmax = scores_expr.max(-1)
        #
        scores_unlabeled_max_arr = BK.get_value(scores_unlabeled_max)
        mst_heads_arr, _, mst_scores_arr = CPU_f(scores_unlabeled_max_arr, lengths_arr, labeled=False)
        # [BS, m]
        mst_heads_expr = BK.input_idx(mst_heads_arr)
        mst_labels_expr = BK.gather_one_lastdim(labels_argmax, mst_heads_expr).squeeze(-1)
        # prepare for the outputs
        if ret_arr:
            return mst_heads_arr, BK.get_value(mst_labels_expr), mst_scores_arr
        else:
            return mst_heads_expr, mst_labels_expr, BK.input_real(mst_scores_arr)
Exemple #28
0
 def __init__(self, conf: SrlInferenceHelperConf, dec: 'ZDecoderSrl',
              **kwargs):
     super().__init__(conf, **kwargs)
     conf: SrlInferenceHelperConf = self.conf
     # --
     self.setattr_borrow('dec', dec)
     self.arg_pp = PostProcessor(conf.arg_pp)
     # --
     self.lu_cons, self.role_cons = None, None
     if conf.frames_name:  # currently only frame->role
         from msp2.data.resources import get_frames_label_budgets
         flb = get_frames_label_budgets(conf.frames_name)
         _voc_ef, _voc_evt, _voc_arg = dec.ztask.vpack
         _role_cons = fchelper.build_constraint_arrs(
             flb, _voc_arg, _voc_evt)
         self.role_cons = BK.input_real(_role_cons)
     if conf.frames_file:
         _voc_ef, _voc_evt, _voc_arg = dec.ztask.vpack
         _fc = default_pickle_serializer.from_file(conf.frames_file)
         _lu_cons = fchelper.build_constraint_arrs(
             fchelper.build_lu_map(_fc), _voc_evt,
             warning=False)  # lexicon->frame
         _role_cons = fchelper.build_constraint_arrs(
             fchelper.build_role_map(_fc), _voc_arg,
             _voc_evt)  # frame->role
         self.lu_cons, self.role_cons = _lu_cons, BK.input_real(_role_cons)
     # --
     self.cons_evt_tok_f = conf.get_cons_evt_tok()
     self.cons_evt_frame_f = conf.get_cons_evt_frame()
     if self.dec.conf.arg_use_bio:  # extend for bio!
         self.cons_arg_bio_sels = BK.input_idx(
             self.dec.vocab_bio_arg.get_bio2origin())
     else:
         self.cons_arg_bio_sels = None
     # --
     from msp2.data.resources.frames import KBP17_TYPES
     self.pred_evt_filter = {
         'kbp17': KBP17_TYPES
     }.get(conf.pred_evt_filter, None)
Exemple #29
0
 def forward(self, inputs, add_bos=False, add_eos=False):
     conf: CharCnnInputEmbedderConf = self.conf
     # --
     voc = self.voc
     char_input_t = BK.input_idx(inputs)  # [*, len]
     # todo(note): no need for replacing to unk for char!!
     # bos and eos
     all_input_slices = []
     slice_shape = BK.get_shape(char_input_t)
     slice_shape[-2] = 1  # [*, 1, clen]
     if add_bos:
         all_input_slices.append(
             BK.constants(slice_shape, voc.bos, dtype=char_input_t.dtype))
     all_input_slices.append(char_input_t)  # [*, len, clen]
     if add_eos:
         all_input_slices.append(
             BK.constants(slice_shape, voc.eos, dtype=char_input_t.dtype))
     final_input_t = BK.concat(all_input_slices, -2)  # [*, 1?+len+1?, clen]
     # char embeddings
     char_embed_expr = self.E(final_input_t)  # [*, ??, dim]
     # char cnn
     ret = self.cnn(char_embed_expr)
     return ret
Exemple #30
0
 def loss(self,
          insts: Union[List[Sent], List[Frame]],
          input_expr: BK.Expr,
          mask_expr: BK.Expr,
          pair_expr: BK.Expr = None,
          lookup_flatten=False,
          external_extra_score: BK.Expr = None):
     conf: SoftExtractorConf = self.conf
     assert not lookup_flatten
     bsize, slen = BK.get_shape(mask_expr)
     # --
     # step 0: prepare
     arr_items, expr_seq_gaddr, expr_group_widxes, expr_group_masks, expr_loss_weight_non = \
         self.helper.prepare(insts, mlen=BK.get_shape(mask_expr, -1), use_cache=True)
     arange2_t = BK.arange_idx(bsize).unsqueeze(-1)  # [*, 1]
     # --
     # step 1: cand
     cand_full_scores, pred_cand_decisions = self._cand_score_and_select(
         input_expr, mask_expr)  # [*, slen]
     loss_cand_items, cand_widxes, cand_masks = self._loss_feed_cand(
         mask_expr,
         cand_full_scores,
         pred_cand_decisions,
         expr_seq_gaddr,
         expr_group_widxes,
         expr_group_masks,
         expr_loss_weight_non,
     )  # ~, [*, clen]
     # --
     # step 2: split
     cand_expr, cand_scores = input_expr[
         arange2_t, cand_widxes], cand_full_scores[arange2_t,
                                                   cand_widxes]  # [*, clen]
     split_scores, pred_split_decisions = self._split_score(
         cand_expr, cand_masks)  # [*, clen-1]
     loss_split_item, seg_masks, seg_ext_widxes, seg_ext_masks, seg_weighted_expr, oracle_gaddr = self._loss_feed_split(
         mask_expr,
         split_scores,
         pred_split_decisions,
         cand_widxes,
         cand_masks,
         cand_expr,
         cand_scores,
         expr_seq_gaddr,
     )  # ~, [*, seglen, *?]
     # --
     # step 3: lab
     # todo(note): add a 0 as idx=-1 to make NEG ones as 0!!
     flatten_gold_label_idxes = BK.input_idx(
         [(0 if z is None else z.label_idx)
          for z in arr_items.flatten()] + [0])
     gold_label_idxes = flatten_gold_label_idxes[
         oracle_gaddr]  # [*, seglen]
     lab_loss_weights = BK.where(oracle_gaddr >= 0,
                                 expr_loss_weight_non.unsqueeze(-1) *
                                 conf.loss_weight_non,
                                 seg_masks)  # [*, seglen]
     final_lab_loss_weights = lab_loss_weights * seg_masks  # [*, seglen]
     # go
     loss_lab, loss_count = self.lab_node.loss(
         seg_weighted_expr,
         pair_expr,
         seg_masks,
         gold_label_idxes,
         loss_weight_expr=final_lab_loss_weights,
         extra_score=external_extra_score)
     loss_lab_item = LossHelper.compile_leaf_loss(
         f"lab",
         loss_lab,
         loss_count,
         loss_lambda=conf.loss_lab,
         gold=(gold_label_idxes > 0).float().sum())
     # step 4: extend
     flt_mask = ((gold_label_idxes > 0) & (seg_masks > 0.))  # [*, seglen]
     flt_sidx = BK.arange_idx(bsize).unsqueeze(-1).expand_as(flt_mask)[
         flt_mask]  # [?]
     flt_expr = seg_weighted_expr[flt_mask]  # [?, D]
     flt_full_expr = self._prepare_full_expr(seg_ext_widxes[flt_mask],
                                             seg_ext_masks[flt_mask],
                                             slen)  # [?, slen, D]
     flt_items = arr_items.flatten()[BK.get_value(
         oracle_gaddr[flt_mask])]  # [?]
     loss_ext_item = self.ext_node.loss(flt_items, input_expr[flt_sidx],
                                        flt_expr, flt_full_expr,
                                        mask_expr[flt_sidx])
     # --
     # return loss
     ret_loss = LossHelper.combine_multiple_losses(
         loss_cand_items + [loss_split_item, loss_lab_item, loss_ext_item])
     return ret_loss, None