def decode_frame(self, ibatch, scores_t: BK.Expr, pred_max_layer: int, voc, pred_label: bool, pred_tag: str, pred_check_layer: int): # -- # first get topk for each position logprobs_t = scores_t.log_softmax(-1) # [*, dlen, L] pred_scores, pred_labels = logprobs_t.topk( pred_max_layer) # [*, dlen, K] arr_scores, arr_labels = BK.get_value(pred_scores), BK.get_value( pred_labels) # [*, dlen, K] # put results res_bidxes, res_widxes, res_frames = [], [], [] # flattened results res_farrs = np.full(arr_scores.shape, None, dtype=object) # [*, dlen, K] for bidx, item in enumerate( ibatch.items): # for each item in the batch _dec_offsets = item.seq_info.dec_offsets for sidx, sent in enumerate(item.sents): # todo(+N): currently we only predict for center if there is! if item.center_sidx is not None and sidx != item.center_sidx: continue # skip non-center sent in this mode! _start = _dec_offsets[sidx] _len = len(sent) _arr_scores, _arr_labels = arr_scores[bidx][ _start:_start + _len], arr_labels[bidx][_start:_start + _len] for widx in range(_len): _full_widx = widx + _start # idx in the msent _tmp_set = set() for _k in range(pred_max_layer): _score, _lab = float(_arr_scores[widx][_k]), int( _arr_labels[widx][_k]) if _lab == 0: # note: lab=0 means NIL break _type_str = (voc.idx2word(_lab) if pred_label else "UNK") _type_str_prefix = '.'.join( _type_str.split('.')[:pred_check_layer]) if pred_check_layer >= 0 and _type_str_prefix in _tmp_set: continue # ignore since constraint _tmp_set.add(_type_str_prefix) # add new one! res_bidxes.append(bidx) res_widxes.append(_full_widx) _new_frame = sent.make_frame(widx, 1, tag=pred_tag, type=_type_str, score=float(_score)) _new_frame.set_label_idx(int(_lab)) _new_frame._tmp_sstart = _start # todo(+N): ugly tmp value ... _new_frame._tmp_sidx = sidx _new_frame._tmp_item = item res_frames.append(_new_frame) res_farrs[bidx, _full_widx, _k] = _new_frame # return res_bidxes_t, res_widxes_t = BK.input_idx(res_bidxes), BK.input_idx( res_widxes) # [??] return (res_bidxes_t, res_widxes_t), res_frames, res_farrs # [??], [*, dlen, K]
def decode_frame_given(self, ibatch, scores_t: BK.Expr, pred_max_layer: int, voc, pred_label: bool, pred_tag: str, assume_osof: bool): if pred_label: # if overwrite label! logprobs_t = scores_t.log_softmax(-1) # [*, dlen, L] pred_scores, pred_labels = logprobs_t.max( -1) # [*, dlen], note: maximum! arr_scores, arr_labels = BK.get_value(pred_scores), BK.get_value( pred_labels) # [*, dlen] else: arr_scores = arr_labels = None # -- # read given results res_bidxes, res_widxes, res_frames = [], [], [] # flattened results tmp_farrs = defaultdict(list) # later assign for bidx, item in enumerate( ibatch.items): # for each item in the batch _trg_frames = [item.inst] if assume_osof else \ sum([sent.get_frames(pred_tag) for sidx,sent in enumerate(item.sents) if (item.center_sidx is None or sidx == item.center_sidx)],[]) # still only pick center ones! # -- _dec_offsets = item.seq_info.dec_offsets for _frame in _trg_frames: # note: simply sort by original order! sidx = item.sents.index(_frame.sent) _start = _dec_offsets[sidx] _full_hidx = _start + _frame.mention.shead_widx # add new one res_bidxes.append(bidx) res_widxes.append(_full_hidx) _frame._tmp_sstart = _start # todo(+N): ugly tmp value ... _frame._tmp_sidx = sidx _frame._tmp_item = item res_frames.append(_frame) tmp_farrs[(bidx, _full_hidx)].append(_frame) # assign/rewrite label? if pred_label: _lab = int(arr_labels[bidx, _full_hidx]) # label index _frame.set_label_idx(_lab) _frame.set_label(voc.idx2word(_lab)) _frame.set_score(float(arr_scores[bidx, _full_hidx])) # -- # -- res_farrs = np.full(BK.get_shape(scores_t)[:-1] + [pred_max_layer], None, dtype=object) # [*, dlen, K] for _key, _values in tmp_farrs.items(): bidx, widx = _key _values = _values[:pred_max_layer] # truncate if more! res_farrs[bidx, widx, :len(_values)] = _values # return res_bidxes_t, res_widxes_t = BK.input_idx(res_bidxes), BK.input_idx( res_widxes) # [??] return (res_bidxes_t, res_widxes_t), res_frames, res_farrs # [??], [*, dlen, K]
def decode_arg(self, res_evts: List, arg_scores_t: BK.Expr, pred_max_layer: int, voc, arg_allowed_sent_gap: int, arr_efs): # first get topk arg_logprobs_t = arg_scores_t.log_softmax(-1) # [??, dlen, L] pred_arg_scores, pred_arg_labels = arg_logprobs_t.topk( pred_max_layer) # [??, dlen, K] arr_arg_scores, arr_arg_labels = BK.get_value( pred_arg_scores), BK.get_value(pred_arg_labels) # [??, dlen, K] # put results res_fidxes, res_widxes, res_args = [], [], [] # flattened results for fidx, evt in enumerate(res_evts): # for each evt item = evt._tmp_item # cached _evt_sidx = evt._tmp_sidx _dec_offsets = item.seq_info.dec_offsets for sidx, sent in enumerate(item.sents): if abs(sidx - _evt_sidx) > arg_allowed_sent_gap: continue # larger than allowed sentence gap _start = _dec_offsets[sidx] _len = len(sent) _arr_scores, _arr_labels = arr_arg_scores[fidx][ _start:_start + _len], arr_arg_labels[fidx][_start:_start + _len] for widx in range(_len): _full_widx = widx + _start # idx in the msent _new_ef = None if arr_efs is not None: # note: arr_efs should also expand to frames! _new_ef = arr_efs[ fidx, _full_widx, 0] # todo(+N): only get the first one! if _new_ef is None: continue # no ef! for _score, _lab in zip(_arr_scores[widx], _arr_labels[widx]): # [K] if _lab == 0: # note: idx=0 means NIL break # add new one!! res_fidxes.append(fidx) res_widxes.append(_full_widx) if _new_ef is None: _new_ef = sent.make_entity_filler( widx, 1) # share them if possible! _new_arg = evt.add_arg(_new_ef, role=voc.idx2word(_lab), score=float(_score)) _new_arg._tmp_sstart = _start # todo(+N): ugly tmp value ... res_args.append(_new_arg) # return res_fidxes_t, res_widxes_t = BK.input_idx(res_fidxes), BK.input_idx( res_widxes) # [??] return (res_fidxes_t, res_widxes_t), res_args
def decode_evt(self, dec, ibatch, evt_scores_t: BK.Expr): _pred_max_layer_evt = dec.conf.max_layer_evt _voc_evt = dec.voc_evt _pred_evt_label = self.conf.pred_evt_label # -- evt_logprobs_t = evt_scores_t.log_softmax(-1) # [*, dlen, L] pred_evt_scores, pred_evt_labels = evt_logprobs_t.topk( _pred_max_layer_evt) # [*, dlen, K] arr_evt_scores, arr_evt_labels = BK.get_value( pred_evt_scores), BK.get_value(pred_evt_labels) # [*, dlen, K] # put results res_bidxes, res_widxes, res_evts = [], [], [] # flattened results for bidx, item in enumerate( ibatch.items): # for each item in the batch _dec_offsets = item.seq_info.dec_offsets for sidx, sent in enumerate(item.sents): # note: here we only predict for center if there is! if item.center_sidx is not None and sidx != item.center_sidx: continue # skip non-center sent in this mode! _start = _dec_offsets[sidx] _len = len(sent) _arr_scores, _arr_labels = arr_evt_scores[bidx][ _start:_start + _len], arr_evt_labels[bidx][_start:_start + _len] for widx in range(_len): for _score, _lab in zip(_arr_scores[widx], _arr_labels[widx]): # [K] if _lab == 0: # note: idx=0 means NIL break # add new one!! res_bidxes.append(bidx) res_widxes.append( _start + widx) # note: remember to add offset! _new_evt = sent.make_event( widx, 1, type=(_voc_evt.idx2word(_lab) if _pred_evt_label else "UNK"), score=float(_score)) _new_evt.set_label_idx(int(_lab)) _new_evt._tmp_sstart = _start # todo(+N): ugly tmp value ... _new_evt._tmp_sidx = sidx _new_evt._tmp_item = item res_evts.append(_new_evt) # return res_bidxes_t, res_widxes_t = BK.input_idx(res_bidxes), BK.input_idx( res_widxes) # [??] return (res_bidxes_t, res_widxes_t), res_evts
def decode_evt_given(self, dec, ibatch, evt_scores_t: BK.Expr): _voc_evt = dec.voc_evt _assume_osof = dec.conf.assume_osof # one seq one frame _pred_evt_label = self.conf.pred_evt_label # -- if _pred_evt_label: evt_logprobs_t = evt_scores_t.log_softmax(-1) # [*, dlen, L] pred_evt_scores, pred_evt_labels = evt_logprobs_t.max( -1) # [*, dlen], note: maximum! arr_evt_scores, arr_evt_labels = BK.get_value( pred_evt_scores), BK.get_value(pred_evt_labels) # [*, dlen] else: arr_evt_scores = arr_evt_labels = None # -- # read given results res_bidxes, res_widxes, res_evts = [], [], [] # flattened results for bidx, item in enumerate( ibatch.items): # for each item in the batch _trg_evts = [item.inst] if _assume_osof else \ sum([sent.events for sidx,sent in enumerate(item.sents) if (item.center_sidx is None or sidx == item.center_sidx)],[]) # -- _dec_offsets = item.seq_info.dec_offsets for _evt in _trg_evts: sidx = item.sents.index(_evt.sent) _start = _dec_offsets[sidx] _full_hidx = _start + _evt.mention.shead_widx # add new one res_bidxes.append(bidx) res_widxes.append(_full_hidx) _evt._tmp_sstart = _start # todo(+N): ugly tmp value ... _evt._tmp_sidx = sidx _evt._tmp_item = item res_evts.append(_evt) # assign label? if _pred_evt_label: _lab = int(arr_evt_labels[bidx, _full_hidx]) # label index _evt.set_label_idx(_lab) _evt.set_label(_voc_evt.idx2word(_lab)) _evt.set_score(float(arr_evt_scores[bidx, _full_hidx])) # -- # -- # return res_bidxes_t, res_widxes_t = BK.input_idx(res_bidxes), BK.input_idx( res_widxes) # [??] return (res_bidxes_t, res_widxes_t), res_evts
def decode_arg(self, dec, res_evts: List, arg_scores_t: BK.Expr): _pred_max_layer_arg = dec.conf.max_layer_arg _arg_allowed_sent_gap = dec.conf.arg_allowed_sent_gap _voc_arg = dec.voc_arg # -- arg_logprobs_t = arg_scores_t.log_softmax(-1) # [??, dlen, L] pred_arg_scores, pred_arg_labels = arg_logprobs_t.topk( _pred_max_layer_arg) # [??, dlen, K] arr_arg_scores, arr_arg_labels = BK.get_value( pred_arg_scores), BK.get_value(pred_arg_labels) # [??, dlen, K] # put results res_fidxes, res_widxes, res_args = [], [], [] # flattened results for fidx, evt in enumerate(res_evts): # for each evt item = evt._tmp_item # cached _evt_sidx = evt._tmp_sidx _dec_offsets = item.seq_info.dec_offsets for sidx, sent in enumerate(item.sents): if abs(sidx - _evt_sidx) > _arg_allowed_sent_gap: continue # larger than allowed sentence gap _start = _dec_offsets[sidx] _len = len(sent) _arr_scores, _arr_labels = arr_arg_scores[fidx][ _start:_start + _len], arr_arg_labels[fidx][_start:_start + _len] for widx in range(_len): for _score, _lab in zip(_arr_scores[widx], _arr_labels[widx]): # [K] if _lab == 0: # note: idx=0 means NIL break # add new one!! res_fidxes.append(fidx) res_widxes.append(_start + widx) _new_ef = sent.make_entity_filler(widx, 1) _new_arg = evt.add_arg(_new_ef, role=_voc_arg.idx2word(_lab), score=float(_score)) _new_arg._tmp_sstart = _start # todo(+N): ugly tmp value ... res_args.append(_new_arg) # return res_fidxes_t, res_widxes_t = BK.input_idx(res_fidxes), BK.input_idx( res_widxes) # [??] return (res_fidxes_t, res_widxes_t), res_args
def decode_arg_bio(self, res_evts: List, arg_scores_t: BK.Expr, pred_max_layer: int, voc_bio, arg_allowed_sent_gap: int, arr_efs): assert pred_max_layer == 1, "Currently BIO only allow one!" assert arr_efs is None, "Currently BIO does not allow pick given efs!" _vocab_bio_arg = voc_bio _vocab_arg = _vocab_bio_arg.base_vocab # -- arg_logprobs_t = arg_scores_t.log_softmax(-1) # [??, dlen, L] # todo(+N): allow multiple bio seqs? pred_arg_scores, pred_arg_labels = arg_logprobs_t.max( -1) # note: here only top1, [??, dlen] arr_arg_scores, arr_arg_labels = BK.get_value( pred_arg_scores), BK.get_value(pred_arg_labels) # [??, dlen] # put results for fidx, evt in enumerate(res_evts): # for each evt item = evt._tmp_item # cached _evt_sidx = evt._tmp_sidx _dec_offsets = item.seq_info.dec_offsets for sidx, sent in enumerate(item.sents): if abs(sidx - _evt_sidx) > arg_allowed_sent_gap: continue # larger than allowed sentence gap _start = _dec_offsets[sidx] _len = len(sent) _arr_scores, _arr_labels = arr_arg_scores[fidx][ _start:_start + _len], arr_arg_labels[fidx][_start:_start + _len] # decode bio _arg_results = _vocab_bio_arg.tags2spans_idx(_arr_labels) for a_widx, a_wlen, a_lab in _arg_results: a_lab = int(a_lab) assert a_lab > 0, "Error: should not extract 'O'!?" _new_ef = sent.make_entity_filler(a_widx, a_wlen) a_role = _vocab_arg.idx2word(a_lab) _new_arg = evt.add_arg(_new_ef, a_role, score=np.mean( _arr_scores[a_widx:a_widx + a_wlen]).item()) # -- return # no need to return anything here
def _cri_cert(self, scores: BK.Expr): # 1.-uncertainty uncertainty = (scores.softmax(-1) * scores.log_softmax(-1)).sum(-1) / ( -math.log(BK.get_shape(scores, -1))) return 1. - uncertainty
def output_score(self, score: BK.Expr, local_normalize: bool): if local_normalize is None: local_normalize = self.conf.local_normalize if local_normalize: score = score.log_softmax(-1) return score