def prepare(self, insts: List[Sent], use_cache: bool): # get info if use_cache: zobjs = [] attr_name = f"_cache_srl" # should be unique for s in insts: one = getattr(s, attr_name, None) if one is None: one = self._prep_sent(s) setattr(s, attr_name, one) # set cache zobjs.append(one) else: zobjs = [self._prep_sent(s) for s in insts] # batch things bsize, mlen = len(insts), max(z.slen for z in zobjs) if len(zobjs)>0 else 1 # at least put one as padding batched_shape = (bsize, mlen) arr_items = np.full(batched_shape, None, dtype=object) arr_evt_labels = np.full(batched_shape, 0, dtype=np.int) arr_arg_labels = np.full(batched_shape+(mlen,), 0, dtype=np.int) for zidx, zobj in enumerate(zobjs): zlen = zobj.slen arr_items[zidx, :zlen] = zobj.evt_items arr_evt_labels[zidx, :zlen] = zobj.evt_arr arr_arg_labels[zidx, :zlen, :zlen] = zobj.arg_arr expr_evt_labels = BK.input_idx(arr_evt_labels) # [*, slen] expr_arg_labels = BK.input_idx(arr_arg_labels) # [*, slen] expr_loss_weight_non = BK.input_real([z.loss_weight_non for z in zobjs]) # [*] return arr_items, expr_evt_labels, expr_arg_labels, expr_loss_weight_non
def decode_frame(self, ibatch, scores_t: BK.Expr, pred_max_layer: int, voc, pred_label: bool, pred_tag: str, pred_check_layer: int): # -- # first get topk for each position logprobs_t = scores_t.log_softmax(-1) # [*, dlen, L] pred_scores, pred_labels = logprobs_t.topk( pred_max_layer) # [*, dlen, K] arr_scores, arr_labels = BK.get_value(pred_scores), BK.get_value( pred_labels) # [*, dlen, K] # put results res_bidxes, res_widxes, res_frames = [], [], [] # flattened results res_farrs = np.full(arr_scores.shape, None, dtype=object) # [*, dlen, K] for bidx, item in enumerate( ibatch.items): # for each item in the batch _dec_offsets = item.seq_info.dec_offsets for sidx, sent in enumerate(item.sents): # todo(+N): currently we only predict for center if there is! if item.center_sidx is not None and sidx != item.center_sidx: continue # skip non-center sent in this mode! _start = _dec_offsets[sidx] _len = len(sent) _arr_scores, _arr_labels = arr_scores[bidx][ _start:_start + _len], arr_labels[bidx][_start:_start + _len] for widx in range(_len): _full_widx = widx + _start # idx in the msent _tmp_set = set() for _k in range(pred_max_layer): _score, _lab = float(_arr_scores[widx][_k]), int( _arr_labels[widx][_k]) if _lab == 0: # note: lab=0 means NIL break _type_str = (voc.idx2word(_lab) if pred_label else "UNK") _type_str_prefix = '.'.join( _type_str.split('.')[:pred_check_layer]) if pred_check_layer >= 0 and _type_str_prefix in _tmp_set: continue # ignore since constraint _tmp_set.add(_type_str_prefix) # add new one! res_bidxes.append(bidx) res_widxes.append(_full_widx) _new_frame = sent.make_frame(widx, 1, tag=pred_tag, type=_type_str, score=float(_score)) _new_frame.set_label_idx(int(_lab)) _new_frame._tmp_sstart = _start # todo(+N): ugly tmp value ... _new_frame._tmp_sidx = sidx _new_frame._tmp_item = item res_frames.append(_new_frame) res_farrs[bidx, _full_widx, _k] = _new_frame # return res_bidxes_t, res_widxes_t = BK.input_idx(res_bidxes), BK.input_idx( res_widxes) # [??] return (res_bidxes_t, res_widxes_t), res_frames, res_farrs # [??], [*, dlen, K]
def decode_frame_given(self, ibatch, scores_t: BK.Expr, pred_max_layer: int, voc, pred_label: bool, pred_tag: str, assume_osof: bool): if pred_label: # if overwrite label! logprobs_t = scores_t.log_softmax(-1) # [*, dlen, L] pred_scores, pred_labels = logprobs_t.max( -1) # [*, dlen], note: maximum! arr_scores, arr_labels = BK.get_value(pred_scores), BK.get_value( pred_labels) # [*, dlen] else: arr_scores = arr_labels = None # -- # read given results res_bidxes, res_widxes, res_frames = [], [], [] # flattened results tmp_farrs = defaultdict(list) # later assign for bidx, item in enumerate( ibatch.items): # for each item in the batch _trg_frames = [item.inst] if assume_osof else \ sum([sent.get_frames(pred_tag) for sidx,sent in enumerate(item.sents) if (item.center_sidx is None or sidx == item.center_sidx)],[]) # still only pick center ones! # -- _dec_offsets = item.seq_info.dec_offsets for _frame in _trg_frames: # note: simply sort by original order! sidx = item.sents.index(_frame.sent) _start = _dec_offsets[sidx] _full_hidx = _start + _frame.mention.shead_widx # add new one res_bidxes.append(bidx) res_widxes.append(_full_hidx) _frame._tmp_sstart = _start # todo(+N): ugly tmp value ... _frame._tmp_sidx = sidx _frame._tmp_item = item res_frames.append(_frame) tmp_farrs[(bidx, _full_hidx)].append(_frame) # assign/rewrite label? if pred_label: _lab = int(arr_labels[bidx, _full_hidx]) # label index _frame.set_label_idx(_lab) _frame.set_label(voc.idx2word(_lab)) _frame.set_score(float(arr_scores[bidx, _full_hidx])) # -- # -- res_farrs = np.full(BK.get_shape(scores_t)[:-1] + [pred_max_layer], None, dtype=object) # [*, dlen, K] for _key, _values in tmp_farrs.items(): bidx, widx = _key _values = _values[:pred_max_layer] # truncate if more! res_farrs[bidx, widx, :len(_values)] = _values # return res_bidxes_t, res_widxes_t = BK.input_idx(res_bidxes), BK.input_idx( res_widxes) # [??] return (res_bidxes_t, res_widxes_t), res_frames, res_farrs # [??], [*, dlen, K]
def decode_arg(self, res_evts: List, arg_scores_t: BK.Expr, pred_max_layer: int, voc, arg_allowed_sent_gap: int, arr_efs): # first get topk arg_logprobs_t = arg_scores_t.log_softmax(-1) # [??, dlen, L] pred_arg_scores, pred_arg_labels = arg_logprobs_t.topk( pred_max_layer) # [??, dlen, K] arr_arg_scores, arr_arg_labels = BK.get_value( pred_arg_scores), BK.get_value(pred_arg_labels) # [??, dlen, K] # put results res_fidxes, res_widxes, res_args = [], [], [] # flattened results for fidx, evt in enumerate(res_evts): # for each evt item = evt._tmp_item # cached _evt_sidx = evt._tmp_sidx _dec_offsets = item.seq_info.dec_offsets for sidx, sent in enumerate(item.sents): if abs(sidx - _evt_sidx) > arg_allowed_sent_gap: continue # larger than allowed sentence gap _start = _dec_offsets[sidx] _len = len(sent) _arr_scores, _arr_labels = arr_arg_scores[fidx][ _start:_start + _len], arr_arg_labels[fidx][_start:_start + _len] for widx in range(_len): _full_widx = widx + _start # idx in the msent _new_ef = None if arr_efs is not None: # note: arr_efs should also expand to frames! _new_ef = arr_efs[ fidx, _full_widx, 0] # todo(+N): only get the first one! if _new_ef is None: continue # no ef! for _score, _lab in zip(_arr_scores[widx], _arr_labels[widx]): # [K] if _lab == 0: # note: idx=0 means NIL break # add new one!! res_fidxes.append(fidx) res_widxes.append(_full_widx) if _new_ef is None: _new_ef = sent.make_entity_filler( widx, 1) # share them if possible! _new_arg = evt.add_arg(_new_ef, role=voc.idx2word(_lab), score=float(_score)) _new_arg._tmp_sstart = _start # todo(+N): ugly tmp value ... res_args.append(_new_arg) # return res_fidxes_t, res_widxes_t = BK.input_idx(res_fidxes), BK.input_idx( res_widxes) # [??] return (res_fidxes_t, res_widxes_t), res_args
def decode_evt(self, dec, ibatch, evt_scores_t: BK.Expr): _pred_max_layer_evt = dec.conf.max_layer_evt _voc_evt = dec.voc_evt _pred_evt_label = self.conf.pred_evt_label # -- evt_logprobs_t = evt_scores_t.log_softmax(-1) # [*, dlen, L] pred_evt_scores, pred_evt_labels = evt_logprobs_t.topk( _pred_max_layer_evt) # [*, dlen, K] arr_evt_scores, arr_evt_labels = BK.get_value( pred_evt_scores), BK.get_value(pred_evt_labels) # [*, dlen, K] # put results res_bidxes, res_widxes, res_evts = [], [], [] # flattened results for bidx, item in enumerate( ibatch.items): # for each item in the batch _dec_offsets = item.seq_info.dec_offsets for sidx, sent in enumerate(item.sents): # note: here we only predict for center if there is! if item.center_sidx is not None and sidx != item.center_sidx: continue # skip non-center sent in this mode! _start = _dec_offsets[sidx] _len = len(sent) _arr_scores, _arr_labels = arr_evt_scores[bidx][ _start:_start + _len], arr_evt_labels[bidx][_start:_start + _len] for widx in range(_len): for _score, _lab in zip(_arr_scores[widx], _arr_labels[widx]): # [K] if _lab == 0: # note: idx=0 means NIL break # add new one!! res_bidxes.append(bidx) res_widxes.append( _start + widx) # note: remember to add offset! _new_evt = sent.make_event( widx, 1, type=(_voc_evt.idx2word(_lab) if _pred_evt_label else "UNK"), score=float(_score)) _new_evt.set_label_idx(int(_lab)) _new_evt._tmp_sstart = _start # todo(+N): ugly tmp value ... _new_evt._tmp_sidx = sidx _new_evt._tmp_item = item res_evts.append(_new_evt) # return res_bidxes_t, res_widxes_t = BK.input_idx(res_bidxes), BK.input_idx( res_widxes) # [??] return (res_bidxes_t, res_widxes_t), res_evts
def forward(self, inputs, add_bos=False, add_eos=False): conf: PlainInputEmbedderConf = self.conf # -- voc = self.voc input_t = BK.input_idx(inputs) # [*, len] # rare unk in training if self.is_training() and self.use_rare_unk: rare_unk_rate = conf.rare_unk_rate cur_unk_imask = ( self.rare_unk_mask[input_t] * (BK.rand(BK.get_shape(input_t)) < rare_unk_rate)).long() input_t = input_t * (1 - cur_unk_imask) + voc.unk * cur_unk_imask # bos and eos all_input_slices = [] slice_shape = BK.get_shape(input_t)[:-1] + [1] if add_bos: all_input_slices.append( BK.constants(slice_shape, voc.bos, dtype=input_t.dtype)) all_input_slices.append(input_t) # [*, len] if add_eos: all_input_slices.append( BK.constants(slice_shape, voc.eos, dtype=input_t.dtype)) final_input_t = BK.concat(all_input_slices, -1) # [*, 1?+len+1?] # finally ret = self.E(final_input_t) # [*, ??, dim] return ret
def prepare(self, insts: Union[List[Sent], List[Frame]], mlen: int, use_cache: bool): conf: SeqExtractorConf = self.conf # get info if use_cache: zobjs = [] attr_name = f"_scache_{conf.ftag}" # should be unique for s in insts: one = getattr(s, attr_name, None) if one is None: one = self._prep_f(s) setattr(s, attr_name, one) # set cache zobjs.append(one) else: zobjs = [self._prep_f(s) for s in insts] # batch things bsize = len(insts) # mlen = max(z.len for z in zobjs) # note: fed by outside!! batched_shape = (bsize, mlen) # arr_first_items = np.full(batched_shape, None, dtype=object) arr_slabs = np.full(batched_shape, 0, dtype=np.int) for zidx, zobj in enumerate(zobjs): # arr_first_items[zidx, zobj.len] = zobj.first_items arr_slabs[zidx, :zobj.len] = zobj.tags # final setup things expr_slabs = BK.input_idx(arr_slabs) expr_loss_weight_non = BK.input_real( [z.loss_weight_non for z in zobjs]) # [*] # return arr_first_items, expr_slabs, expr_loss_weight_non return expr_slabs, expr_loss_weight_non
def forward(self, inputs, add_bos=False, add_eos=False): conf: PosiInputEmbedderConf = self.conf # -- try: # input is a shape as prepared by "PosiHelper" batch_size, max_len = inputs if add_bos: max_len += 1 if add_eos: max_len += 1 posi_idxes = BK.arange_idx(max_len) # [?len?] ret = self.E(posi_idxes).unsqueeze(0).expand(batch_size, -1, -1) except: # input is tensor posi_idxes = BK.input_idx(inputs) # [*, len] cur_maxlen = BK.get_shape(posi_idxes, -1) # -- all_input_slices = [] slice_shape = BK.get_shape(posi_idxes)[:-1] + [1] if add_bos: # add 0 and offset all_input_slices.append( BK.constants(slice_shape, 0, dtype=posi_idxes.dtype)) cur_maxlen += 1 posi_idxes += 1 all_input_slices.append(posi_idxes) # [*, len] if add_eos: all_input_slices.append( BK.constants(slice_shape, cur_maxlen, dtype=posi_idxes.dtype)) final_input_t = BK.concat(all_input_slices, -1) # [*, 1?+len+1?] # finally ret = self.E(final_input_t) # [*, ??, dim] return ret
def get_dec_mask(self, ibatch, center_only: bool): if center_only: center_idxes = BK.input_idx([z.center_sidx for z in ibatch.items]).unsqueeze(-1) # [bs, 1] ret_mask = (ibatch.seq_info.dec_sent_idxes == center_idxes).float() # [bs, dlen] else: # otherwise, simply further exclude CLS/PAD ret_mask = (ibatch.seq_info.dec_sent_idxes >= 0).float() # [*, dlen] # ret_mask *= ibatch.seq_info.dec_sel_masks # [*, dlen], note: no need for this return ret_mask
def _get_arg_external_extra_score(self, flt_items): if self.cons_arg is not None: evt_idxes = [(0 if z is None else z.label_idx) for z in flt_items] valid_masks = self.cons_arg.lookup(BK.input_idx(evt_idxes)) # [*, L] ret = Constants.REAL_PRAC_MIN * (1. - valid_masks) # [*, L] return ret.unsqueeze(-2) # [bs, 1, L], let later broadcast! else: return None
def batched_rev_idxes(self): if self._batched_rev_idxes is None: padder = DataPadder(2, pad_vals=0) # again pad 0 batched_rev_idxes, _ = padder.pad([ s.align_info.split2orig for s in self.seq_subs ]) # [bsize, sub_len] self._batched_rev_idxes = BK.input_idx(batched_rev_idxes) return self._batched_rev_idxes # [bsize, sub_len]
def decode_evt_given(self, dec, ibatch, evt_scores_t: BK.Expr): _voc_evt = dec.voc_evt _assume_osof = dec.conf.assume_osof # one seq one frame _pred_evt_label = self.conf.pred_evt_label # -- if _pred_evt_label: evt_logprobs_t = evt_scores_t.log_softmax(-1) # [*, dlen, L] pred_evt_scores, pred_evt_labels = evt_logprobs_t.max( -1) # [*, dlen], note: maximum! arr_evt_scores, arr_evt_labels = BK.get_value( pred_evt_scores), BK.get_value(pred_evt_labels) # [*, dlen] else: arr_evt_scores = arr_evt_labels = None # -- # read given results res_bidxes, res_widxes, res_evts = [], [], [] # flattened results for bidx, item in enumerate( ibatch.items): # for each item in the batch _trg_evts = [item.inst] if _assume_osof else \ sum([sent.events for sidx,sent in enumerate(item.sents) if (item.center_sidx is None or sidx == item.center_sidx)],[]) # -- _dec_offsets = item.seq_info.dec_offsets for _evt in _trg_evts: sidx = item.sents.index(_evt.sent) _start = _dec_offsets[sidx] _full_hidx = _start + _evt.mention.shead_widx # add new one res_bidxes.append(bidx) res_widxes.append(_full_hidx) _evt._tmp_sstart = _start # todo(+N): ugly tmp value ... _evt._tmp_sidx = sidx _evt._tmp_item = item res_evts.append(_evt) # assign label? if _pred_evt_label: _lab = int(arr_evt_labels[bidx, _full_hidx]) # label index _evt.set_label_idx(_lab) _evt.set_label(_voc_evt.idx2word(_lab)) _evt.set_score(float(arr_evt_scores[bidx, _full_hidx])) # -- # -- # return res_bidxes_t, res_widxes_t = BK.input_idx(res_bidxes), BK.input_idx( res_widxes) # [??] return (res_bidxes_t, res_widxes_t), res_evts
def prepare(self, ibatch): b_seq_info = ibatch.seq_info bsize, dlen = BK.get_shape(b_seq_info.dec_sel_masks) arr_udep_labels = np.full([bsize, dlen, dlen], 0, dtype=np.int) # by default 0 arr_head = np.full([bsize, dlen], -1, dtype=np.int) # the 0 ones are root for bidx, item in enumerate(ibatch.items): # for each item in the batch _dec_offsets = item.seq_info.dec_offsets for sidx, sent in enumerate(item.sents): # for each sent in the msent item tree = sent.tree_dep _start = _dec_offsets[sidx] _slen = len(sent) # note: here transpose it: (h,m), arti-root not included! arr_udep_labels[bidx, _start:_start+_slen, _start:_start+_slen] = tree.label_matrix[:, 1:].T arr_head[bidx, _start:_start+_slen] = tree.seq_head.vals # -- expr_udep_labels = BK.input_idx(arr_udep_labels) # [bs, dlen, dlen] expr_isroot = (BK.input_idx(arr_head) == 0).long() # [bs, dlen] return expr_udep_labels, expr_isroot
def prepare(self, insts: List, use_cache: bool = None): bsize, mlen = len(insts), max(len(z.sent) for z in insts) if len(insts) > 0 else 1 batched_shape = (bsize, mlen) arr_upos_labels = np.full(batched_shape, 0, dtype=np.int) for bidx, inst in enumerate(insts): zlen = len(inst.sent) arr_upos_labels[bidx, :zlen] = inst.sent.seq_upos.idxes expr_upos_labels = BK.input_idx(arr_upos_labels) return expr_upos_labels
def prepare(self, insts: Union[List[Sent], List[Frame]], mlen: int, use_cache: bool): conf: AnchorExtractorConf = self.conf # get info if use_cache: zobjs = [] attr_name = f"_acache_{conf.ftag}" # should be unique for s in insts: one = getattr(s, attr_name, None) if one is None: one = self._prep_f(s) setattr(s, attr_name, one) # set cache zobjs.append(one) else: zobjs = [self._prep_f(s) for s in insts] # batch things bsize, mlen2 = len(insts), max(len(z.items) for z in zobjs) if len(zobjs) > 0 else 1 mnum = max(len(g) for z in zobjs for g in z.group_widxes) if len(zobjs) > 0 else 1 arr_items = np.full((bsize, mlen2), None, dtype=object) # [*, ?] arr_seq_iidxes = np.full((bsize, mlen), -1, dtype=np.int) arr_seq_labs = np.full((bsize, mlen), 0, dtype=np.int) arr_group_widxes = np.full((bsize, mlen, mnum), 0, dtype=np.int) arr_group_masks = np.full((bsize, mlen, mnum), 0., dtype=np.float) for zidx, zobj in enumerate(zobjs): arr_items[zidx, :len(zobj.items)] = zobj.items iidx_offset = zidx * mlen2 # note: offset for valid ones! arr_seq_iidxes[zidx, :len(zobj.seq_iidxes)] = [ (iidx_offset + ii) if ii >= 0 else ii for ii in zobj.seq_iidxes ] arr_seq_labs[zidx, :len(zobj.seq_labs)] = zobj.seq_labs for zidx2, zwidxes in enumerate(zobj.group_widxes): arr_group_widxes[zidx, zidx2, :len(zwidxes)] = zwidxes arr_group_masks[zidx, zidx2, :len(zwidxes)] = 1. # final setup things expr_seq_iidxes = BK.input_idx(arr_seq_iidxes) # [*, slen] expr_seq_labs = BK.input_idx(arr_seq_labs) # [*, slen] expr_group_widxes = BK.input_idx(arr_group_widxes) # [*, slen, MW] expr_group_masks = BK.input_real(arr_group_masks) # [*, slen, MW] expr_loss_weight_non = BK.input_real( [z.loss_weight_non for z in zobjs]) # [*] return arr_items, expr_seq_iidxes, expr_seq_labs, expr_group_widxes, expr_group_masks, expr_loss_weight_non
def decode_arg(self, dec, res_evts: List, arg_scores_t: BK.Expr): _pred_max_layer_arg = dec.conf.max_layer_arg _arg_allowed_sent_gap = dec.conf.arg_allowed_sent_gap _voc_arg = dec.voc_arg # -- arg_logprobs_t = arg_scores_t.log_softmax(-1) # [??, dlen, L] pred_arg_scores, pred_arg_labels = arg_logprobs_t.topk( _pred_max_layer_arg) # [??, dlen, K] arr_arg_scores, arr_arg_labels = BK.get_value( pred_arg_scores), BK.get_value(pred_arg_labels) # [??, dlen, K] # put results res_fidxes, res_widxes, res_args = [], [], [] # flattened results for fidx, evt in enumerate(res_evts): # for each evt item = evt._tmp_item # cached _evt_sidx = evt._tmp_sidx _dec_offsets = item.seq_info.dec_offsets for sidx, sent in enumerate(item.sents): if abs(sidx - _evt_sidx) > _arg_allowed_sent_gap: continue # larger than allowed sentence gap _start = _dec_offsets[sidx] _len = len(sent) _arr_scores, _arr_labels = arr_arg_scores[fidx][ _start:_start + _len], arr_arg_labels[fidx][_start:_start + _len] for widx in range(_len): for _score, _lab in zip(_arr_scores[widx], _arr_labels[widx]): # [K] if _lab == 0: # note: idx=0 means NIL break # add new one!! res_fidxes.append(fidx) res_widxes.append(_start + widx) _new_ef = sent.make_entity_filler(widx, 1) _new_arg = evt.add_arg(_new_ef, role=_voc_arg.idx2word(_lab), score=float(_score)) _new_arg._tmp_sstart = _start # todo(+N): ugly tmp value ... res_args.append(_new_arg) # return res_fidxes_t, res_widxes_t = BK.input_idx(res_fidxes), BK.input_idx( res_widxes) # [??] return (res_fidxes_t, res_widxes_t), res_args
def prepare(self, insts: Union[List[Sent], List[Frame]], use_cache: bool): conf: DirectExtractorConf = self.conf # get info if use_cache: zobjs = [] attr_name = f"_dcache_{conf.ftag}" # should be unique for s in insts: one = getattr(s, attr_name, None) if one is None: one = self._prep_f(s) setattr(s, attr_name, one) # set cache zobjs.append(one) else: zobjs = [self._prep_f(s) for s in insts] # batch things bsize, mlen = len(insts), max(z.len for z in zobjs) if len(zobjs)>0 else 1 # at least put one as padding batched_shape = (bsize, mlen) arr_items = np.full(batched_shape, None, dtype=object) arr_gaddrs = np.arange(bsize*mlen).reshape(batched_shape) # gold address arr_core_widxes = np.full(batched_shape, 0, dtype=np.int) arr_core_wlens = np.full(batched_shape, 1, dtype=np.int) # arr_ext_widxes = np.full(batched_shape, 0, dtype=np.int) # arr_ext_wlens = np.full(batched_shape, 1, dtype=np.int) for zidx, zobj in enumerate(zobjs): zlen = zobj.len arr_items[zidx, :zlen] = zobj.items arr_core_widxes[zidx, :zlen] = zobj.core_widxes arr_core_wlens[zidx, :zlen] = zobj.core_wlens # arr_ext_widxes[zidx, :zlen] = zobj.ext_widxes # arr_ext_wlens[zidx, :zlen] = zobj.ext_wlens arr_gaddrs[arr_items==None] = -1 # set -1 as gaddr # final setup things expr_gaddr = BK.input_idx(arr_gaddrs) # [*, GLEN] expr_core_widxes = BK.input_idx(arr_core_widxes) expr_core_wlens = BK.input_idx(arr_core_wlens) # expr_ext_widxes = BK.input_idx(arr_ext_widxes) # expr_ext_wlens = BK.input_idx(arr_ext_wlens) expr_loss_weight_non = BK.input_real([z.loss_weight_non for z in zobjs]) # [*] # return arr_flatten_items, expr_gaddr, expr_core_widxes, expr_core_wlens, \ # expr_ext_widxes, expr_ext_wlens, expr_loss_weight_non return arr_items, expr_gaddr, expr_core_widxes, expr_core_wlens, expr_loss_weight_non
def prepare(self, insts: List[Sent], use_cache: bool): # get info zobjs = ZDecoderHelper.get_zobjs(insts, self._prep_sent, use_cache, f"_cache_srl") # batch things bsize, mlen = len(insts), max(z.slen for z in zobjs) if len(zobjs)>0 else 1 # at least put one as padding batched_shape = (bsize, mlen) arr_items = np.full(batched_shape, None, dtype=object) arr_evt_labels = np.full(batched_shape, 0, dtype=np.int) arr_arg_labels = np.full(batched_shape+(mlen,), 0, dtype=np.int) arr_arg2_labels = np.full(batched_shape+(mlen,), 0, dtype=np.int) for zidx, zobj in enumerate(zobjs): zlen = zobj.slen arr_items[zidx, :zlen] = zobj.evt_items arr_evt_labels[zidx, :zlen] = zobj.evt_arr arr_arg_labels[zidx, :zlen, :zlen] = zobj.arg_arr arr_arg2_labels[zidx, :zlen, :zlen] = zobj.arg2_arr expr_evt_labels = BK.input_idx(arr_evt_labels) # [*, slen] expr_arg_labels = BK.input_idx(arr_arg_labels) # [*, slen, slen] expr_arg2_labels = BK.input_idx(arr_arg2_labels) # [*, slen, slen] expr_loss_weight_non = BK.input_real([z.loss_weight_non for z in zobjs]) # [*] return arr_items, expr_evt_labels, expr_arg_labels, expr_arg2_labels, expr_loss_weight_non
def loss(self, flt_items, flt_input_expr, flt_pair_expr, flt_full_expr, flt_mask_expr, flt_extra_weights=None): conf: ExtenderConf = self.conf _loss_lambda = conf._loss_lambda # -- enc_t = self._forward_eenc(flt_input_expr, flt_full_expr, flt_mask_expr) # [*, slen, D] s_left, s_right = self.enode.score(enc_t, flt_pair_expr if conf.ext_use_finput else None, flt_mask_expr) # [*, slen] # -- gold_posi = [self.ext_span_getter(z.mention) for z in flt_items] # List[(widx, wlen)] widx_t = BK.input_idx([z[0] for z in gold_posi]) # [*] wlen_t = BK.input_idx([z[1] for z in gold_posi]) loss_left_t, loss_right_t = BK.loss_nll(s_left, widx_t), BK.loss_nll(s_right, widx_t+wlen_t-1) # [*] if flt_extra_weights is not None: loss_left_t *= flt_extra_weights loss_right_t *= flt_extra_weights loss_div = flt_extra_weights.sum() # note: also use this! else: loss_div = BK.constants([len(flt_items)], value=1.).sum() loss_left_item = LossHelper.compile_leaf_loss("left", loss_left_t.sum(), loss_div, loss_lambda=_loss_lambda) loss_right_item = LossHelper.compile_leaf_loss("right", loss_right_t.sum(), loss_div, loss_lambda=_loss_lambda) ret_loss = LossHelper.combine_multiple_losses([loss_left_item, loss_right_item]) return ret_loss
def forward(self, med: ZMediator): # -- # get hid_t hid_t0 = med.get_enc_cache_val("hid") sinfo = med.ibatch.seq_info _arange_t, _sel_t = sinfo.arange2_t, sinfo.dec_sel_idxes hid_t = hid_t0[_arange_t, _sel_t] # [*, dlen, D] # -- # prepare relations bsize, dlen = BK.get_shape(sinfo.dec_sel_masks) arr_rels = np.full([bsize, dlen, dlen], 0, dtype=np.int) # by default 0 arr_labs = np.full([bsize, dlen], 0, dtype=np.int) for bidx, item in enumerate( med.ibatch.items): # for each item in the batch _dec_offsets = item.seq_info.dec_offsets for sidx, sent in enumerate( item.sents): # for each sent in the msent item tree = sent.tree_dep _start = _dec_offsets[sidx] _slen = len(sent) _arr_ms = np.asarray(range(_slen)) + _start # [??] _arr_hs = np.asarray(tree.seq_head.vals) + ( _start - 1) # note(+N): need to do more if msent!! _arr_labs = np.asarray(tree.seq_label.idxes) # [??] arr_labs[bidx, _start:_start + _slen] = _arr_labs arr_rels[bidx, _arr_hs, _arr_ms] = _arr_labs arr_rels[bidx, _arr_ms, _arr_hs] = -_arr_labs expr_labs = BK.input_idx(arr_rels) # [*, dlen, dlen] # -- # go through res_t = hid_t if self.type_emb is not None: expr_seq_labs = BK.input_idx(arr_labs) # [*, dlen] lab_t = self.type_emb(expr_seq_labs) res_t = res_t + lab_t for node in self.nodes: res_t = node.forward(res_t, expr_labs, sinfo.dec_sel_masks) med.layer_end({'hid': res_t}) # step once! return res_t
def loss(self, insts: Union[List[Sent], List[Frame]], input_expr: BK.Expr, mask_expr: BK.Expr, pair_expr: BK.Expr = None, lookup_flatten=False, external_extra_score: BK.Expr=None): conf: DirectExtractorConf = self.conf # step 0: prepare golds arr_gold_items, expr_gold_gaddr, expr_gold_widxes, expr_gold_wlens, expr_loss_weight_non = \ self.helper.prepare(insts, use_cache=True) # step 1: extract cands if conf.loss_use_posi: cand_res = self.extract_node.go_lookup( input_expr, expr_gold_widxes, expr_gold_wlens, (expr_gold_gaddr>=0).float(), gaddr_expr=expr_gold_gaddr) else: # todo(note): assume no in-middle mask!! cand_widx, cand_wlen, cand_mask, cand_gaddr = self.extract_node.prepare_with_lengths( BK.get_shape(mask_expr), mask_expr.sum(-1).long(), expr_gold_widxes, expr_gold_wlens, expr_gold_gaddr) if conf.span_train_sample: # simply do sampling cand_res = self.extract_node.go_sample( input_expr, mask_expr, cand_widx, cand_wlen, cand_mask, rate=conf.span_train_sample_rate, count=conf.span_train_sample_count, gaddr_expr=cand_gaddr, add_gold_rate=1.0) # note: always fully add gold for sampling!! else: # beam pruner using topk cand_res = self.extract_node.go_topk( input_expr, mask_expr, cand_widx, cand_wlen, cand_mask, rate=conf.span_topk_rate, count=conf.span_topk_count, gaddr_expr=cand_gaddr, add_gold_rate=conf.span_train_topk_add_gold_rate) # step 1+: prepare for labeling cand_gold_mask = (cand_res.gaddr_expr>=0).float() * cand_res.mask_expr # [*, cand_len] # todo(note): add a 0 as idx=-1 to make NEG ones as 0!! flatten_gold_label_idxes = BK.input_idx([(0 if z is None else z.label_idx) for z in arr_gold_items.flatten()] + [0]) gold_label_idxes = flatten_gold_label_idxes[cand_res.gaddr_expr] cand_loss_weights = BK.where(gold_label_idxes==0, expr_loss_weight_non.unsqueeze(-1)*conf.loss_weight_non, cand_res.mask_expr) final_loss_weights = cand_loss_weights * cand_res.mask_expr # cand loss if conf.loss_cand > 0. and not conf.loss_use_posi: loss_cand0 = BK.loss_binary(cand_res.score_expr, cand_gold_mask, label_smoothing=conf.cand_label_smoothing) loss_cand = (loss_cand0 * final_loss_weights).sum() loss_cand_item = LossHelper.compile_leaf_loss(f"cand", loss_cand, final_loss_weights.sum(), loss_lambda=conf.loss_cand) else: loss_cand_item = None # extra score cand_extra_score = self._get_extra_score( cand_res.score_expr, insts, cand_res, arr_gold_items, conf.loss_use_cons, conf.loss_use_lu) final_extra_score = self._sum_scores(external_extra_score, cand_extra_score) # step 2: label; with special weights loss_lab, loss_count = self.lab_node.loss( cand_res.span_expr, pair_expr, cand_res.mask_expr, gold_label_idxes, loss_weight_expr=final_loss_weights, extra_score=final_extra_score) loss_lab_item = LossHelper.compile_leaf_loss(f"lab", loss_lab, loss_count, loss_lambda=conf.loss_lab, gold=cand_gold_mask.sum()) # == # return loss ret_loss = LossHelper.combine_multiple_losses([loss_cand_item, loss_lab_item]) return self._finish_loss(ret_loss, insts, input_expr, mask_expr, pair_expr, lookup_flatten)
def __init__(self, ibatch: InputBatch, IDX_PAD: int): # preps self.bsize = len(ibatch) self.arange1_t = BK.arange_idx(self.bsize) # [bsize] self.arange2_t = self.arange1_t.unsqueeze(-1) # [bsize, 1] self.arange3_t = self.arange2_t.unsqueeze(-1) # [bsize, 1, 1] # batched them all_seq_infos = [z.seq_info for z in ibatch.items] # enc: [*, len_enc]: ids(pad IDX_PAD), masks, segids(pad 0) self.enc_input_ids = BK.input_idx( DataPadder.go_batch_2d([z.enc_input_ids for z in all_seq_infos], int(IDX_PAD))) self.enc_input_masks = BK.input_real( DataPadder.lengths2mask( [len(z.enc_input_ids) for z in all_seq_infos])) self.enc_input_segids = BK.input_idx( DataPadder.go_batch_2d([z.enc_input_segids for z in all_seq_infos], 0)) # dec: [*, len_dec]: sel_idxes(pad 0), sel_lens(pad 1), masks, sent_idxes(pad ??) self.dec_sel_idxes = BK.input_idx( DataPadder.go_batch_2d([z.dec_sel_idxes for z in all_seq_infos], 0)) self.dec_sel_lens = BK.input_idx( DataPadder.go_batch_2d([z.dec_sel_lens for z in all_seq_infos], 1)) self.dec_sel_masks = BK.input_real( DataPadder.lengths2mask( [len(z.dec_sel_idxes) for z in all_seq_infos])) _max_dec_len = BK.get_shape(self.dec_sel_masks, 1) _dec_offsets = BK.input_idx( DataPadder.go_batch_2d([z.dec_offsets for z in all_seq_infos], _max_dec_len)) # note: CLS as -1, then 0,1,2,..., PAD gets -2! self.dec_sent_idxes = \ (BK.arange_idx(_max_dec_len).unsqueeze(0).unsqueeze(-1) >= _dec_offsets.unsqueeze(-2)).sum(-1).long() - 1 self.dec_sent_idxes[self.dec_sel_masks <= 0.] = -2 # dec -> enc: [*, len_enc] (calculated on needed!) # note: require 1-to-1 mapping (except pads)!! self._enc_back_hits = None self._enc_back_sel_idxes = None
def __init__(self, berter: BertEncoder, seq_subs: List[InputSubwordSeqField]): self.seq_subs = seq_subs self.berter = berter self.bsize = len(seq_subs) self.arange1_t = BK.arange_idx(self.bsize) # [bsize] self.arange2_t = self.arange1_t.unsqueeze(-1) # [bsize, 1] self.arange3_t = self.arange2_t.unsqueeze(-1) # [bsize, 1, 1] # -- tokenizer = self.berter.tokenizer PAD_IDX = tokenizer.pad_token_id # MASK_IDX = tokenizer.mask_token_id # CLS_IDX_l = [tokenizer.cls_token_id] # SEP_IDX_l = [tokenizer.sep_token_id] # make batched idxes padder = DataPadder(2, pad_vals=PAD_IDX, mask_range=2) batched_sublens = [len(s.idxes) for s in seq_subs] # [bsize] batched_input_ids, batched_input_mask = padder.pad( [s.idxes for s in seq_subs]) # [bsize, sub_len] self.batched_sublens_p1 = BK.input_idx( batched_sublens ) + 1 # also the idx of EOS (if counting including BOS) self.batched_input_ids = BK.input_idx(batched_input_ids) self.batched_input_mask = BK.input_real(batched_input_mask) # make batched mappings (sub->orig) padder2 = DataPadder(2, pad_vals=0, mask_range=2) # pad as 0 to avoid out-of-range batched_first_idxes, batched_first_mask = padder2.pad( [s.align_info.orig2begin for s in seq_subs]) # [bsize, orig_len] self.batched_first_idxes = BK.input_idx(batched_first_idxes) self.batched_first_mask = BK.input_real(batched_first_mask) # reversed batched_mappings (orig->sub) (created when needed) self._batched_rev_idxes = None # [bsize, sub_len] # -- self.batched_repl_masks = None # [bsize, sub_len], to replace with MASK self.batched_token_type_ids = None # [bsize, 1+sub_len+1] self.batched_position_ids = None # [bsize, 1+sub_len+1] self.other_factors = {} # name -> aug_batched_ids
def prepare(self, ibatch): b_seq_info = ibatch.seq_info arr_upos_labels = np.full(BK.get_shape(b_seq_info.dec_sel_masks), 0, dtype=np.int) # by default 0 for bidx, item in enumerate( ibatch.items): # for each item in the batch _dec_offsets = item.seq_info.dec_offsets for sidx, sent in enumerate( item.sents): # for each sent in the msent item _start = _dec_offsets[sidx] arr_upos_labels[bidx, _start:_start + len(sent)] = sent.seq_upos.idxes expr_upos_labels = BK.input_idx(arr_upos_labels) # [bs, dlen] return expr_upos_labels
def prepare(self, insts: List, use_cache: bool): # get info zobjs = ZDecoderHelper.get_zobjs(insts, self._prep_inst, use_cache, f"_cache_udep") # then bsize, mlen = len(insts), max(z.slen for z in zobjs) if len(zobjs)>0 else 1 batched_shape = (bsize, mlen) arr_depth = np.full(batched_shape, 0., dtype=np.float) # [*, slen] arr_udep = np.full(batched_shape+(mlen,), 0, dtype=np.int) # [*, slen_m, slen_h] for zidx, zobj in enumerate(zobjs): zlen = zobj.slen arr_depth[zidx, :zlen] = zobj.depth_arr arr_udep[zidx, :zlen, :zlen] = zobj.udep_arr expr_depth = BK.input_real(arr_depth) # [*, slen] expr_udep = BK.input_idx(arr_udep) # [*, slen_m, slen_h] return expr_depth, expr_udep
def _transform_factors(self, factors: Union[List[List[int]], BK.Expr], is_orig: bool, PAD_IDX: Union[int, float]): if isinstance(factors, BK.Expr): # already padded batched_ids = factors else: padder = DataPadder(2, pad_vals=PAD_IDX) batched_ids, _ = padder.pad(factors) batched_ids = BK.input_idx( batched_ids) # [bsize, orig-len if is_orig else sub_len] if is_orig: # map to subtoks final_batched_ids = batched_ids[ self.arange2_t, self.batched_rev_idxes] # [bsize, sub_len] else: final_batched_ids = batched_ids # [bsize, sub_len] return final_batched_ids
def _common_nmst(CPU_f, scores_expr, mask_expr, lengths_arr, labeled, ret_arr): assert labeled with BK.no_grad_env(): # argmax-label: [BS, m, h] scores_unlabeled_max, labels_argmax = scores_expr.max(-1) # scores_unlabeled_max_arr = BK.get_value(scores_unlabeled_max) mst_heads_arr, _, mst_scores_arr = CPU_f(scores_unlabeled_max_arr, lengths_arr, labeled=False) # [BS, m] mst_heads_expr = BK.input_idx(mst_heads_arr) mst_labels_expr = BK.gather_one_lastdim(labels_argmax, mst_heads_expr).squeeze(-1) # prepare for the outputs if ret_arr: return mst_heads_arr, BK.get_value(mst_labels_expr), mst_scores_arr else: return mst_heads_expr, mst_labels_expr, BK.input_real(mst_scores_arr)
def __init__(self, conf: SrlInferenceHelperConf, dec: 'ZDecoderSrl', **kwargs): super().__init__(conf, **kwargs) conf: SrlInferenceHelperConf = self.conf # -- self.setattr_borrow('dec', dec) self.arg_pp = PostProcessor(conf.arg_pp) # -- self.lu_cons, self.role_cons = None, None if conf.frames_name: # currently only frame->role from msp2.data.resources import get_frames_label_budgets flb = get_frames_label_budgets(conf.frames_name) _voc_ef, _voc_evt, _voc_arg = dec.ztask.vpack _role_cons = fchelper.build_constraint_arrs( flb, _voc_arg, _voc_evt) self.role_cons = BK.input_real(_role_cons) if conf.frames_file: _voc_ef, _voc_evt, _voc_arg = dec.ztask.vpack _fc = default_pickle_serializer.from_file(conf.frames_file) _lu_cons = fchelper.build_constraint_arrs( fchelper.build_lu_map(_fc), _voc_evt, warning=False) # lexicon->frame _role_cons = fchelper.build_constraint_arrs( fchelper.build_role_map(_fc), _voc_arg, _voc_evt) # frame->role self.lu_cons, self.role_cons = _lu_cons, BK.input_real(_role_cons) # -- self.cons_evt_tok_f = conf.get_cons_evt_tok() self.cons_evt_frame_f = conf.get_cons_evt_frame() if self.dec.conf.arg_use_bio: # extend for bio! self.cons_arg_bio_sels = BK.input_idx( self.dec.vocab_bio_arg.get_bio2origin()) else: self.cons_arg_bio_sels = None # -- from msp2.data.resources.frames import KBP17_TYPES self.pred_evt_filter = { 'kbp17': KBP17_TYPES }.get(conf.pred_evt_filter, None)
def forward(self, inputs, add_bos=False, add_eos=False): conf: CharCnnInputEmbedderConf = self.conf # -- voc = self.voc char_input_t = BK.input_idx(inputs) # [*, len] # todo(note): no need for replacing to unk for char!! # bos and eos all_input_slices = [] slice_shape = BK.get_shape(char_input_t) slice_shape[-2] = 1 # [*, 1, clen] if add_bos: all_input_slices.append( BK.constants(slice_shape, voc.bos, dtype=char_input_t.dtype)) all_input_slices.append(char_input_t) # [*, len, clen] if add_eos: all_input_slices.append( BK.constants(slice_shape, voc.eos, dtype=char_input_t.dtype)) final_input_t = BK.concat(all_input_slices, -2) # [*, 1?+len+1?, clen] # char embeddings char_embed_expr = self.E(final_input_t) # [*, ??, dim] # char cnn ret = self.cnn(char_embed_expr) return ret
def loss(self, insts: Union[List[Sent], List[Frame]], input_expr: BK.Expr, mask_expr: BK.Expr, pair_expr: BK.Expr = None, lookup_flatten=False, external_extra_score: BK.Expr = None): conf: SoftExtractorConf = self.conf assert not lookup_flatten bsize, slen = BK.get_shape(mask_expr) # -- # step 0: prepare arr_items, expr_seq_gaddr, expr_group_widxes, expr_group_masks, expr_loss_weight_non = \ self.helper.prepare(insts, mlen=BK.get_shape(mask_expr, -1), use_cache=True) arange2_t = BK.arange_idx(bsize).unsqueeze(-1) # [*, 1] # -- # step 1: cand cand_full_scores, pred_cand_decisions = self._cand_score_and_select( input_expr, mask_expr) # [*, slen] loss_cand_items, cand_widxes, cand_masks = self._loss_feed_cand( mask_expr, cand_full_scores, pred_cand_decisions, expr_seq_gaddr, expr_group_widxes, expr_group_masks, expr_loss_weight_non, ) # ~, [*, clen] # -- # step 2: split cand_expr, cand_scores = input_expr[ arange2_t, cand_widxes], cand_full_scores[arange2_t, cand_widxes] # [*, clen] split_scores, pred_split_decisions = self._split_score( cand_expr, cand_masks) # [*, clen-1] loss_split_item, seg_masks, seg_ext_widxes, seg_ext_masks, seg_weighted_expr, oracle_gaddr = self._loss_feed_split( mask_expr, split_scores, pred_split_decisions, cand_widxes, cand_masks, cand_expr, cand_scores, expr_seq_gaddr, ) # ~, [*, seglen, *?] # -- # step 3: lab # todo(note): add a 0 as idx=-1 to make NEG ones as 0!! flatten_gold_label_idxes = BK.input_idx( [(0 if z is None else z.label_idx) for z in arr_items.flatten()] + [0]) gold_label_idxes = flatten_gold_label_idxes[ oracle_gaddr] # [*, seglen] lab_loss_weights = BK.where(oracle_gaddr >= 0, expr_loss_weight_non.unsqueeze(-1) * conf.loss_weight_non, seg_masks) # [*, seglen] final_lab_loss_weights = lab_loss_weights * seg_masks # [*, seglen] # go loss_lab, loss_count = self.lab_node.loss( seg_weighted_expr, pair_expr, seg_masks, gold_label_idxes, loss_weight_expr=final_lab_loss_weights, extra_score=external_extra_score) loss_lab_item = LossHelper.compile_leaf_loss( f"lab", loss_lab, loss_count, loss_lambda=conf.loss_lab, gold=(gold_label_idxes > 0).float().sum()) # step 4: extend flt_mask = ((gold_label_idxes > 0) & (seg_masks > 0.)) # [*, seglen] flt_sidx = BK.arange_idx(bsize).unsqueeze(-1).expand_as(flt_mask)[ flt_mask] # [?] flt_expr = seg_weighted_expr[flt_mask] # [?, D] flt_full_expr = self._prepare_full_expr(seg_ext_widxes[flt_mask], seg_ext_masks[flt_mask], slen) # [?, slen, D] flt_items = arr_items.flatten()[BK.get_value( oracle_gaddr[flt_mask])] # [?] loss_ext_item = self.ext_node.loss(flt_items, input_expr[flt_sidx], flt_expr, flt_full_expr, mask_expr[flt_sidx]) # -- # return loss ret_loss = LossHelper.combine_multiple_losses( loss_cand_items + [loss_split_item, loss_lab_item, loss_ext_item]) return ret_loss, None