def loss_cf(self, cf_scores: List[BK.Expr], insts, loss_cf: float): conf: SeqExitHelperConf = self.conf # -- assert self.is_cf # get oracle oracles = [self.cf_oracle_f(ff) for ff in insts] # bs*[NL, slen] or bs*[NL] rets = [] mask_t = BK.input_real( DataPadder.lengths2mask([len(z.sent) for z in insts])) # [bs, slen] for one_li, one_scores in enumerate(cf_scores): if conf.cf_use_seq: one_oracle_t = BK.input_real([z[one_li] for z in oracles]) # [bs] one_oracle_t *= conf.cf_scale one_mask_t = BK.zeros([len(one_oracle_t)]) + 1 else: one_oracle_t = BK.input_real( DataPadder.go_batch_2d([z[one_li] for z in oracles], 1.)) # [bs, slen] one_mask_t = (BK.rand(one_oracle_t.shape) >= ((one_oracle_t**conf.cf_loss_discard_curve) * conf.cf_loss_discard)) * mask_t one_oracle_t *= conf.cf_scale # simple L2 loss one_loss_t = (one_scores.squeeze(-1) - one_oracle_t)**2 one_loss_item = LossHelper.compile_leaf_loss(f"cf{one_li}", (one_loss_t * one_mask_t).sum(), one_mask_t.sum(), loss_lambda=loss_cf) rets.append(one_loss_item) return rets
def __init__(self, conf: ZLabelConf, **kwargs): super().__init__(conf, **kwargs) conf: ZLabelConf = self.conf # -- _csize = conf._csize # final affine layer? self.input_act = ActivationHelper.get_act(conf.input_act) if conf.emb_size > 0: self.aff_final = AffineNode(None, isize=conf.emb_size, osize=_csize, no_drop=True) else: self.aff_final = None # fixed_nil_mask? self.fixed_nil_mask, self.fixed_nil_val = None, None if conf.fixed_nil_val is not None: self.fixed_nil_val = float(conf.fixed_nil_val) self.fixed_nil_mask = BK.input_real([0.] + [1.] * (_csize - 1)) # binary mode if conf.use_nil_as_binary or conf.loss_binary_alpha != 0: assert conf.use_nil_as_binary and conf.loss_binary_alpha != 0 and self.fixed_nil_mask is not None if conf.loss_allbinary_alpha != 0.: assert conf.fixed_nil_val == 0. # crf mode? self.crf = None if conf.crf is not None: assert not (conf.use_nil_as_binary or conf.loss_binary_alpha != 0 ) # not binary mode! self.crf = ZLinearCrfNode(conf.crf, _csize=_csize)
def get_weights(weights, pad: float = 0.): final_weights = [pad] * 100 # note: this should be enough! final_weights[-len(weights):] = [float(z) for z in weights ] # assign the last ones! ret = BK.input_real(final_weights) # ret /= ret.sum(-1) # normalize!, note: not doing this!! return ret
def __init__(self, conf: MySRLConf, vocab_evt: SimpleVocab, vocab_arg: SimpleVocab, **kwargs): super().__init__(conf, **kwargs) conf: MySRLConf = self.conf self.vocab_evt = vocab_evt self.vocab_arg = vocab_arg # -- self.vocab_bio_arg = None self.pred_cons_mat = None if conf.arg_use_bio: self.vocab_bio_arg = SeqVocab(vocab_arg) # simply BIO vocab zlog(f"Use BIO vocab for srl: {self.vocab_bio_arg}") if conf.arg_pred_use_seq_cons: _m = self.vocab_bio_arg.get_allowed_transitions() self.pred_cons_mat = (1. - BK.input_real(_m)) * Constants.REAL_PRAC_MIN # [L, L] zlog(f"Further use BIO constraints for decoding: {self.pred_cons_mat.shape}") helper_vocab_arg = self.vocab_bio_arg else: helper_vocab_arg = self.vocab_arg # -- self.helper = MySRLHelper(conf, self.vocab_evt, helper_vocab_arg) # -- # predicate self.evt_node = SingleIdecNode(conf.evt_conf, ndim=conf.isize, nlab=(2 if conf.binary_evt else len(vocab_evt))) # argument self.arg_node = PairwiseIdecNode(conf.arg_conf, ndim=conf.isize, nlab=len(helper_vocab_arg))
def prepare(self, insts: List[Sent], use_cache: bool): # get info if use_cache: zobjs = [] attr_name = f"_cache_srl" # should be unique for s in insts: one = getattr(s, attr_name, None) if one is None: one = self._prep_sent(s) setattr(s, attr_name, one) # set cache zobjs.append(one) else: zobjs = [self._prep_sent(s) for s in insts] # batch things bsize, mlen = len(insts), max(z.slen for z in zobjs) if len(zobjs)>0 else 1 # at least put one as padding batched_shape = (bsize, mlen) arr_items = np.full(batched_shape, None, dtype=object) arr_evt_labels = np.full(batched_shape, 0, dtype=np.int) arr_arg_labels = np.full(batched_shape+(mlen,), 0, dtype=np.int) for zidx, zobj in enumerate(zobjs): zlen = zobj.slen arr_items[zidx, :zlen] = zobj.evt_items arr_evt_labels[zidx, :zlen] = zobj.evt_arr arr_arg_labels[zidx, :zlen, :zlen] = zobj.arg_arr expr_evt_labels = BK.input_idx(arr_evt_labels) # [*, slen] expr_arg_labels = BK.input_idx(arr_arg_labels) # [*, slen] expr_loss_weight_non = BK.input_real([z.loss_weight_non for z in zobjs]) # [*] return arr_items, expr_evt_labels, expr_arg_labels, expr_loss_weight_non
def __init__(self, conf: PlainInputEmbedderConf, voc: SimpleVocab, npvec: np.ndarray = None, name="UNK"): super().__init__(conf, name) # -- conf: PlainInputEmbedderConf = self.conf self.voc = voc # check init embeddings if conf.init_from_pretrain: zlog( f"Try to init {self.extra_repr()} with npvec.shape={npvec.shape if (npvec is not None) else None}" ) if npvec is None: zwarn("warning: cannot get pre-trained embeddings to init!!") # get rare unk range voc_rare_unk_mask = [] for w in self.voc.full_i2w: c = self.voc.word2count(w, df=None) voc_rare_unk_mask.append( float(c is not None and c <= conf.rare_unk_thr)) self.rare_unk_mask = BK.input_real(voc_rare_unk_mask) # stored tensor! # self.register_buffer() # todo(note): do we need register buffer? self.use_rare_unk = (conf.rare_unk_rate > 0. and conf.rare_unk_thr > 0) # add the real embedding node self.E = EmbeddingNode(conf.econf, npvec=npvec, osize=conf.dim, n_words=len(self.voc))
def prepare(self, insts: Union[List[Sent], List[Frame]], mlen: int, use_cache: bool): conf: SeqExtractorConf = self.conf # get info if use_cache: zobjs = [] attr_name = f"_scache_{conf.ftag}" # should be unique for s in insts: one = getattr(s, attr_name, None) if one is None: one = self._prep_f(s) setattr(s, attr_name, one) # set cache zobjs.append(one) else: zobjs = [self._prep_f(s) for s in insts] # batch things bsize = len(insts) # mlen = max(z.len for z in zobjs) # note: fed by outside!! batched_shape = (bsize, mlen) # arr_first_items = np.full(batched_shape, None, dtype=object) arr_slabs = np.full(batched_shape, 0, dtype=np.int) for zidx, zobj in enumerate(zobjs): # arr_first_items[zidx, zobj.len] = zobj.first_items arr_slabs[zidx, :zobj.len] = zobj.tags # final setup things expr_slabs = BK.input_idx(arr_slabs) expr_loss_weight_non = BK.input_real( [z.loss_weight_non for z in zobjs]) # [*] # return arr_first_items, expr_slabs, expr_loss_weight_non return expr_slabs, expr_loss_weight_non
def __init__(self, conf: ZDecoderSRLConf, name: str, vocab_evt: SimpleVocab, vocab_arg: SimpleVocab, ref_enc: ZEncoder, **kwargs): super().__init__(conf, name, **kwargs) conf: ZDecoderSRLConf = self.conf self.vocab_evt = vocab_evt self.vocab_arg = vocab_arg _enc_dim, _head_dim = ref_enc.get_enc_dim(), ref_enc.get_head_dim() # -- self.vocab_bio_arg = None self.pred_cons_mat = None if conf.arg_use_bio: self.vocab_bio_arg = SeqVocab(vocab_arg) # simply BIO vocab zlog(f"Use BIO vocab for srl: {self.vocab_bio_arg}") if conf.arg_pred_use_seq_cons: _m = self.vocab_bio_arg.get_allowed_transitions() self.pred_cons_mat = (1. - BK.input_real(_m)) * Constants.REAL_PRAC_MIN # [L, L] zlog(f"Further use BIO constraints for decoding: {self.pred_cons_mat.shape}") helper_vocab_arg = self.vocab_bio_arg else: helper_vocab_arg = self.vocab_arg # -- self.helper = ZDecoderSRLHelper(conf, self.vocab_evt, helper_vocab_arg, self.vocab_arg) # -- # nodes self.evt_node: IdecNode = conf.evt_conf.make_node(_isize=_enc_dim, _nhead=_head_dim, _csize=(2 if conf.binary_evt else len(vocab_evt))) self.arg_node: IdecNode = conf.arg_conf.make_node(_isize=_enc_dim, _nhead=_head_dim, _csize=len(helper_vocab_arg)) self.arg2_node: IdecNode = conf.arg2_conf.make_node(_isize=_enc_dim, _nhead=_head_dim, _csize=len(self.vocab_arg)) # -- raise RuntimeError("Deprecated after MED's collecting of scores!!")
def get_weights(nl: int, weights): final_weights = [weights[-1]] * nl mlen = min(nl, len(weights)) final_weights[:mlen] = weights[:mlen] ret = BK.input_real(final_weights) ret /= ret.sum(-1) # normalize! return ret
def _input_bert(self, insts: List[Sent]): bi = self.berter.create_input_batch_from_sents(insts) mask_expr = BK.input_real( DataPadder.lengths2mask([len(z) for z in insts])) # [bs, slen, *] vstate = self.idec_manager.new_vstate( None, mask_expr) # todo(+N): currently ignore emb layer! bert_expr = self.berter.forward(bi, vstate=vstate) return mask_expr, bert_expr, vstate
def loss_regs(regs: List['ParamRegHelper']): loss_items = [] for ii, reg in enumerate(regs): if reg.reg_method_loss: _loss, _loss_lambda = reg.compute_loss() _loss_item = LossHelper.compile_leaf_loss(f'reg_{ii}', _loss, BK.input_real(1.), loss_lambda=_loss_lambda) loss_items.append(_loss_item) ret_loss = LossHelper.combine_multiple_losses(loss_items) return ret_loss
def cons_score_frame2role(self, arg_scores: BK.Expr, evts: List): evt_idxes = [e.label_idx for e in evts] valid_mask = BK.input_real(self.role_cons[evt_idxes]) # [??, L] if self.cons_arg_bio_sels is not None: valid_mask = valid_mask[:, self.cons_arg_bio_sels] # [??, L'] valid_mask[:, 0] = 1. # note: must preserve NIL! arg_scores = arg_scores + (1 - valid_mask).unsqueeze( -2) * Constants.REAL_PRAC_MIN # [??, dlen, L] return arg_scores
def get_batched_items(self, insts: List): all_items = [self._get_f(z) for z in insts] arr_shape = len(insts), max(len(z) for z in all_items) arr_items = np.full(arr_shape, None, dtype=object) arr_masks = np.full(arr_shape, 0., dtype=np.float32) for zidx, zitems in enumerate(all_items): zlen = len(zitems) arr_items[zidx, :zlen] = zitems arr_masks[zidx, :zlen] = 1. return arr_items, BK.input_real(arr_masks)
def prepare(self, insts: Union[List[Sent], List[Frame]], mlen: int, use_cache: bool): conf: AnchorExtractorConf = self.conf # get info if use_cache: zobjs = [] attr_name = f"_acache_{conf.ftag}" # should be unique for s in insts: one = getattr(s, attr_name, None) if one is None: one = self._prep_f(s) setattr(s, attr_name, one) # set cache zobjs.append(one) else: zobjs = [self._prep_f(s) for s in insts] # batch things bsize, mlen2 = len(insts), max(len(z.items) for z in zobjs) if len(zobjs) > 0 else 1 mnum = max(len(g) for z in zobjs for g in z.group_widxes) if len(zobjs) > 0 else 1 arr_items = np.full((bsize, mlen2), None, dtype=object) # [*, ?] arr_seq_iidxes = np.full((bsize, mlen), -1, dtype=np.int) arr_seq_labs = np.full((bsize, mlen), 0, dtype=np.int) arr_group_widxes = np.full((bsize, mlen, mnum), 0, dtype=np.int) arr_group_masks = np.full((bsize, mlen, mnum), 0., dtype=np.float) for zidx, zobj in enumerate(zobjs): arr_items[zidx, :len(zobj.items)] = zobj.items iidx_offset = zidx * mlen2 # note: offset for valid ones! arr_seq_iidxes[zidx, :len(zobj.seq_iidxes)] = [ (iidx_offset + ii) if ii >= 0 else ii for ii in zobj.seq_iidxes ] arr_seq_labs[zidx, :len(zobj.seq_labs)] = zobj.seq_labs for zidx2, zwidxes in enumerate(zobj.group_widxes): arr_group_widxes[zidx, zidx2, :len(zwidxes)] = zwidxes arr_group_masks[zidx, zidx2, :len(zwidxes)] = 1. # final setup things expr_seq_iidxes = BK.input_idx(arr_seq_iidxes) # [*, slen] expr_seq_labs = BK.input_idx(arr_seq_labs) # [*, slen] expr_group_widxes = BK.input_idx(arr_group_widxes) # [*, slen, MW] expr_group_masks = BK.input_real(arr_group_masks) # [*, slen, MW] expr_loss_weight_non = BK.input_real( [z.loss_weight_non for z in zobjs]) # [*] return arr_items, expr_seq_iidxes, expr_seq_labs, expr_group_widxes, expr_group_masks, expr_loss_weight_non
def prepare_inputs(self, insts: List): ret = OrderedDict() # first basic masks mask_arr = DataPadder.lengths2mask([len(z) for z in insts]) ret["mask"] = BK.input_real(mask_arr) ret["mask_arr"] = mask_arr # then the rest for key, inputter in self.inputters.items(): ret[key] = inputter.prepare(insts) return ret
def __init__(self, ibatch: InputBatch, IDX_PAD: int): # preps self.bsize = len(ibatch) self.arange1_t = BK.arange_idx(self.bsize) # [bsize] self.arange2_t = self.arange1_t.unsqueeze(-1) # [bsize, 1] self.arange3_t = self.arange2_t.unsqueeze(-1) # [bsize, 1, 1] # batched them all_seq_infos = [z.seq_info for z in ibatch.items] # enc: [*, len_enc]: ids(pad IDX_PAD), masks, segids(pad 0) self.enc_input_ids = BK.input_idx( DataPadder.go_batch_2d([z.enc_input_ids for z in all_seq_infos], int(IDX_PAD))) self.enc_input_masks = BK.input_real( DataPadder.lengths2mask( [len(z.enc_input_ids) for z in all_seq_infos])) self.enc_input_segids = BK.input_idx( DataPadder.go_batch_2d([z.enc_input_segids for z in all_seq_infos], 0)) # dec: [*, len_dec]: sel_idxes(pad 0), sel_lens(pad 1), masks, sent_idxes(pad ??) self.dec_sel_idxes = BK.input_idx( DataPadder.go_batch_2d([z.dec_sel_idxes for z in all_seq_infos], 0)) self.dec_sel_lens = BK.input_idx( DataPadder.go_batch_2d([z.dec_sel_lens for z in all_seq_infos], 1)) self.dec_sel_masks = BK.input_real( DataPadder.lengths2mask( [len(z.dec_sel_idxes) for z in all_seq_infos])) _max_dec_len = BK.get_shape(self.dec_sel_masks, 1) _dec_offsets = BK.input_idx( DataPadder.go_batch_2d([z.dec_offsets for z in all_seq_infos], _max_dec_len)) # note: CLS as -1, then 0,1,2,..., PAD gets -2! self.dec_sent_idxes = \ (BK.arange_idx(_max_dec_len).unsqueeze(0).unsqueeze(-1) >= _dec_offsets.unsqueeze(-2)).sum(-1).long() - 1 self.dec_sent_idxes[self.dec_sel_masks <= 0.] = -2 # dec -> enc: [*, len_enc] (calculated on needed!) # note: require 1-to-1 mapping (except pads)!! self._enc_back_hits = None self._enc_back_sel_idxes = None
def __init__(self, conf: SrlInferenceHelperConf, dec: 'ZDecoderSrl', **kwargs): super().__init__(conf, **kwargs) conf: SrlInferenceHelperConf = self.conf # -- self.setattr_borrow('dec', dec) self.arg_pp = PostProcessor(conf.arg_pp) # -- self.lu_cons, self.role_cons = None, None if conf.frames_name: # currently only frame->role from msp2.data.resources import get_frames_label_budgets flb = get_frames_label_budgets(conf.frames_name) _voc_ef, _voc_evt, _voc_arg = dec.ztask.vpack _role_cons = fchelper.build_constraint_arrs( flb, _voc_arg, _voc_evt) self.role_cons = BK.input_real(_role_cons) if conf.frames_file: _voc_ef, _voc_evt, _voc_arg = dec.ztask.vpack _fc = default_pickle_serializer.from_file(conf.frames_file) _lu_cons = fchelper.build_constraint_arrs( fchelper.build_lu_map(_fc), _voc_evt, warning=False) # lexicon->frame _role_cons = fchelper.build_constraint_arrs( fchelper.build_role_map(_fc), _voc_arg, _voc_evt) # frame->role self.lu_cons, self.role_cons = _lu_cons, BK.input_real(_role_cons) # -- self.cons_evt_tok_f = conf.get_cons_evt_tok() self.cons_evt_frame_f = conf.get_cons_evt_frame() if self.dec.conf.arg_use_bio: # extend for bio! self.cons_arg_bio_sels = BK.input_idx( self.dec.vocab_bio_arg.get_bio2origin()) else: self.cons_arg_bio_sels = None # -- from msp2.data.resources.frames import KBP17_TYPES self.pred_evt_filter = { 'kbp17': KBP17_TYPES }.get(conf.pred_evt_filter, None)
def nmarginal_proj(scores_expr, mask_expr, lengths_arr, labeled=True): assert labeled with BK.no_grad_env(): # first make it unlabeled by sum-exp scores_unlabeled = BK.logsumexp(scores_expr, dim=-1) # [BS, m, h] # marginal for unlabeled scores_unlabeled_arr = BK.get_value(scores_unlabeled) marginals_unlabeled_arr = marginal_proj(scores_unlabeled_arr, lengths_arr, False) # back to labeled values marginals_unlabeled_expr = BK.input_real(marginals_unlabeled_arr) marginals_labeled_expr = marginals_unlabeled_expr.unsqueeze(-1) * BK.exp(scores_expr - scores_unlabeled.unsqueeze(-1)) # [BS, m, h, L] return _ensure_margins_norm(marginals_labeled_expr)
def __init__(self, berter: BertEncoder, seq_subs: List[InputSubwordSeqField]): self.seq_subs = seq_subs self.berter = berter self.bsize = len(seq_subs) self.arange1_t = BK.arange_idx(self.bsize) # [bsize] self.arange2_t = self.arange1_t.unsqueeze(-1) # [bsize, 1] self.arange3_t = self.arange2_t.unsqueeze(-1) # [bsize, 1, 1] # -- tokenizer = self.berter.tokenizer PAD_IDX = tokenizer.pad_token_id # MASK_IDX = tokenizer.mask_token_id # CLS_IDX_l = [tokenizer.cls_token_id] # SEP_IDX_l = [tokenizer.sep_token_id] # make batched idxes padder = DataPadder(2, pad_vals=PAD_IDX, mask_range=2) batched_sublens = [len(s.idxes) for s in seq_subs] # [bsize] batched_input_ids, batched_input_mask = padder.pad( [s.idxes for s in seq_subs]) # [bsize, sub_len] self.batched_sublens_p1 = BK.input_idx( batched_sublens ) + 1 # also the idx of EOS (if counting including BOS) self.batched_input_ids = BK.input_idx(batched_input_ids) self.batched_input_mask = BK.input_real(batched_input_mask) # make batched mappings (sub->orig) padder2 = DataPadder(2, pad_vals=0, mask_range=2) # pad as 0 to avoid out-of-range batched_first_idxes, batched_first_mask = padder2.pad( [s.align_info.orig2begin for s in seq_subs]) # [bsize, orig_len] self.batched_first_idxes = BK.input_idx(batched_first_idxes) self.batched_first_mask = BK.input_real(batched_first_mask) # reversed batched_mappings (orig->sub) (created when needed) self._batched_rev_idxes = None # [bsize, sub_len] # -- self.batched_repl_masks = None # [bsize, sub_len], to replace with MASK self.batched_token_type_ids = None # [bsize, 1+sub_len+1] self.batched_position_ids = None # [bsize, 1+sub_len+1] self.other_factors = {} # name -> aug_batched_ids
def select_topk_non_overlapping(score_t: BK.Expr, topk_t: Union[int, BK.Expr], widx_t: BK.Expr, wlen_t: BK.Expr, input_mask_t: BK.Expr, mask_t: BK.Expr = None, dim=-1): score_shape = BK.get_shape(score_t) assert dim == -1 or dim == len( score_shape - 1 ), "Currently only support last-dim!!" # todo(+2): we can permute to allow any dim! # -- # prepare K if isinstance(topk_t, int): tmp_shape = score_shape.copy() tmp_shape[dim] = 1 # set it as 1 topk_t = BK.constants_idx(tmp_shape, topk_t) # -- reshape_trg = [np.prod(score_shape[:-1]).item(), -1] # [*, ?] _, sorted_idxes_t = score_t.sort(dim, descending=True) # -- # put it as CPU and use loop; todo(+N): more efficient ways? arr_sorted_idxes_t, arr_topk_t, arr_widx_t, arr_wlen_t, arr_input_mask_t, arr_mask_t = \ [BK.get_value(z.reshape(reshape_trg)) for z in [sorted_idxes_t, topk_t, widx_t, wlen_t, input_mask_t, mask_t]] _bsize, _cnum = BK.get_shape(arr_sorted_idxes_t) # [bsize, NUM] arr_topk_mask = np.full([_bsize, _cnum], 0.) # [bsize, NUM] _bidx = 0 for aslice_sorted_idxes_t, aslice_topk_t, aslice_widx_t, aslice_wlen_t, aslice_input_mask_t, aslice_mask_t \ in zip(arr_sorted_idxes_t, arr_topk_t, arr_widx_t, arr_wlen_t, arr_input_mask_t, arr_mask_t): aslice_topk_mask = arr_topk_mask[_bidx] # -- cur_ok_mask = np.copy(aslice_input_mask_t) cur_budget = aslice_topk_t.item() for _cidx in aslice_sorted_idxes_t: _cidx = _cidx.item() if cur_budget <= 0: break # no budget left if not aslice_mask_t[_cidx].item(): continue # non-valid candidate one_widx, one_wlen = aslice_widx_t[_cidx].item( ), aslice_wlen_t[_cidx].item() if np.prod(cur_ok_mask[one_widx:one_widx + one_wlen]).item() == 0.: # any hit one? continue # ok! add it! cur_budget -= 1 cur_ok_mask[one_widx:one_widx + one_wlen] = 0. aslice_topk_mask[_cidx] = 1. _bidx += 1 # note: no need to *=mask_t again since already check in the loop return BK.input_real(arr_topk_mask).reshape(score_shape)
def prepare(self, insts: List, use_cache: bool): # get info zobjs = ZDecoderHelper.get_zobjs(insts, self._prep_inst, use_cache, f"_cache_udep") # then bsize, mlen = len(insts), max(z.slen for z in zobjs) if len(zobjs)>0 else 1 batched_shape = (bsize, mlen) arr_depth = np.full(batched_shape, 0., dtype=np.float) # [*, slen] arr_udep = np.full(batched_shape+(mlen,), 0, dtype=np.int) # [*, slen_m, slen_h] for zidx, zobj in enumerate(zobjs): zlen = zobj.slen arr_depth[zidx, :zlen] = zobj.depth_arr arr_udep[zidx, :zlen, :zlen] = zobj.udep_arr expr_depth = BK.input_real(arr_depth) # [*, slen] expr_udep = BK.input_idx(arr_udep) # [*, slen_m, slen_h] return expr_depth, expr_udep
def _common_nmst(CPU_f, scores_expr, mask_expr, lengths_arr, labeled, ret_arr): assert labeled with BK.no_grad_env(): # argmax-label: [BS, m, h] scores_unlabeled_max, labels_argmax = scores_expr.max(-1) # scores_unlabeled_max_arr = BK.get_value(scores_unlabeled_max) mst_heads_arr, _, mst_scores_arr = CPU_f(scores_unlabeled_max_arr, lengths_arr, labeled=False) # [BS, m] mst_heads_expr = BK.input_idx(mst_heads_arr) mst_labels_expr = BK.gather_one_lastdim(labels_argmax, mst_heads_expr).squeeze(-1) # prepare for the outputs if ret_arr: return mst_heads_arr, BK.get_value(mst_labels_expr), mst_scores_arr else: return mst_heads_expr, mst_labels_expr, BK.input_real(mst_scores_arr)
def __init__(self, cons: Constrainer, src_vocab: SimpleVocab, trg_vocab: SimpleVocab, conf: ConstrainerNodeConf, **kwargs): super().__init__(conf, **kwargs) conf: ConstrainerNodeConf = self.conf # -- # input vocab if src_vocab is None: # make our own src_vocab cons_keys = sorted(cons.cmap.keys()) # simply get all the keys src_vocab = SimpleVocab.build_by_static(cons_keys, pre_list=["non"], post_list=None) # non==0! # output vocab assert trg_vocab is not None out_size = len(trg_vocab) # output size is len(trg_vocab) trg_is_seq_vocab = isinstance(trg_vocab, SeqVocab) _trg_get_f = (lambda x: trg_vocab.get_range_by_basename(x)) if trg_is_seq_vocab else (lambda x: trg_vocab.get(x)) # set it up _vec = np.full((len(src_vocab), out_size), 0., dtype=np.float32) assert src_vocab.non == 0 _vec[0] = 1. # by default: src-non is all valid! _vec[:,0] = 1. # by default: trg-non is all valid! # -- stat = {"k_skip": 0, "k_hit": 0, "v_skip": 0, "v_hit": 1} for k, v in cons.cmap.items(): idx_k = src_vocab.get(k) if idx_k is None: stat["k_skip"] += 1 continue # skip no_hit! stat["k_hit"] += 1 for k2 in v.keys(): idx_k2 = _trg_get_f(k2) if idx_k2 is None: stat["v_skip"] += 1 continue stat["v_hit"] += 1 if trg_is_seq_vocab: _vec[idx_k, idx_k2[0]:idx_k2[1]] = 1. # hit range else: _vec[idx_k, idx_k2] = 1. # hit!! zlog(f"Setup ConstrainerNode ok: vec={_vec.shape}, stat={stat}") # -- self.cons = cons self.src_vocab = src_vocab self.trg_vocab = trg_vocab self.vec = BK.input_real(_vec)
def prepare(self, insts: Union[List[Sent], List[Frame]], use_cache: bool): conf: DirectExtractorConf = self.conf # get info if use_cache: zobjs = [] attr_name = f"_dcache_{conf.ftag}" # should be unique for s in insts: one = getattr(s, attr_name, None) if one is None: one = self._prep_f(s) setattr(s, attr_name, one) # set cache zobjs.append(one) else: zobjs = [self._prep_f(s) for s in insts] # batch things bsize, mlen = len(insts), max(z.len for z in zobjs) if len(zobjs)>0 else 1 # at least put one as padding batched_shape = (bsize, mlen) arr_items = np.full(batched_shape, None, dtype=object) arr_gaddrs = np.arange(bsize*mlen).reshape(batched_shape) # gold address arr_core_widxes = np.full(batched_shape, 0, dtype=np.int) arr_core_wlens = np.full(batched_shape, 1, dtype=np.int) # arr_ext_widxes = np.full(batched_shape, 0, dtype=np.int) # arr_ext_wlens = np.full(batched_shape, 1, dtype=np.int) for zidx, zobj in enumerate(zobjs): zlen = zobj.len arr_items[zidx, :zlen] = zobj.items arr_core_widxes[zidx, :zlen] = zobj.core_widxes arr_core_wlens[zidx, :zlen] = zobj.core_wlens # arr_ext_widxes[zidx, :zlen] = zobj.ext_widxes # arr_ext_wlens[zidx, :zlen] = zobj.ext_wlens arr_gaddrs[arr_items==None] = -1 # set -1 as gaddr # final setup things expr_gaddr = BK.input_idx(arr_gaddrs) # [*, GLEN] expr_core_widxes = BK.input_idx(arr_core_widxes) expr_core_wlens = BK.input_idx(arr_core_wlens) # expr_ext_widxes = BK.input_idx(arr_ext_widxes) # expr_ext_wlens = BK.input_idx(arr_ext_wlens) expr_loss_weight_non = BK.input_real([z.loss_weight_non for z in zobjs]) # [*] # return arr_flatten_items, expr_gaddr, expr_core_widxes, expr_core_wlens, \ # expr_ext_widxes, expr_ext_wlens, expr_loss_weight_non return arr_items, expr_gaddr, expr_core_widxes, expr_core_wlens, expr_loss_weight_non
def get_label_mask(self, sels: List[str]): expand_sels = [] for s in sels: if s in UD_CATEGORIES: expand_sels.extend(UD_CATEGORIES[s]) else: expand_sels.append(s) expand_sels = sorted(set(expand_sels)) voc = self.voc # -- ret = np.zeros(len(voc)) _cc = 0 for s in expand_sels: if s in voc: ret[voc[s]] = 1. _cc += voc.word2count(s) else: zwarn(f"UNK dep label: {s}") _all_cc = voc.get_all_counts() zlog(f"Get label mask with {expand_sels}: {len(expand_sels)}=={ret.sum().item()} -> {_cc}/{_all_cc}={_cc/(_all_cc+1e-5)}") return BK.input_real(ret)
def prepare(self, insts: List[Sent], use_cache: bool): # get info zobjs = ZDecoderHelper.get_zobjs(insts, self._prep_sent, use_cache, f"_cache_srl") # batch things bsize, mlen = len(insts), max(z.slen for z in zobjs) if len(zobjs)>0 else 1 # at least put one as padding batched_shape = (bsize, mlen) arr_items = np.full(batched_shape, None, dtype=object) arr_evt_labels = np.full(batched_shape, 0, dtype=np.int) arr_arg_labels = np.full(batched_shape+(mlen,), 0, dtype=np.int) arr_arg2_labels = np.full(batched_shape+(mlen,), 0, dtype=np.int) for zidx, zobj in enumerate(zobjs): zlen = zobj.slen arr_items[zidx, :zlen] = zobj.evt_items arr_evt_labels[zidx, :zlen] = zobj.evt_arr arr_arg_labels[zidx, :zlen, :zlen] = zobj.arg_arr arr_arg2_labels[zidx, :zlen, :zlen] = zobj.arg2_arr expr_evt_labels = BK.input_idx(arr_evt_labels) # [*, slen] expr_arg_labels = BK.input_idx(arr_arg_labels) # [*, slen, slen] expr_arg2_labels = BK.input_idx(arr_arg2_labels) # [*, slen, slen] expr_loss_weight_non = BK.input_real([z.loss_weight_non for z in zobjs]) # [*] return arr_items, expr_evt_labels, expr_arg_labels, expr_arg2_labels, expr_loss_weight_non
def cons_score_lu2frame(self, evt_scores: BK.Expr, ibatch, given_f=None): dlen, llen = BK.get_shape(evt_scores)[-2:] _tok_f, _frame_f = self.cons_evt_tok_f, self.cons_evt_frame_f # -- res = [] for bidx, item in enumerate(ibatch.items): _dec_offsets = item.seq_info.dec_offsets one_res = [np.zeros(llen) for _ in range(dlen)] # dlen*[L] for sidx, sent in enumerate(item.sents): _start = _dec_offsets[sidx] # tok if _tok_f is not None: for widx, tok in enumerate(sent.get_tokens()): _key = _tok_f(tok) _arr = self.lu_cons.get(_key) if _arr is not None: one_res[_start + widx] += _arr # frame if _frame_f is not None and given_f is not None: _frames = given_f(sent) for ff in _frames: _key = _frame_f(ff) _arr = self.lu_cons.get(_key) if _arr is not None: one_res[_start + ff.mention.shead_widx] += _arr # -- res.extend(one_res) # -- valid_mask = BK.input_real(res).view(evt_scores.shape) # [*, dlen, L] valid_mask.clamp_(max=1) exclude_mask = (1 - valid_mask) * (valid_mask.sum(-1, keepdims=True) > 0).float() # no effects if no-hit! exclude_mask[:, :, 0] = 0. # note: still do not exclude NIL, use "pred_evt_nil_add" for special mode! evt_scores = evt_scores + exclude_mask * Constants.REAL_PRAC_MIN # [*, dlen, L] return evt_scores
def _input_bert(self, insts: List[Sent]): bi = self.berter.create_input_batch_from_sents(insts) mask_expr = BK.input_real( DataPadder.lengths2mask([len(z) for z in insts])) # [bs, slen, *] bert_expr = self.berter.forward(bi) return mask_expr, bert_expr
def __init__(self, vocab: SimpleVocab, conf: SeqLabelerConf, **kwargs): super().__init__(conf, **kwargs) conf: SeqLabelerConf = self.conf is_pairwise = (conf.psize > 0) self.is_pairwise = is_pairwise # -- # 0. pre mlp isize, psize = conf.isize, conf.psize self.main_mlp = MLPNode(conf.main_mlp, isize=isize, osize=-1, use_out=False) isize = self.main_mlp.get_output_dims()[0] if is_pairwise: self.pair_mlp = MLPNode(conf.pair_mlp, isize=psize, osize=-1, use_out=False) psize = self.pair_mlp.get_output_dims()[0] else: self.pair_mlp = lambda x: x # 1/2. decoder & laber if conf.use_seqdec: # extra for seq-decoder dec_hid = conf.seqdec_conf.dec_hidden # setup labeler to get embedding dim self.laber = SimpleLabelerNode(vocab, conf.labeler_conf, isize=dec_hid, psize=psize) laber_embed_dim = self.laber.lookup_dim # init starting hidden; note: choose different according to 'is_pairwise' self.sd_init_aff = AffineNode( conf.sd_init_aff, isize=(psize if is_pairwise else isize), osize=dec_hid) self.sd_init_pool_f = ActivationHelper.get_pool(conf.sd_init_pool) # sd input: one_repr + one_idx_embed self.sd_input_aff = AffineNode(conf.sd_init_aff, isize=[isize, laber_embed_dim], osize=dec_hid) # sd output: cur_expr + hidden self.sd_output_aff = AffineNode(conf.sd_output_aff, isize=[isize, dec_hid], osize=dec_hid) # sd itself self.seqdec = PlainDecoder(conf.seqdec_conf, input_dim=dec_hid) else: # directly using the scorer (overwrite some values) self.laber = SimpleLabelerNode(vocab, conf.labeler_conf, isize=isize, psize=psize) # 3. bigram # todo(note): bigram does not consider skip_non if conf.use_bigram: self.bigram = BigramNode(conf.bigram_conf, osize=self.laber.output_dim) else: self.bigram = None # special decoding if conf.pred_use_seq_cons_from_file: assert not conf.pred_use_seq_cons _m = default_pickle_serializer.from_file( conf.pred_use_seq_cons_from_file) zlog(f"Load weights from {conf.pred_use_seq_cons_from_file}") self.pred_cons_mat = BK.input_real(_m) elif conf.pred_use_seq_cons: _m = vocab.get_allowed_transitions() self.pred_cons_mat = (1. - BK.input_real(_m)) * Constants.REAL_PRAC_MIN else: self.pred_cons_mat = None # ===== # loss self.loss_mle, self.loss_crf = [ conf.loss_mode == z for z in ["mle", "crf"] ] if self.loss_mle: if conf.use_seqdec or conf.use_bigram: zlog("Setup SeqLabelerNode with Local complex mode!") else: zlog("Setup SeqLabelerNode with Local simple mode!") elif self.loss_crf: assert conf.use_bigram and ( not conf.use_seqdec), "Wrong mode for crf" zlog("Setup SeqLabelerNode with CRF mode!") else: raise NotImplementedError(f"UNK loss mode: {conf.loss_mode}")