コード例 #1
0
 def loss_cf(self, cf_scores: List[BK.Expr], insts, loss_cf: float):
     conf: SeqExitHelperConf = self.conf
     # --
     assert self.is_cf
     # get oracle
     oracles = [self.cf_oracle_f(ff)
                for ff in insts]  # bs*[NL, slen] or bs*[NL]
     rets = []
     mask_t = BK.input_real(
         DataPadder.lengths2mask([len(z.sent)
                                  for z in insts]))  # [bs, slen]
     for one_li, one_scores in enumerate(cf_scores):
         if conf.cf_use_seq:
             one_oracle_t = BK.input_real([z[one_li]
                                           for z in oracles])  # [bs]
             one_oracle_t *= conf.cf_scale
             one_mask_t = BK.zeros([len(one_oracle_t)]) + 1
         else:
             one_oracle_t = BK.input_real(
                 DataPadder.go_batch_2d([z[one_li] for z in oracles],
                                        1.))  # [bs, slen]
             one_mask_t = (BK.rand(one_oracle_t.shape) >=
                           ((one_oracle_t**conf.cf_loss_discard_curve) *
                            conf.cf_loss_discard)) * mask_t
             one_oracle_t *= conf.cf_scale
         # simple L2 loss
         one_loss_t = (one_scores.squeeze(-1) - one_oracle_t)**2
         one_loss_item = LossHelper.compile_leaf_loss(f"cf{one_li}",
                                                      (one_loss_t *
                                                       one_mask_t).sum(),
                                                      one_mask_t.sum(),
                                                      loss_lambda=loss_cf)
         rets.append(one_loss_item)
     return rets
コード例 #2
0
 def __init__(self, conf: ZLabelConf, **kwargs):
     super().__init__(conf, **kwargs)
     conf: ZLabelConf = self.conf
     # --
     _csize = conf._csize
     # final affine layer?
     self.input_act = ActivationHelper.get_act(conf.input_act)
     if conf.emb_size > 0:
         self.aff_final = AffineNode(None,
                                     isize=conf.emb_size,
                                     osize=_csize,
                                     no_drop=True)
     else:
         self.aff_final = None
     # fixed_nil_mask?
     self.fixed_nil_mask, self.fixed_nil_val = None, None
     if conf.fixed_nil_val is not None:
         self.fixed_nil_val = float(conf.fixed_nil_val)
         self.fixed_nil_mask = BK.input_real([0.] + [1.] * (_csize - 1))
     # binary mode
     if conf.use_nil_as_binary or conf.loss_binary_alpha != 0:
         assert conf.use_nil_as_binary and conf.loss_binary_alpha != 0 and self.fixed_nil_mask is not None
     if conf.loss_allbinary_alpha != 0.:
         assert conf.fixed_nil_val == 0.
     # crf mode?
     self.crf = None
     if conf.crf is not None:
         assert not (conf.use_nil_as_binary or conf.loss_binary_alpha != 0
                     )  # not binary mode!
         self.crf = ZLinearCrfNode(conf.crf, _csize=_csize)
コード例 #3
0
 def get_weights(weights, pad: float = 0.):
     final_weights = [pad] * 100  # note: this should be enough!
     final_weights[-len(weights):] = [float(z) for z in weights
                                      ]  # assign the last ones!
     ret = BK.input_real(final_weights)
     # ret /= ret.sum(-1)  # normalize!, note: not doing this!!
     return ret
コード例 #4
0
ファイル: srl.py プロジェクト: zzsfornlp/zmsp
 def __init__(self, conf: MySRLConf, vocab_evt: SimpleVocab, vocab_arg: SimpleVocab, **kwargs):
     super().__init__(conf, **kwargs)
     conf: MySRLConf = self.conf
     self.vocab_evt = vocab_evt
     self.vocab_arg = vocab_arg
     # --
     self.vocab_bio_arg = None
     self.pred_cons_mat = None
     if conf.arg_use_bio:
         self.vocab_bio_arg = SeqVocab(vocab_arg)  # simply BIO vocab
         zlog(f"Use BIO vocab for srl: {self.vocab_bio_arg}")
         if conf.arg_pred_use_seq_cons:
             _m = self.vocab_bio_arg.get_allowed_transitions()
             self.pred_cons_mat = (1. - BK.input_real(_m)) * Constants.REAL_PRAC_MIN  # [L, L]
             zlog(f"Further use BIO constraints for decoding: {self.pred_cons_mat.shape}")
         helper_vocab_arg = self.vocab_bio_arg
     else:
         helper_vocab_arg = self.vocab_arg
     # --
     self.helper = MySRLHelper(conf, self.vocab_evt, helper_vocab_arg)
     # --
     # predicate
     self.evt_node = SingleIdecNode(conf.evt_conf, ndim=conf.isize, nlab=(2 if conf.binary_evt else len(vocab_evt)))
     # argument
     self.arg_node = PairwiseIdecNode(conf.arg_conf, ndim=conf.isize, nlab=len(helper_vocab_arg))
コード例 #5
0
ファイル: srl.py プロジェクト: zzsfornlp/zmsp
 def prepare(self, insts: List[Sent], use_cache: bool):
     # get info
     if use_cache:
         zobjs = []
         attr_name = f"_cache_srl"  # should be unique
         for s in insts:
             one = getattr(s, attr_name, None)
             if one is None:
                 one = self._prep_sent(s)
                 setattr(s, attr_name, one)  # set cache
             zobjs.append(one)
     else:
         zobjs = [self._prep_sent(s) for s in insts]
     # batch things
     bsize, mlen = len(insts), max(z.slen for z in zobjs) if len(zobjs)>0 else 1  # at least put one as padding
     batched_shape = (bsize, mlen)
     arr_items = np.full(batched_shape, None, dtype=object)
     arr_evt_labels = np.full(batched_shape, 0, dtype=np.int)
     arr_arg_labels = np.full(batched_shape+(mlen,), 0, dtype=np.int)
     for zidx, zobj in enumerate(zobjs):
         zlen = zobj.slen
         arr_items[zidx, :zlen] = zobj.evt_items
         arr_evt_labels[zidx, :zlen] = zobj.evt_arr
         arr_arg_labels[zidx, :zlen, :zlen] = zobj.arg_arr
     expr_evt_labels = BK.input_idx(arr_evt_labels)  # [*, slen]
     expr_arg_labels = BK.input_idx(arr_arg_labels)  # [*, slen]
     expr_loss_weight_non = BK.input_real([z.loss_weight_non for z in zobjs])  # [*]
     return arr_items, expr_evt_labels, expr_arg_labels, expr_loss_weight_non
コード例 #6
0
 def __init__(self,
              conf: PlainInputEmbedderConf,
              voc: SimpleVocab,
              npvec: np.ndarray = None,
              name="UNK"):
     super().__init__(conf, name)
     # --
     conf: PlainInputEmbedderConf = self.conf
     self.voc = voc
     # check init embeddings
     if conf.init_from_pretrain:
         zlog(
             f"Try to init {self.extra_repr()} with npvec.shape={npvec.shape if (npvec is not None) else None}"
         )
         if npvec is None:
             zwarn("warning: cannot get pre-trained embeddings to init!!")
     # get rare unk range
     voc_rare_unk_mask = []
     for w in self.voc.full_i2w:
         c = self.voc.word2count(w, df=None)
         voc_rare_unk_mask.append(
             float(c is not None and c <= conf.rare_unk_thr))
     self.rare_unk_mask = BK.input_real(voc_rare_unk_mask)  # stored tensor!
     # self.register_buffer()  # todo(note): do we need register buffer?
     self.use_rare_unk = (conf.rare_unk_rate > 0. and conf.rare_unk_thr > 0)
     # add the real embedding node
     self.E = EmbeddingNode(conf.econf,
                            npvec=npvec,
                            osize=conf.dim,
                            n_words=len(self.voc))
コード例 #7
0
ファイル: seq.py プロジェクト: zzsfornlp/zmsp
 def prepare(self, insts: Union[List[Sent], List[Frame]], mlen: int,
             use_cache: bool):
     conf: SeqExtractorConf = self.conf
     # get info
     if use_cache:
         zobjs = []
         attr_name = f"_scache_{conf.ftag}"  # should be unique
         for s in insts:
             one = getattr(s, attr_name, None)
             if one is None:
                 one = self._prep_f(s)
                 setattr(s, attr_name, one)  # set cache
             zobjs.append(one)
     else:
         zobjs = [self._prep_f(s) for s in insts]
     # batch things
     bsize = len(insts)
     # mlen = max(z.len for z in zobjs)  # note: fed by outside!!
     batched_shape = (bsize, mlen)
     # arr_first_items = np.full(batched_shape, None, dtype=object)
     arr_slabs = np.full(batched_shape, 0, dtype=np.int)
     for zidx, zobj in enumerate(zobjs):
         # arr_first_items[zidx, zobj.len] = zobj.first_items
         arr_slabs[zidx, :zobj.len] = zobj.tags
     # final setup things
     expr_slabs = BK.input_idx(arr_slabs)
     expr_loss_weight_non = BK.input_real(
         [z.loss_weight_non for z in zobjs])  # [*]
     # return arr_first_items, expr_slabs, expr_loss_weight_non
     return expr_slabs, expr_loss_weight_non
コード例 #8
0
ファイル: srl.py プロジェクト: zzsfornlp/zmsp
 def __init__(self, conf: ZDecoderSRLConf, name: str,
              vocab_evt: SimpleVocab, vocab_arg: SimpleVocab, ref_enc: ZEncoder, **kwargs):
     super().__init__(conf, name, **kwargs)
     conf: ZDecoderSRLConf = self.conf
     self.vocab_evt = vocab_evt
     self.vocab_arg = vocab_arg
     _enc_dim, _head_dim = ref_enc.get_enc_dim(), ref_enc.get_head_dim()
     # --
     self.vocab_bio_arg = None
     self.pred_cons_mat = None
     if conf.arg_use_bio:
         self.vocab_bio_arg = SeqVocab(vocab_arg)  # simply BIO vocab
         zlog(f"Use BIO vocab for srl: {self.vocab_bio_arg}")
         if conf.arg_pred_use_seq_cons:
             _m = self.vocab_bio_arg.get_allowed_transitions()
             self.pred_cons_mat = (1. - BK.input_real(_m)) * Constants.REAL_PRAC_MIN  # [L, L]
             zlog(f"Further use BIO constraints for decoding: {self.pred_cons_mat.shape}")
         helper_vocab_arg = self.vocab_bio_arg
     else:
         helper_vocab_arg = self.vocab_arg
     # --
     self.helper = ZDecoderSRLHelper(conf, self.vocab_evt, helper_vocab_arg, self.vocab_arg)
     # --
     # nodes
     self.evt_node: IdecNode = conf.evt_conf.make_node(_isize=_enc_dim, _nhead=_head_dim, _csize=(2 if conf.binary_evt else len(vocab_evt)))
     self.arg_node: IdecNode = conf.arg_conf.make_node(_isize=_enc_dim, _nhead=_head_dim, _csize=len(helper_vocab_arg))
     self.arg2_node: IdecNode = conf.arg2_conf.make_node(_isize=_enc_dim, _nhead=_head_dim, _csize=len(self.vocab_arg))
     # --
     raise RuntimeError("Deprecated after MED's collecting of scores!!")
コード例 #9
0
ファイル: idec.py プロジェクト: zzsfornlp/zmsp
 def get_weights(nl: int, weights):
     final_weights = [weights[-1]] * nl
     mlen = min(nl, len(weights))
     final_weights[:mlen] = weights[:mlen]
     ret = BK.input_real(final_weights)
     ret /= ret.sum(-1)  # normalize!
     return ret
コード例 #10
0
 def _input_bert(self, insts: List[Sent]):
     bi = self.berter.create_input_batch_from_sents(insts)
     mask_expr = BK.input_real(
         DataPadder.lengths2mask([len(z) for z in insts]))  # [bs, slen, *]
     vstate = self.idec_manager.new_vstate(
         None, mask_expr)  # todo(+N): currently ignore emb layer!
     bert_expr = self.berter.forward(bi, vstate=vstate)
     return mask_expr, bert_expr, vstate
コード例 #11
0
ファイル: param_reg.py プロジェクト: zzsfornlp/zmsp
 def loss_regs(regs: List['ParamRegHelper']):
     loss_items = []
     for ii, reg in enumerate(regs):
         if reg.reg_method_loss:
             _loss, _loss_lambda = reg.compute_loss()
             _loss_item = LossHelper.compile_leaf_loss(f'reg_{ii}', _loss, BK.input_real(1.), loss_lambda=_loss_lambda)
             loss_items.append(_loss_item)
     ret_loss = LossHelper.combine_multiple_losses(loss_items)
     return ret_loss
コード例 #12
0
 def cons_score_frame2role(self, arg_scores: BK.Expr, evts: List):
     evt_idxes = [e.label_idx for e in evts]
     valid_mask = BK.input_real(self.role_cons[evt_idxes])  # [??, L]
     if self.cons_arg_bio_sels is not None:
         valid_mask = valid_mask[:, self.cons_arg_bio_sels]  # [??, L']
     valid_mask[:, 0] = 1.  # note: must preserve NIL!
     arg_scores = arg_scores + (1 - valid_mask).unsqueeze(
         -2) * Constants.REAL_PRAC_MIN  # [??, dlen, L]
     return arg_scores
コード例 #13
0
 def get_batched_items(self, insts: List):
     all_items = [self._get_f(z) for z in insts]
     arr_shape = len(insts), max(len(z) for z in all_items)
     arr_items = np.full(arr_shape, None, dtype=object)
     arr_masks = np.full(arr_shape, 0., dtype=np.float32)
     for zidx, zitems in enumerate(all_items):
         zlen = len(zitems)
         arr_items[zidx, :zlen] = zitems
         arr_masks[zidx, :zlen] = 1.
     return arr_items, BK.input_real(arr_masks)
コード例 #14
0
 def prepare(self, insts: Union[List[Sent], List[Frame]], mlen: int,
             use_cache: bool):
     conf: AnchorExtractorConf = self.conf
     # get info
     if use_cache:
         zobjs = []
         attr_name = f"_acache_{conf.ftag}"  # should be unique
         for s in insts:
             one = getattr(s, attr_name, None)
             if one is None:
                 one = self._prep_f(s)
                 setattr(s, attr_name, one)  # set cache
             zobjs.append(one)
     else:
         zobjs = [self._prep_f(s) for s in insts]
     # batch things
     bsize, mlen2 = len(insts), max(len(z.items)
                                    for z in zobjs) if len(zobjs) > 0 else 1
     mnum = max(len(g) for z in zobjs
                for g in z.group_widxes) if len(zobjs) > 0 else 1
     arr_items = np.full((bsize, mlen2), None, dtype=object)  # [*, ?]
     arr_seq_iidxes = np.full((bsize, mlen), -1, dtype=np.int)
     arr_seq_labs = np.full((bsize, mlen), 0, dtype=np.int)
     arr_group_widxes = np.full((bsize, mlen, mnum), 0, dtype=np.int)
     arr_group_masks = np.full((bsize, mlen, mnum), 0., dtype=np.float)
     for zidx, zobj in enumerate(zobjs):
         arr_items[zidx, :len(zobj.items)] = zobj.items
         iidx_offset = zidx * mlen2  # note: offset for valid ones!
         arr_seq_iidxes[zidx, :len(zobj.seq_iidxes)] = [
             (iidx_offset + ii) if ii >= 0 else ii for ii in zobj.seq_iidxes
         ]
         arr_seq_labs[zidx, :len(zobj.seq_labs)] = zobj.seq_labs
         for zidx2, zwidxes in enumerate(zobj.group_widxes):
             arr_group_widxes[zidx, zidx2, :len(zwidxes)] = zwidxes
             arr_group_masks[zidx, zidx2, :len(zwidxes)] = 1.
     # final setup things
     expr_seq_iidxes = BK.input_idx(arr_seq_iidxes)  # [*, slen]
     expr_seq_labs = BK.input_idx(arr_seq_labs)  # [*, slen]
     expr_group_widxes = BK.input_idx(arr_group_widxes)  # [*, slen, MW]
     expr_group_masks = BK.input_real(arr_group_masks)  # [*, slen, MW]
     expr_loss_weight_non = BK.input_real(
         [z.loss_weight_non for z in zobjs])  # [*]
     return arr_items, expr_seq_iidxes, expr_seq_labs, expr_group_widxes, expr_group_masks, expr_loss_weight_non
コード例 #15
0
 def prepare_inputs(self, insts: List):
     ret = OrderedDict()
     # first basic masks
     mask_arr = DataPadder.lengths2mask([len(z) for z in insts])
     ret["mask"] = BK.input_real(mask_arr)
     ret["mask_arr"] = mask_arr
     # then the rest
     for key, inputter in self.inputters.items():
         ret[key] = inputter.prepare(insts)
     return ret
コード例 #16
0
 def __init__(self, ibatch: InputBatch, IDX_PAD: int):
     # preps
     self.bsize = len(ibatch)
     self.arange1_t = BK.arange_idx(self.bsize)  # [bsize]
     self.arange2_t = self.arange1_t.unsqueeze(-1)  # [bsize, 1]
     self.arange3_t = self.arange2_t.unsqueeze(-1)  # [bsize, 1, 1]
     # batched them
     all_seq_infos = [z.seq_info for z in ibatch.items]
     # enc: [*, len_enc]: ids(pad IDX_PAD), masks, segids(pad 0)
     self.enc_input_ids = BK.input_idx(
         DataPadder.go_batch_2d([z.enc_input_ids for z in all_seq_infos],
                                int(IDX_PAD)))
     self.enc_input_masks = BK.input_real(
         DataPadder.lengths2mask(
             [len(z.enc_input_ids) for z in all_seq_infos]))
     self.enc_input_segids = BK.input_idx(
         DataPadder.go_batch_2d([z.enc_input_segids for z in all_seq_infos],
                                0))
     # dec: [*, len_dec]: sel_idxes(pad 0), sel_lens(pad 1), masks, sent_idxes(pad ??)
     self.dec_sel_idxes = BK.input_idx(
         DataPadder.go_batch_2d([z.dec_sel_idxes for z in all_seq_infos],
                                0))
     self.dec_sel_lens = BK.input_idx(
         DataPadder.go_batch_2d([z.dec_sel_lens for z in all_seq_infos], 1))
     self.dec_sel_masks = BK.input_real(
         DataPadder.lengths2mask(
             [len(z.dec_sel_idxes) for z in all_seq_infos]))
     _max_dec_len = BK.get_shape(self.dec_sel_masks, 1)
     _dec_offsets = BK.input_idx(
         DataPadder.go_batch_2d([z.dec_offsets for z in all_seq_infos],
                                _max_dec_len))
     # note: CLS as -1, then 0,1,2,..., PAD gets -2!
     self.dec_sent_idxes = \
         (BK.arange_idx(_max_dec_len).unsqueeze(0).unsqueeze(-1) >= _dec_offsets.unsqueeze(-2)).sum(-1).long() - 1
     self.dec_sent_idxes[self.dec_sel_masks <= 0.] = -2
     # dec -> enc: [*, len_enc] (calculated on needed!)
     # note: require 1-to-1 mapping (except pads)!!
     self._enc_back_hits = None
     self._enc_back_sel_idxes = None
コード例 #17
0
 def __init__(self, conf: SrlInferenceHelperConf, dec: 'ZDecoderSrl',
              **kwargs):
     super().__init__(conf, **kwargs)
     conf: SrlInferenceHelperConf = self.conf
     # --
     self.setattr_borrow('dec', dec)
     self.arg_pp = PostProcessor(conf.arg_pp)
     # --
     self.lu_cons, self.role_cons = None, None
     if conf.frames_name:  # currently only frame->role
         from msp2.data.resources import get_frames_label_budgets
         flb = get_frames_label_budgets(conf.frames_name)
         _voc_ef, _voc_evt, _voc_arg = dec.ztask.vpack
         _role_cons = fchelper.build_constraint_arrs(
             flb, _voc_arg, _voc_evt)
         self.role_cons = BK.input_real(_role_cons)
     if conf.frames_file:
         _voc_ef, _voc_evt, _voc_arg = dec.ztask.vpack
         _fc = default_pickle_serializer.from_file(conf.frames_file)
         _lu_cons = fchelper.build_constraint_arrs(
             fchelper.build_lu_map(_fc), _voc_evt,
             warning=False)  # lexicon->frame
         _role_cons = fchelper.build_constraint_arrs(
             fchelper.build_role_map(_fc), _voc_arg,
             _voc_evt)  # frame->role
         self.lu_cons, self.role_cons = _lu_cons, BK.input_real(_role_cons)
     # --
     self.cons_evt_tok_f = conf.get_cons_evt_tok()
     self.cons_evt_frame_f = conf.get_cons_evt_frame()
     if self.dec.conf.arg_use_bio:  # extend for bio!
         self.cons_arg_bio_sels = BK.input_idx(
             self.dec.vocab_bio_arg.get_bio2origin())
     else:
         self.cons_arg_bio_sels = None
     # --
     from msp2.data.resources.frames import KBP17_TYPES
     self.pred_evt_filter = {
         'kbp17': KBP17_TYPES
     }.get(conf.pred_evt_filter, None)
コード例 #18
0
ファイル: nmst.py プロジェクト: zzsfornlp/zmsp
def nmarginal_proj(scores_expr, mask_expr, lengths_arr, labeled=True):
    assert labeled
    with BK.no_grad_env():
        # first make it unlabeled by sum-exp
        scores_unlabeled = BK.logsumexp(scores_expr, dim=-1)  # [BS, m, h]
        # marginal for unlabeled
        scores_unlabeled_arr = BK.get_value(scores_unlabeled)
        marginals_unlabeled_arr = marginal_proj(scores_unlabeled_arr, lengths_arr, False)
        # back to labeled values
        marginals_unlabeled_expr = BK.input_real(marginals_unlabeled_arr)
        marginals_labeled_expr = marginals_unlabeled_expr.unsqueeze(-1) * BK.exp(scores_expr - scores_unlabeled.unsqueeze(-1))
        # [BS, m, h, L]
        return _ensure_margins_norm(marginals_labeled_expr)
コード例 #19
0
 def __init__(self, berter: BertEncoder,
              seq_subs: List[InputSubwordSeqField]):
     self.seq_subs = seq_subs
     self.berter = berter
     self.bsize = len(seq_subs)
     self.arange1_t = BK.arange_idx(self.bsize)  # [bsize]
     self.arange2_t = self.arange1_t.unsqueeze(-1)  # [bsize, 1]
     self.arange3_t = self.arange2_t.unsqueeze(-1)  # [bsize, 1, 1]
     # --
     tokenizer = self.berter.tokenizer
     PAD_IDX = tokenizer.pad_token_id
     # MASK_IDX = tokenizer.mask_token_id
     # CLS_IDX_l = [tokenizer.cls_token_id]
     # SEP_IDX_l = [tokenizer.sep_token_id]
     # make batched idxes
     padder = DataPadder(2, pad_vals=PAD_IDX, mask_range=2)
     batched_sublens = [len(s.idxes) for s in seq_subs]  # [bsize]
     batched_input_ids, batched_input_mask = padder.pad(
         [s.idxes for s in seq_subs])  # [bsize, sub_len]
     self.batched_sublens_p1 = BK.input_idx(
         batched_sublens
     ) + 1  # also the idx of EOS (if counting including BOS)
     self.batched_input_ids = BK.input_idx(batched_input_ids)
     self.batched_input_mask = BK.input_real(batched_input_mask)
     # make batched mappings (sub->orig)
     padder2 = DataPadder(2, pad_vals=0,
                          mask_range=2)  # pad as 0 to avoid out-of-range
     batched_first_idxes, batched_first_mask = padder2.pad(
         [s.align_info.orig2begin for s in seq_subs])  # [bsize, orig_len]
     self.batched_first_idxes = BK.input_idx(batched_first_idxes)
     self.batched_first_mask = BK.input_real(batched_first_mask)
     # reversed batched_mappings (orig->sub) (created when needed)
     self._batched_rev_idxes = None  # [bsize, sub_len]
     # --
     self.batched_repl_masks = None  # [bsize, sub_len], to replace with MASK
     self.batched_token_type_ids = None  # [bsize, 1+sub_len+1]
     self.batched_position_ids = None  # [bsize, 1+sub_len+1]
     self.other_factors = {}  # name -> aug_batched_ids
コード例 #20
0
ファイル: helper.py プロジェクト: zzsfornlp/zmsp
def select_topk_non_overlapping(score_t: BK.Expr,
                                topk_t: Union[int, BK.Expr],
                                widx_t: BK.Expr,
                                wlen_t: BK.Expr,
                                input_mask_t: BK.Expr,
                                mask_t: BK.Expr = None,
                                dim=-1):
    score_shape = BK.get_shape(score_t)
    assert dim == -1 or dim == len(
        score_shape - 1
    ), "Currently only support last-dim!!"  # todo(+2): we can permute to allow any dim!
    # --
    # prepare K
    if isinstance(topk_t, int):
        tmp_shape = score_shape.copy()
        tmp_shape[dim] = 1  # set it as 1
        topk_t = BK.constants_idx(tmp_shape, topk_t)
    # --
    reshape_trg = [np.prod(score_shape[:-1]).item(), -1]  # [*, ?]
    _, sorted_idxes_t = score_t.sort(dim, descending=True)
    # --
    # put it as CPU and use loop; todo(+N): more efficient ways?
    arr_sorted_idxes_t, arr_topk_t, arr_widx_t, arr_wlen_t, arr_input_mask_t, arr_mask_t = \
        [BK.get_value(z.reshape(reshape_trg)) for z in [sorted_idxes_t, topk_t, widx_t, wlen_t, input_mask_t, mask_t]]
    _bsize, _cnum = BK.get_shape(arr_sorted_idxes_t)  # [bsize, NUM]
    arr_topk_mask = np.full([_bsize, _cnum], 0.)  # [bsize, NUM]
    _bidx = 0
    for aslice_sorted_idxes_t, aslice_topk_t, aslice_widx_t, aslice_wlen_t, aslice_input_mask_t, aslice_mask_t \
            in zip(arr_sorted_idxes_t, arr_topk_t, arr_widx_t, arr_wlen_t, arr_input_mask_t, arr_mask_t):
        aslice_topk_mask = arr_topk_mask[_bidx]
        # --
        cur_ok_mask = np.copy(aslice_input_mask_t)
        cur_budget = aslice_topk_t.item()
        for _cidx in aslice_sorted_idxes_t:
            _cidx = _cidx.item()
            if cur_budget <= 0: break  # no budget left
            if not aslice_mask_t[_cidx].item(): continue  # non-valid candidate
            one_widx, one_wlen = aslice_widx_t[_cidx].item(
            ), aslice_wlen_t[_cidx].item()
            if np.prod(cur_ok_mask[one_widx:one_widx +
                                   one_wlen]).item() == 0.:  # any hit one?
                continue
            # ok! add it!
            cur_budget -= 1
            cur_ok_mask[one_widx:one_widx + one_wlen] = 0.
            aslice_topk_mask[_cidx] = 1.
        _bidx += 1
    # note: no need to *=mask_t again since already check in the loop
    return BK.input_real(arr_topk_mask).reshape(score_shape)
コード例 #21
0
 def prepare(self, insts: List, use_cache: bool):
     # get info
     zobjs = ZDecoderHelper.get_zobjs(insts, self._prep_inst, use_cache, f"_cache_udep")
     # then
     bsize, mlen = len(insts), max(z.slen for z in zobjs) if len(zobjs)>0 else 1
     batched_shape = (bsize, mlen)
     arr_depth = np.full(batched_shape, 0., dtype=np.float)  # [*, slen]
     arr_udep = np.full(batched_shape+(mlen,), 0, dtype=np.int)  # [*, slen_m, slen_h]
     for zidx, zobj in enumerate(zobjs):
         zlen = zobj.slen
         arr_depth[zidx, :zlen] = zobj.depth_arr
         arr_udep[zidx, :zlen, :zlen] = zobj.udep_arr
     expr_depth = BK.input_real(arr_depth)  # [*, slen]
     expr_udep = BK.input_idx(arr_udep)  # [*, slen_m, slen_h]
     return expr_depth, expr_udep
コード例 #22
0
ファイル: nmst.py プロジェクト: zzsfornlp/zmsp
def _common_nmst(CPU_f, scores_expr, mask_expr, lengths_arr, labeled, ret_arr):
    assert labeled
    with BK.no_grad_env():
        # argmax-label: [BS, m, h]
        scores_unlabeled_max, labels_argmax = scores_expr.max(-1)
        #
        scores_unlabeled_max_arr = BK.get_value(scores_unlabeled_max)
        mst_heads_arr, _, mst_scores_arr = CPU_f(scores_unlabeled_max_arr, lengths_arr, labeled=False)
        # [BS, m]
        mst_heads_expr = BK.input_idx(mst_heads_arr)
        mst_labels_expr = BK.gather_one_lastdim(labels_argmax, mst_heads_expr).squeeze(-1)
        # prepare for the outputs
        if ret_arr:
            return mst_heads_arr, BK.get_value(mst_labels_expr), mst_scores_arr
        else:
            return mst_heads_expr, mst_labels_expr, BK.input_real(mst_scores_arr)
コード例 #23
0
 def __init__(self, cons: Constrainer, src_vocab: SimpleVocab, trg_vocab: SimpleVocab, conf: ConstrainerNodeConf, **kwargs):
     super().__init__(conf, **kwargs)
     conf: ConstrainerNodeConf = self.conf
     # --
     # input vocab
     if src_vocab is None:  # make our own src_vocab
         cons_keys = sorted(cons.cmap.keys())  # simply get all the keys
         src_vocab = SimpleVocab.build_by_static(cons_keys, pre_list=["non"], post_list=None)  # non==0!
     # output vocab
     assert trg_vocab is not None
     out_size = len(trg_vocab)  # output size is len(trg_vocab)
     trg_is_seq_vocab = isinstance(trg_vocab, SeqVocab)
     _trg_get_f = (lambda x: trg_vocab.get_range_by_basename(x)) if trg_is_seq_vocab else (lambda x: trg_vocab.get(x))
     # set it up
     _vec = np.full((len(src_vocab), out_size), 0., dtype=np.float32)
     assert src_vocab.non == 0
     _vec[0] = 1.  # by default: src-non is all valid!
     _vec[:,0] = 1.  # by default: trg-non is all valid!
     # --
     stat = {"k_skip": 0, "k_hit": 0, "v_skip": 0, "v_hit": 1}
     for k, v in cons.cmap.items():
         idx_k = src_vocab.get(k)
         if idx_k is None:
             stat["k_skip"] += 1
             continue  # skip no_hit!
         stat["k_hit"] += 1
         for k2 in v.keys():
             idx_k2 = _trg_get_f(k2)
             if idx_k2 is None:
                 stat["v_skip"] += 1
                 continue
             stat["v_hit"] += 1
             if trg_is_seq_vocab:
                 _vec[idx_k, idx_k2[0]:idx_k2[1]] = 1.  # hit range
             else:
                 _vec[idx_k, idx_k2] = 1.  # hit!!
     zlog(f"Setup ConstrainerNode ok: vec={_vec.shape}, stat={stat}")
     # --
     self.cons = cons
     self.src_vocab = src_vocab
     self.trg_vocab = trg_vocab
     self.vec = BK.input_real(_vec)
コード例 #24
0
ファイル: direct.py プロジェクト: zzsfornlp/zmsp
 def prepare(self, insts: Union[List[Sent], List[Frame]], use_cache: bool):
     conf: DirectExtractorConf = self.conf
     # get info
     if use_cache:
         zobjs = []
         attr_name = f"_dcache_{conf.ftag}"  # should be unique
         for s in insts:
             one = getattr(s, attr_name, None)
             if one is None:
                 one = self._prep_f(s)
                 setattr(s, attr_name, one)  # set cache
             zobjs.append(one)
     else:
         zobjs = [self._prep_f(s) for s in insts]
     # batch things
     bsize, mlen = len(insts), max(z.len for z in zobjs) if len(zobjs)>0 else 1  # at least put one as padding
     batched_shape = (bsize, mlen)
     arr_items = np.full(batched_shape, None, dtype=object)
     arr_gaddrs = np.arange(bsize*mlen).reshape(batched_shape)  # gold address
     arr_core_widxes = np.full(batched_shape, 0, dtype=np.int)
     arr_core_wlens = np.full(batched_shape, 1, dtype=np.int)
     # arr_ext_widxes = np.full(batched_shape, 0, dtype=np.int)
     # arr_ext_wlens = np.full(batched_shape, 1, dtype=np.int)
     for zidx, zobj in enumerate(zobjs):
         zlen = zobj.len
         arr_items[zidx, :zlen] = zobj.items
         arr_core_widxes[zidx, :zlen] = zobj.core_widxes
         arr_core_wlens[zidx, :zlen] = zobj.core_wlens
         # arr_ext_widxes[zidx, :zlen] = zobj.ext_widxes
         # arr_ext_wlens[zidx, :zlen] = zobj.ext_wlens
     arr_gaddrs[arr_items==None] = -1  # set -1 as gaddr
     # final setup things
     expr_gaddr = BK.input_idx(arr_gaddrs)  # [*, GLEN]
     expr_core_widxes = BK.input_idx(arr_core_widxes)
     expr_core_wlens = BK.input_idx(arr_core_wlens)
     # expr_ext_widxes = BK.input_idx(arr_ext_widxes)
     # expr_ext_wlens = BK.input_idx(arr_ext_wlens)
     expr_loss_weight_non = BK.input_real([z.loss_weight_non for z in zobjs])  # [*]
     # return arr_flatten_items, expr_gaddr, expr_core_widxes, expr_core_wlens, \
     #        expr_ext_widxes, expr_ext_wlens, expr_loss_weight_non
     return arr_items, expr_gaddr, expr_core_widxes, expr_core_wlens, expr_loss_weight_non
コード例 #25
0
ファイル: dec_udep.py プロジェクト: zzsfornlp/zmsp
 def get_label_mask(self, sels: List[str]):
     expand_sels = []
     for s in sels:
         if s in UD_CATEGORIES:
             expand_sels.extend(UD_CATEGORIES[s])
         else:
             expand_sels.append(s)
     expand_sels = sorted(set(expand_sels))
     voc = self.voc
     # --
     ret = np.zeros(len(voc))
     _cc = 0
     for s in expand_sels:
         if s in voc:
             ret[voc[s]] = 1.
             _cc += voc.word2count(s)
         else:
             zwarn(f"UNK dep label: {s}")
     _all_cc = voc.get_all_counts()
     zlog(f"Get label mask with {expand_sels}: {len(expand_sels)}=={ret.sum().item()} -> {_cc}/{_all_cc}={_cc/(_all_cc+1e-5)}")
     return BK.input_real(ret)
コード例 #26
0
ファイル: srl.py プロジェクト: zzsfornlp/zmsp
 def prepare(self, insts: List[Sent], use_cache: bool):
     # get info
     zobjs = ZDecoderHelper.get_zobjs(insts, self._prep_sent, use_cache, f"_cache_srl")
     # batch things
     bsize, mlen = len(insts), max(z.slen for z in zobjs) if len(zobjs)>0 else 1  # at least put one as padding
     batched_shape = (bsize, mlen)
     arr_items = np.full(batched_shape, None, dtype=object)
     arr_evt_labels = np.full(batched_shape, 0, dtype=np.int)
     arr_arg_labels = np.full(batched_shape+(mlen,), 0, dtype=np.int)
     arr_arg2_labels = np.full(batched_shape+(mlen,), 0, dtype=np.int)
     for zidx, zobj in enumerate(zobjs):
         zlen = zobj.slen
         arr_items[zidx, :zlen] = zobj.evt_items
         arr_evt_labels[zidx, :zlen] = zobj.evt_arr
         arr_arg_labels[zidx, :zlen, :zlen] = zobj.arg_arr
         arr_arg2_labels[zidx, :zlen, :zlen] = zobj.arg2_arr
     expr_evt_labels = BK.input_idx(arr_evt_labels)  # [*, slen]
     expr_arg_labels = BK.input_idx(arr_arg_labels)  # [*, slen, slen]
     expr_arg2_labels = BK.input_idx(arr_arg2_labels)  # [*, slen, slen]
     expr_loss_weight_non = BK.input_real([z.loss_weight_non for z in zobjs])  # [*]
     return arr_items, expr_evt_labels, expr_arg_labels, expr_arg2_labels, expr_loss_weight_non
コード例 #27
0
 def cons_score_lu2frame(self, evt_scores: BK.Expr, ibatch, given_f=None):
     dlen, llen = BK.get_shape(evt_scores)[-2:]
     _tok_f, _frame_f = self.cons_evt_tok_f, self.cons_evt_frame_f
     # --
     res = []
     for bidx, item in enumerate(ibatch.items):
         _dec_offsets = item.seq_info.dec_offsets
         one_res = [np.zeros(llen) for _ in range(dlen)]  # dlen*[L]
         for sidx, sent in enumerate(item.sents):
             _start = _dec_offsets[sidx]
             # tok
             if _tok_f is not None:
                 for widx, tok in enumerate(sent.get_tokens()):
                     _key = _tok_f(tok)
                     _arr = self.lu_cons.get(_key)
                     if _arr is not None:
                         one_res[_start + widx] += _arr
             # frame
             if _frame_f is not None and given_f is not None:
                 _frames = given_f(sent)
                 for ff in _frames:
                     _key = _frame_f(ff)
                     _arr = self.lu_cons.get(_key)
                     if _arr is not None:
                         one_res[_start + ff.mention.shead_widx] += _arr
         # --
         res.extend(one_res)
     # --
     valid_mask = BK.input_real(res).view(evt_scores.shape)  # [*, dlen, L]
     valid_mask.clamp_(max=1)
     exclude_mask = (1 - valid_mask) * (valid_mask.sum(-1, keepdims=True) >
                                        0).float()  # no effects if no-hit!
     exclude_mask[:, :,
                  0] = 0.  # note: still do not exclude NIL, use "pred_evt_nil_add" for special mode!
     evt_scores = evt_scores + exclude_mask * Constants.REAL_PRAC_MIN  # [*, dlen, L]
     return evt_scores
コード例 #28
0
ファイル: model.py プロジェクト: zzsfornlp/zmsp
 def _input_bert(self, insts: List[Sent]):
     bi = self.berter.create_input_batch_from_sents(insts)
     mask_expr = BK.input_real(
         DataPadder.lengths2mask([len(z) for z in insts]))  # [bs, slen, *]
     bert_expr = self.berter.forward(bi)
     return mask_expr, bert_expr
コード例 #29
0
 def __init__(self, vocab: SimpleVocab, conf: SeqLabelerConf, **kwargs):
     super().__init__(conf, **kwargs)
     conf: SeqLabelerConf = self.conf
     is_pairwise = (conf.psize > 0)
     self.is_pairwise = is_pairwise
     # --
     # 0. pre mlp
     isize, psize = conf.isize, conf.psize
     self.main_mlp = MLPNode(conf.main_mlp,
                             isize=isize,
                             osize=-1,
                             use_out=False)
     isize = self.main_mlp.get_output_dims()[0]
     if is_pairwise:
         self.pair_mlp = MLPNode(conf.pair_mlp,
                                 isize=psize,
                                 osize=-1,
                                 use_out=False)
         psize = self.pair_mlp.get_output_dims()[0]
     else:
         self.pair_mlp = lambda x: x
     # 1/2. decoder & laber
     if conf.use_seqdec:
         # extra for seq-decoder
         dec_hid = conf.seqdec_conf.dec_hidden
         # setup labeler to get embedding dim
         self.laber = SimpleLabelerNode(vocab,
                                        conf.labeler_conf,
                                        isize=dec_hid,
                                        psize=psize)
         laber_embed_dim = self.laber.lookup_dim
         # init starting hidden; note: choose different according to 'is_pairwise'
         self.sd_init_aff = AffineNode(
             conf.sd_init_aff,
             isize=(psize if is_pairwise else isize),
             osize=dec_hid)
         self.sd_init_pool_f = ActivationHelper.get_pool(conf.sd_init_pool)
         # sd input: one_repr + one_idx_embed
         self.sd_input_aff = AffineNode(conf.sd_init_aff,
                                        isize=[isize, laber_embed_dim],
                                        osize=dec_hid)
         # sd output: cur_expr + hidden
         self.sd_output_aff = AffineNode(conf.sd_output_aff,
                                         isize=[isize, dec_hid],
                                         osize=dec_hid)
         # sd itself
         self.seqdec = PlainDecoder(conf.seqdec_conf, input_dim=dec_hid)
     else:
         # directly using the scorer (overwrite some values)
         self.laber = SimpleLabelerNode(vocab,
                                        conf.labeler_conf,
                                        isize=isize,
                                        psize=psize)
     # 3. bigram
     # todo(note): bigram does not consider skip_non
     if conf.use_bigram:
         self.bigram = BigramNode(conf.bigram_conf,
                                  osize=self.laber.output_dim)
     else:
         self.bigram = None
     # special decoding
     if conf.pred_use_seq_cons_from_file:
         assert not conf.pred_use_seq_cons
         _m = default_pickle_serializer.from_file(
             conf.pred_use_seq_cons_from_file)
         zlog(f"Load weights from {conf.pred_use_seq_cons_from_file}")
         self.pred_cons_mat = BK.input_real(_m)
     elif conf.pred_use_seq_cons:
         _m = vocab.get_allowed_transitions()
         self.pred_cons_mat = (1. -
                               BK.input_real(_m)) * Constants.REAL_PRAC_MIN
     else:
         self.pred_cons_mat = None
     # =====
     # loss
     self.loss_mle, self.loss_crf = [
         conf.loss_mode == z for z in ["mle", "crf"]
     ]
     if self.loss_mle:
         if conf.use_seqdec or conf.use_bigram:
             zlog("Setup SeqLabelerNode with Local complex mode!")
         else:
             zlog("Setup SeqLabelerNode with Local simple mode!")
     elif self.loss_crf:
         assert conf.use_bigram and (
             not conf.use_seqdec), "Wrong mode for crf"
         zlog("Setup SeqLabelerNode with CRF mode!")
     else:
         raise NotImplementedError(f"UNK loss mode: {conf.loss_mode}")