예제 #1
0
 def decode_frame(self, ibatch, scores_t: BK.Expr, pred_max_layer: int, voc,
                  pred_label: bool, pred_tag: str, pred_check_layer: int):
     # --
     # first get topk for each position
     logprobs_t = scores_t.log_softmax(-1)  # [*, dlen, L]
     pred_scores, pred_labels = logprobs_t.topk(
         pred_max_layer)  # [*, dlen, K]
     arr_scores, arr_labels = BK.get_value(pred_scores), BK.get_value(
         pred_labels)  # [*, dlen, K]
     # put results
     res_bidxes, res_widxes, res_frames = [], [], []  # flattened results
     res_farrs = np.full(arr_scores.shape, None,
                         dtype=object)  # [*, dlen, K]
     for bidx, item in enumerate(
             ibatch.items):  # for each item in the batch
         _dec_offsets = item.seq_info.dec_offsets
         for sidx, sent in enumerate(item.sents):
             # todo(+N): currently we only predict for center if there is!
             if item.center_sidx is not None and sidx != item.center_sidx:
                 continue  # skip non-center sent in this mode!
             _start = _dec_offsets[sidx]
             _len = len(sent)
             _arr_scores, _arr_labels = arr_scores[bidx][
                 _start:_start + _len], arr_labels[bidx][_start:_start +
                                                         _len]
             for widx in range(_len):
                 _full_widx = widx + _start  # idx in the msent
                 _tmp_set = set()
                 for _k in range(pred_max_layer):
                     _score, _lab = float(_arr_scores[widx][_k]), int(
                         _arr_labels[widx][_k])
                     if _lab == 0:  # note: lab=0 means NIL
                         break
                     _type_str = (voc.idx2word(_lab)
                                  if pred_label else "UNK")
                     _type_str_prefix = '.'.join(
                         _type_str.split('.')[:pred_check_layer])
                     if pred_check_layer >= 0 and _type_str_prefix in _tmp_set:
                         continue  # ignore since constraint
                     _tmp_set.add(_type_str_prefix)
                     # add new one!
                     res_bidxes.append(bidx)
                     res_widxes.append(_full_widx)
                     _new_frame = sent.make_frame(widx,
                                                  1,
                                                  tag=pred_tag,
                                                  type=_type_str,
                                                  score=float(_score))
                     _new_frame.set_label_idx(int(_lab))
                     _new_frame._tmp_sstart = _start  # todo(+N): ugly tmp value ...
                     _new_frame._tmp_sidx = sidx
                     _new_frame._tmp_item = item
                     res_frames.append(_new_frame)
                     res_farrs[bidx, _full_widx, _k] = _new_frame
     # return
     res_bidxes_t, res_widxes_t = BK.input_idx(res_bidxes), BK.input_idx(
         res_widxes)  # [??]
     return (res_bidxes_t,
             res_widxes_t), res_frames, res_farrs  # [??], [*, dlen, K]
예제 #2
0
 def decode_frame_given(self, ibatch, scores_t: BK.Expr,
                        pred_max_layer: int, voc, pred_label: bool,
                        pred_tag: str, assume_osof: bool):
     if pred_label:  # if overwrite label!
         logprobs_t = scores_t.log_softmax(-1)  # [*, dlen, L]
         pred_scores, pred_labels = logprobs_t.max(
             -1)  # [*, dlen], note: maximum!
         arr_scores, arr_labels = BK.get_value(pred_scores), BK.get_value(
             pred_labels)  # [*, dlen]
     else:
         arr_scores = arr_labels = None
     # --
     # read given results
     res_bidxes, res_widxes, res_frames = [], [], []  # flattened results
     tmp_farrs = defaultdict(list)  # later assign
     for bidx, item in enumerate(
             ibatch.items):  # for each item in the batch
         _trg_frames = [item.inst] if assume_osof else \
             sum([sent.get_frames(pred_tag) for sidx,sent in enumerate(item.sents)
                  if (item.center_sidx is None or sidx == item.center_sidx)],[])  # still only pick center ones!
         # --
         _dec_offsets = item.seq_info.dec_offsets
         for _frame in _trg_frames:  # note: simply sort by original order!
             sidx = item.sents.index(_frame.sent)
             _start = _dec_offsets[sidx]
             _full_hidx = _start + _frame.mention.shead_widx
             # add new one
             res_bidxes.append(bidx)
             res_widxes.append(_full_hidx)
             _frame._tmp_sstart = _start  # todo(+N): ugly tmp value ...
             _frame._tmp_sidx = sidx
             _frame._tmp_item = item
             res_frames.append(_frame)
             tmp_farrs[(bidx, _full_hidx)].append(_frame)
             # assign/rewrite label?
             if pred_label:
                 _lab = int(arr_labels[bidx, _full_hidx])  # label index
                 _frame.set_label_idx(_lab)
                 _frame.set_label(voc.idx2word(_lab))
                 _frame.set_score(float(arr_scores[bidx, _full_hidx]))
         # --
     # --
     res_farrs = np.full(BK.get_shape(scores_t)[:-1] + [pred_max_layer],
                         None,
                         dtype=object)  # [*, dlen, K]
     for _key, _values in tmp_farrs.items():
         bidx, widx = _key
         _values = _values[:pred_max_layer]  # truncate if more!
         res_farrs[bidx, widx, :len(_values)] = _values
     # return
     res_bidxes_t, res_widxes_t = BK.input_idx(res_bidxes), BK.input_idx(
         res_widxes)  # [??]
     return (res_bidxes_t,
             res_widxes_t), res_frames, res_farrs  # [??], [*, dlen, K]
예제 #3
0
 def decode_arg(self, res_evts: List, arg_scores_t: BK.Expr,
                pred_max_layer: int, voc, arg_allowed_sent_gap: int,
                arr_efs):
     # first get topk
     arg_logprobs_t = arg_scores_t.log_softmax(-1)  # [??, dlen, L]
     pred_arg_scores, pred_arg_labels = arg_logprobs_t.topk(
         pred_max_layer)  # [??, dlen, K]
     arr_arg_scores, arr_arg_labels = BK.get_value(
         pred_arg_scores), BK.get_value(pred_arg_labels)  # [??, dlen, K]
     # put results
     res_fidxes, res_widxes, res_args = [], [], []  # flattened results
     for fidx, evt in enumerate(res_evts):  # for each evt
         item = evt._tmp_item  # cached
         _evt_sidx = evt._tmp_sidx
         _dec_offsets = item.seq_info.dec_offsets
         for sidx, sent in enumerate(item.sents):
             if abs(sidx - _evt_sidx) > arg_allowed_sent_gap:
                 continue  # larger than allowed sentence gap
             _start = _dec_offsets[sidx]
             _len = len(sent)
             _arr_scores, _arr_labels = arr_arg_scores[fidx][
                 _start:_start + _len], arr_arg_labels[fidx][_start:_start +
                                                             _len]
             for widx in range(_len):
                 _full_widx = widx + _start  # idx in the msent
                 _new_ef = None
                 if arr_efs is not None:  # note: arr_efs should also expand to frames!
                     _new_ef = arr_efs[
                         fidx, _full_widx,
                         0]  # todo(+N): only get the first one!
                     if _new_ef is None:
                         continue  # no ef!
                 for _score, _lab in zip(_arr_scores[widx],
                                         _arr_labels[widx]):  # [K]
                     if _lab == 0:  # note: idx=0 means NIL
                         break
                     # add new one!!
                     res_fidxes.append(fidx)
                     res_widxes.append(_full_widx)
                     if _new_ef is None:
                         _new_ef = sent.make_entity_filler(
                             widx, 1)  # share them if possible!
                     _new_arg = evt.add_arg(_new_ef,
                                            role=voc.idx2word(_lab),
                                            score=float(_score))
                     _new_arg._tmp_sstart = _start  # todo(+N): ugly tmp value ...
                     res_args.append(_new_arg)
     # return
     res_fidxes_t, res_widxes_t = BK.input_idx(res_fidxes), BK.input_idx(
         res_widxes)  # [??]
     return (res_fidxes_t, res_widxes_t), res_args
예제 #4
0
 def decode_evt(self, dec, ibatch, evt_scores_t: BK.Expr):
     _pred_max_layer_evt = dec.conf.max_layer_evt
     _voc_evt = dec.voc_evt
     _pred_evt_label = self.conf.pred_evt_label
     # --
     evt_logprobs_t = evt_scores_t.log_softmax(-1)  # [*, dlen, L]
     pred_evt_scores, pred_evt_labels = evt_logprobs_t.topk(
         _pred_max_layer_evt)  # [*, dlen, K]
     arr_evt_scores, arr_evt_labels = BK.get_value(
         pred_evt_scores), BK.get_value(pred_evt_labels)  # [*, dlen, K]
     # put results
     res_bidxes, res_widxes, res_evts = [], [], []  # flattened results
     for bidx, item in enumerate(
             ibatch.items):  # for each item in the batch
         _dec_offsets = item.seq_info.dec_offsets
         for sidx, sent in enumerate(item.sents):
             # note: here we only predict for center if there is!
             if item.center_sidx is not None and sidx != item.center_sidx:
                 continue  # skip non-center sent in this mode!
             _start = _dec_offsets[sidx]
             _len = len(sent)
             _arr_scores, _arr_labels = arr_evt_scores[bidx][
                 _start:_start + _len], arr_evt_labels[bidx][_start:_start +
                                                             _len]
             for widx in range(_len):
                 for _score, _lab in zip(_arr_scores[widx],
                                         _arr_labels[widx]):  # [K]
                     if _lab == 0:  # note: idx=0 means NIL
                         break
                     # add new one!!
                     res_bidxes.append(bidx)
                     res_widxes.append(
                         _start + widx)  # note: remember to add offset!
                     _new_evt = sent.make_event(
                         widx,
                         1,
                         type=(_voc_evt.idx2word(_lab)
                               if _pred_evt_label else "UNK"),
                         score=float(_score))
                     _new_evt.set_label_idx(int(_lab))
                     _new_evt._tmp_sstart = _start  # todo(+N): ugly tmp value ...
                     _new_evt._tmp_sidx = sidx
                     _new_evt._tmp_item = item
                     res_evts.append(_new_evt)
     # return
     res_bidxes_t, res_widxes_t = BK.input_idx(res_bidxes), BK.input_idx(
         res_widxes)  # [??]
     return (res_bidxes_t, res_widxes_t), res_evts
예제 #5
0
 def decode_evt_given(self, dec, ibatch, evt_scores_t: BK.Expr):
     _voc_evt = dec.voc_evt
     _assume_osof = dec.conf.assume_osof  # one seq one frame
     _pred_evt_label = self.conf.pred_evt_label
     # --
     if _pred_evt_label:
         evt_logprobs_t = evt_scores_t.log_softmax(-1)  # [*, dlen, L]
         pred_evt_scores, pred_evt_labels = evt_logprobs_t.max(
             -1)  # [*, dlen], note: maximum!
         arr_evt_scores, arr_evt_labels = BK.get_value(
             pred_evt_scores), BK.get_value(pred_evt_labels)  # [*, dlen]
     else:
         arr_evt_scores = arr_evt_labels = None
     # --
     # read given results
     res_bidxes, res_widxes, res_evts = [], [], []  # flattened results
     for bidx, item in enumerate(
             ibatch.items):  # for each item in the batch
         _trg_evts = [item.inst] if _assume_osof else \
             sum([sent.events for sidx,sent in enumerate(item.sents) if (item.center_sidx is None or sidx == item.center_sidx)],[])
         # --
         _dec_offsets = item.seq_info.dec_offsets
         for _evt in _trg_evts:
             sidx = item.sents.index(_evt.sent)
             _start = _dec_offsets[sidx]
             _full_hidx = _start + _evt.mention.shead_widx
             # add new one
             res_bidxes.append(bidx)
             res_widxes.append(_full_hidx)
             _evt._tmp_sstart = _start  # todo(+N): ugly tmp value ...
             _evt._tmp_sidx = sidx
             _evt._tmp_item = item
             res_evts.append(_evt)
             # assign label?
             if _pred_evt_label:
                 _lab = int(arr_evt_labels[bidx, _full_hidx])  # label index
                 _evt.set_label_idx(_lab)
                 _evt.set_label(_voc_evt.idx2word(_lab))
                 _evt.set_score(float(arr_evt_scores[bidx, _full_hidx]))
         # --
     # --
     # return
     res_bidxes_t, res_widxes_t = BK.input_idx(res_bidxes), BK.input_idx(
         res_widxes)  # [??]
     return (res_bidxes_t, res_widxes_t), res_evts
예제 #6
0
 def decode_arg(self, dec, res_evts: List, arg_scores_t: BK.Expr):
     _pred_max_layer_arg = dec.conf.max_layer_arg
     _arg_allowed_sent_gap = dec.conf.arg_allowed_sent_gap
     _voc_arg = dec.voc_arg
     # --
     arg_logprobs_t = arg_scores_t.log_softmax(-1)  # [??, dlen, L]
     pred_arg_scores, pred_arg_labels = arg_logprobs_t.topk(
         _pred_max_layer_arg)  # [??, dlen, K]
     arr_arg_scores, arr_arg_labels = BK.get_value(
         pred_arg_scores), BK.get_value(pred_arg_labels)  # [??, dlen, K]
     # put results
     res_fidxes, res_widxes, res_args = [], [], []  # flattened results
     for fidx, evt in enumerate(res_evts):  # for each evt
         item = evt._tmp_item  # cached
         _evt_sidx = evt._tmp_sidx
         _dec_offsets = item.seq_info.dec_offsets
         for sidx, sent in enumerate(item.sents):
             if abs(sidx - _evt_sidx) > _arg_allowed_sent_gap:
                 continue  # larger than allowed sentence gap
             _start = _dec_offsets[sidx]
             _len = len(sent)
             _arr_scores, _arr_labels = arr_arg_scores[fidx][
                 _start:_start + _len], arr_arg_labels[fidx][_start:_start +
                                                             _len]
             for widx in range(_len):
                 for _score, _lab in zip(_arr_scores[widx],
                                         _arr_labels[widx]):  # [K]
                     if _lab == 0:  # note: idx=0 means NIL
                         break
                     # add new one!!
                     res_fidxes.append(fidx)
                     res_widxes.append(_start + widx)
                     _new_ef = sent.make_entity_filler(widx, 1)
                     _new_arg = evt.add_arg(_new_ef,
                                            role=_voc_arg.idx2word(_lab),
                                            score=float(_score))
                     _new_arg._tmp_sstart = _start  # todo(+N): ugly tmp value ...
                     res_args.append(_new_arg)
     # return
     res_fidxes_t, res_widxes_t = BK.input_idx(res_fidxes), BK.input_idx(
         res_widxes)  # [??]
     return (res_fidxes_t, res_widxes_t), res_args
예제 #7
0
 def decode_arg_bio(self, res_evts: List, arg_scores_t: BK.Expr,
                    pred_max_layer: int, voc_bio, arg_allowed_sent_gap: int,
                    arr_efs):
     assert pred_max_layer == 1, "Currently BIO only allow one!"
     assert arr_efs is None, "Currently BIO does not allow pick given efs!"
     _vocab_bio_arg = voc_bio
     _vocab_arg = _vocab_bio_arg.base_vocab
     # --
     arg_logprobs_t = arg_scores_t.log_softmax(-1)  # [??, dlen, L]
     # todo(+N): allow multiple bio seqs?
     pred_arg_scores, pred_arg_labels = arg_logprobs_t.max(
         -1)  # note: here only top1, [??, dlen]
     arr_arg_scores, arr_arg_labels = BK.get_value(
         pred_arg_scores), BK.get_value(pred_arg_labels)  # [??, dlen]
     # put results
     for fidx, evt in enumerate(res_evts):  # for each evt
         item = evt._tmp_item  # cached
         _evt_sidx = evt._tmp_sidx
         _dec_offsets = item.seq_info.dec_offsets
         for sidx, sent in enumerate(item.sents):
             if abs(sidx - _evt_sidx) > arg_allowed_sent_gap:
                 continue  # larger than allowed sentence gap
             _start = _dec_offsets[sidx]
             _len = len(sent)
             _arr_scores, _arr_labels = arr_arg_scores[fidx][
                 _start:_start + _len], arr_arg_labels[fidx][_start:_start +
                                                             _len]
             # decode bio
             _arg_results = _vocab_bio_arg.tags2spans_idx(_arr_labels)
             for a_widx, a_wlen, a_lab in _arg_results:
                 a_lab = int(a_lab)
                 assert a_lab > 0, "Error: should not extract 'O'!?"
                 _new_ef = sent.make_entity_filler(a_widx, a_wlen)
                 a_role = _vocab_arg.idx2word(a_lab)
                 _new_arg = evt.add_arg(_new_ef,
                                        a_role,
                                        score=np.mean(
                                            _arr_scores[a_widx:a_widx +
                                                        a_wlen]).item())
     # --
     return  # no need to return anything here
예제 #8
0
 def _cri_cert(self, scores: BK.Expr):  # 1.-uncertainty
     uncertainty = (scores.softmax(-1) * scores.log_softmax(-1)).sum(-1) / (
         -math.log(BK.get_shape(scores, -1)))
     return 1. - uncertainty
예제 #9
0
 def output_score(self, score: BK.Expr, local_normalize: bool):
     if local_normalize is None:
         local_normalize = self.conf.local_normalize
     if local_normalize:
         score = score.log_softmax(-1)
     return score