コード例 #1
0
ファイル: dec_udep.py プロジェクト: zzsfornlp/zmsp
 def decode_udep(self, ibatch, udep_logprobs_t: BK.Expr, root_logprobs_t: BK.Expr):
     conf: ZDecoderUdepConf = self.conf
     # --
     arr_udep = BK.get_value(udep_logprobs_t.transpose(-2,-3))  # [*, m, h, L]
     arr_root = None if root_logprobs_t is None else BK.get_value(root_logprobs_t)  # [*, dlen]
     _dim_label = arr_udep.shape[-1]
     _neg = -10000.  # should be enough!!
     _voc, _lab_range = self.ztask.vpack
     _idx_root = self._label_idx_root
     # --
     for bidx, item in enumerate(ibatch.items):  # for each item in the batch
         _dec_offsets = item.seq_info.dec_offsets
         for sidx, sent in enumerate(item.sents):
             if conf.msent_pred_center and (sidx != item.center_sidx):
                 continue  # skip non-center sent in this mode!
             _start = _dec_offsets[sidx]
             _len = len(sent)
             _len_p1 = _len + 1
             # --
             _arr = np.full([_len_p1, _len_p1, _dim_label], _neg, dtype=np.float32)  # [1+m, 1+h, L]
             # assign label scores
             _arr[1:_len_p1, 1:_len_p1, 1:_lab_range] = arr_udep[bidx, _start:_start+_len, _start:_start+_len, 1:_lab_range]
             # assign root scores
             if arr_root is not None:
                 _arr[1:_len_p1, 0, _idx_root] = arr_root[bidx, _start:_start+_len]
             else:  # todo(+N): currently simply assign a smaller "neg-inf"
                 _arr[1:_len_p1, 0, _idx_root] = -99.
             # --
             from msp2.tools.algo.nmst import mst_unproj  # decoding algorithm
             arr_ret_heads, arr_ret_labels, arr_ret_scores = \
                 mst_unproj(_arr[None], np.asarray([_len_p1]), labeled=True)  # [*, 1+slen]
             # assign
             list_dep_heads = arr_ret_heads[0, 1:_len_p1].tolist()
             list_dep_lidxes = arr_ret_labels[0, 1:_len_p1].tolist()
             list_dep_labels = _voc.seq_idx2word(list_dep_lidxes)
             sent.build_dep_tree(list_dep_heads, list_dep_labels)
コード例 #2
0
 def _transform_factors(self, factors: Union[List[List[int]], BK.Expr],
                        is_orig: bool, PAD_IDX: Union[int, float]):
     if isinstance(factors, BK.Expr):  # already padded
         batched_ids = factors
     else:
         padder = DataPadder(2, pad_vals=PAD_IDX)
         batched_ids, _ = padder.pad(factors)
         batched_ids = BK.input_idx(
             batched_ids)  # [bsize, orig-len if is_orig else sub_len]
     if is_orig:  # map to subtoks
         final_batched_ids = batched_ids[
             self.arange2_t, self.batched_rev_idxes]  # [bsize, sub_len]
     else:
         final_batched_ids = batched_ids  # [bsize, sub_len]
     return final_batched_ids
コード例 #3
0
ファイル: expand.py プロジェクト: zzsfornlp/zmsp
 def score(self,
           input_main: BK.Expr,
           input_pair: BK.Expr,
           input_mask: BK.Expr,
           left_constraints: BK.Expr = None,
           right_constraints: BK.Expr = None):
     conf: SpanExpanderConf = self.conf
     # --
     # left & right
     rets = []
     seq_shape = BK.get_shape(input_mask)
     cur_mask = input_mask
     arange_t = BK.arange_idx(seq_shape[-1]).view(
         [1] * (len(seq_shape) - 1) + [-1])  # [*, slen]
     for scorer, cons_t in zip([self.s_left, self.s_right],
                               [left_constraints, right_constraints]):
         mm = cur_mask if cons_t is None else (
             cur_mask * (arange_t <= cons_t).float())  # [*, slen]
         ss = scorer(
             input_main,
             None if input_pair is None else input_pair.unsqueeze(-2),
             mm).squeeze(-1)  # [*, slen]
         rets.append(ss)
     return rets[0], rets[1]  # [*, slen] (already masked)
コード例 #4
0
 def forward(self, input_map: Dict):
     add_bos, add_eos = self.conf.add_bos, self.conf.add_eos
     ret = OrderedDict()  # [*, len, ?]
     for key, embedder_pack in self.embedders.items(
     ):  # according to REG order!!
         embedder, input_name = embedder_pack
         one_expr = embedder(input_map[input_name],
                             add_bos=add_bos,
                             add_eos=add_eos)
         ret[key] = one_expr
     # mask expr
     mask_expr = input_map.get("mask")
     if mask_expr is not None:
         all_input_slices = []
         mask_slice = BK.constants(BK.get_shape(mask_expr)[:-1] + [1],
                                   1,
                                   dtype=mask_expr.dtype)  # [*, 1]
         if add_bos:
             all_input_slices.append(mask_slice)
         all_input_slices.append(mask_expr)
         if add_eos:
             all_input_slices.append(mask_slice)
         mask_expr = BK.concat(all_input_slices, -1)  # [*, ?+len+?]
     return mask_expr, ret
コード例 #5
0
ファイル: srl.py プロジェクト: zzsfornlp/zmsp
 def predict(self, med: ZMediator):
     conf: ZDecoderSRLConf = self.conf
     insts, mask_expr = med.insts, med.get_mask_t()
     # --
     pred_evt_labels, pred_evt_scores = self._pred_evt()
     pred_arg_labels, pred_arg_scores = self._pred_arg(mask_expr, pred_evt_labels)
     # transfer data from gpu also counts (also make sure gpu calculations are done)!
     all_arrs = [BK.get_value(z) for z in [pred_evt_labels, pred_evt_scores, pred_arg_labels, pred_arg_scores]]
     # =====
     # assign; also record post-processing (non-computing) time
     time0 = time.time()
     self.helper.put_results(insts, all_arrs)
     time1 = time.time()
     # --
     return {f"{self.name}_posttime": time1-time0}
コード例 #6
0
ファイル: idec.py プロジェクト: zzsfornlp/zmsp
 def _merge_cf_geo(self, all_cfs: List[BK.Expr]):
     _temp = self.current_temperature
     accu_cfs = []
     remainings = None
     for cf in all_cfs:
         cf_prob = cf.sigmoid() if _temp == 1. else (cf /
                                                     _temp).sigmoid()  # [*]
         if remainings is None:
             accu_cfs.append(cf_prob)
             remainings = 1. - cf_prob
         else:
             accu_cfs.append(cf_prob * remainings)
             remainings = remainings * (1. - cf_prob)
     # add back to the final one
     accu_cfs[-1] += remainings
     return BK.stack(accu_cfs, -1)  # [*, NL]
コード例 #7
0
ファイル: soft.py プロジェクト: zzsfornlp/zmsp
 def _loss_feed_split(self, mask_expr, split_scores, pred_split_decisions,
                      cand_widxes, cand_masks, cand_expr, cand_scores,
                      expr_seq_gaddr):
     conf: SoftExtractorConf = self.conf
     bsize, slen = BK.get_shape(mask_expr)
     arange2_t = BK.arange_idx(bsize).unsqueeze(-1)  # [*, 1, 1]
     # --
     # step 2.1: split loss (only on good points (excluding -1|-1 or paddings) with dynamic oracle)
     cand_gaddr = expr_seq_gaddr[arange2_t, cand_widxes]  # [*, clen]
     cand_gaddr0, cand_gaddr1 = cand_gaddr[:, :
                                           -1], cand_gaddr[:,
                                                           1:]  # [*, clen-1]
     split_oracle = (cand_gaddr0 !=
                     cand_gaddr1).float() * cand_masks[:, 1:]  # [*, clen-1]
     split_oracle_mask = (
         (cand_gaddr0 >= 0) |
         (cand_gaddr1 >= 0)).float() * cand_masks[:, 1:]  # [*, clen-1]
     raw_split_loss = BK.loss_binary(
         split_scores,
         split_oracle,
         label_smoothing=conf.split_label_smoothing)  # [*, slen]
     loss_split_item = LossHelper.compile_leaf_loss(
         f"split", (raw_split_loss * split_oracle_mask).sum(),
         split_oracle_mask.sum(),
         loss_lambda=conf.loss_split)
     # step 2.2: feed split
     # note: when teacher-forcing, only forcing good points, others still use pred
     force_split_decisions = split_oracle_mask * split_oracle + (
         1. - split_oracle_mask) * pred_split_decisions  # [*, clen-1]
     _use_force_mask = (BK.rand([bsize])
                        <= conf.split_feed_force_rate).float().unsqueeze(
                            -1)  # [*, 1], seq-level
     feed_split_decisions = (_use_force_mask * force_split_decisions +
                             (1. - _use_force_mask) * pred_split_decisions
                             )  # [*, clen-1]
     # next
     # *[*, seglen, MW], [*, seglen]
     seg_ext_cidxes, seg_ext_masks, seg_masks = self._split_extend(
         feed_split_decisions, cand_masks)
     seg_ext_scores, seg_ext_cidxes, seg_ext_widxes, seg_ext_masks, seg_weighted_expr = self._split_aggregate(
         cand_expr, cand_scores, cand_widxes, seg_ext_cidxes, seg_ext_masks,
         conf.split_topk)  # [*, seglen, ?]
     # finally get oracles for next steps
     # todo(+N): simply select the highest scored one as oracle
     if BK.is_zero_shape(seg_ext_scores):  # simply make them all -1
         oracle_gaddr = BK.constants_idx(seg_masks.shape, -1)  # [*, seglen]
     else:
         _, _seg_max_t = seg_ext_scores.max(-1,
                                            keepdim=True)  # [*, seglen, 1]
         oracle_widxes = seg_ext_widxes.gather(-1, _seg_max_t).squeeze(
             -1)  # [*, seglen]
         oracle_gaddr = expr_seq_gaddr.gather(-1,
                                              oracle_widxes)  # [*, seglen]
     oracle_gaddr[seg_masks <= 0] = -1  # (assign invalid ones) [*, seglen]
     return loss_split_item, seg_masks, seg_ext_widxes, seg_ext_masks, seg_weighted_expr, oracle_gaddr
コード例 #8
0
ファイル: common.py プロジェクト: zzsfornlp/zmsp
 def do_score(self, query, key):
     conf: AttentionPlainConf = self.conf
     query_len = BK.get_shape(query, -2)
     key_len = BK.get_shape(key, -2)
     # --
     # 1. project
     query_up = self._shape_project(self.affine_q(query), conf.nh_qk)  # [*, Hin, len_q, d_qk]
     key_up = self._shape_project(self.affine_k(key), conf.nh_qk)  # [*, Hin, len_k, d_qk]
     # 2. score
     query_up = query_up / self._att_scale
     scores_t = BK.matmul(query_up, BK.transpose(key_up, -1, -2))  # [*, Hin, len_q, len_k]
     if conf.use_rposi:
         distance, distance_out, _ = self.rposi.embed_lens(query_len, key_len)
         # avoid broadcast!
         _d_bs, _d_h, _d_q, _d_d = BK.get_shape(query_up)
         query_up0 = BK.reshape(query_up.transpose(2, 1).transpose(1, 0), [_d_q, _d_bs * _d_h, _d_d])
         add_term0 = BK.matmul(query_up0, distance_out.transpose(-1, -2))  # [len_q, head*bs, len_k]
         add_term = BK.reshape(add_term0.transpose(0, 1), BK.get_shape(scores_t))
         # --
         scores_t += add_term  # [*, Hin, len_q, len_k]
     # todo(note): no dropout here, if use this at outside, need extra one!!
     return scores_t  # [*, Hin, len_q, len_k]
コード例 #9
0
ファイル: helper.py プロジェクト: zzsfornlp/zmsp
def apply_piece_pooling(t: BK.Expr,
                        piece: int,
                        f: Union[Callable,
                                 str] = ActivationHelper.get_pool('max'),
                        dim: int = -1):
    # first do things like chunk by piece
    if piece == 1:
        return t  # nothing to do
    # reshape
    orig_shape = BK.get_shape(t)
    if dim < 0:  # should do this!
        dim = len(orig_shape) + dim
    orig_shape[dim] = piece  # replace it with piece
    new_shape = orig_shape[:dim] + [-1] + orig_shape[dim:]  # put before it
    reshaped_t = t.view(new_shape)  # [..., -1, piece, ...]
    if isinstance(f, str):
        f = ActivationHelper.get_pool(f)
    return f(reshaped_t, dim + 1)  # +1 since we make a new dim
コード例 #10
0
ファイル: base.py プロジェクト: zzsfornlp/zmsp
 def forward(self, input_expr: BK.Expr, widx_expr: BK.Expr, wlen_expr: BK.Expr):
     conf: BaseSpanConf = self.conf
     # --
     # note: check empty, otherwise error
     input_item_shape = BK.get_shape(widx_expr)
     if np.prod(input_item_shape) == 0:
         return BK.zeros(input_item_shape + [self.output_dim])  # return an empty but shaped tensor
     # --
     start_idxes, end_idxes = widx_expr, widx_expr+wlen_expr  # make [start, end)
     # get sizes
     bsize, slen = BK.get_shape(input_expr)[:2]
     # num_span = BK.get_shape(start_idxes, 1)
     arange2_t = BK.arange_idx(bsize).unsqueeze(-1)  # [bsize, 1]
     # --
     reprs = []
     if conf.use_starts:  # start [start,
         reprs.append(input_expr[arange2_t, start_idxes])  # [bsize, ?, D]
     if conf.use_ends:  # simply ,end-1]
         reprs.append(input_expr[arange2_t, end_idxes-1])
     if conf.use_softhead:
         # expand range
         all_span_idxes, all_span_mask = expand_ranged_idxes(widx_expr, wlen_expr, 0, None)  # [bsize, ?, MW]
         # flatten
         flatten_all_span_idxes = all_span_idxes.view(bsize, -1)  # [bsize, ?*MW]
         flatten_all_span_mask = all_span_mask.view(bsize, -1)  # [bsize, ?*MW]
         # get softhead score (consider mask here)
         softhead_scores = self.softhead_scorer(input_expr).squeeze(-1)  # [bsize, slen]
         flatten_all_span_scores = softhead_scores[arange2_t, flatten_all_span_idxes]  # [bsize, ?*MW]
         flatten_all_span_scores += (1.-flatten_all_span_mask) * Constants.REAL_PRAC_MIN
         all_span_scores = flatten_all_span_scores.view(all_span_idxes.shape)  # [bsize, ?, MW]
         # reshape and (optionally topk) and softmax
         softhead_topk = conf.softhead_topk
         if softhead_topk>0 and BK.get_shape(all_span_scores,-1)>softhead_topk:  # further select topk; note: this may save mem
             final_span_score, _tmp_idxes = all_span_scores.topk(softhead_topk, dim=-1, sorted=False)  # [bsize, ?, K]
             final_span_idxes = all_span_idxes.gather(-1, _tmp_idxes)  # [bsize, ?, K]
         else:
             final_span_score, final_span_idxes = all_span_scores, all_span_idxes  # [bsize, ?, MW]
         final_prob = final_span_score.softmax(-1)  # [bsize, ?, ??]
         # [bsize, ?, ??, D]
         final_repr = input_expr[arange2_t, final_span_idxes.view(bsize, -1)].view(BK.get_shape(final_span_idxes)+[-1])
         weighted_repr = (final_repr * final_prob.unsqueeze(-1)).sum(-2)  # [bsize, ?, D]
         reprs.append(weighted_repr)
     if conf.use_width:
         cur_width_embed = self.width_embed(wlen_expr)  # [bsize, ?, DE]
         reprs.append(cur_width_embed)
     # concat
     concat_repr = BK.concat(reprs, -1)  # [bsize, ?, SUM]
     if conf.use_proj:
         ret = self.final_proj(concat_repr)  # [bsize, ?, DR]
     else:
         ret = concat_repr
     return ret
コード例 #11
0
ファイル: model.py プロジェクト: zzsfornlp/zmsp
 def predict_on_batch(self, ibatch: InputBatch, **kwargs):
     with BK.no_grad_env():
         self.refresh_batch(False)
         self._mark_active(ibatch)
         # --
         # restart
         self.encoder.restart(ibatch, self.med)
         # prepare
         self.med.do_prep_enc()
         # enc forward
         self.encoder.forward(self.med)
         # get all losses
         pred_info = self.med.do_preds()
         # --
         info = {"inst": len(ibatch), "ff": 1}
         info.update(pred_info)
         self.med.restart(None)  # clean
         return info
コード例 #12
0
ファイル: __init__.py プロジェクト: zzsfornlp/zmsp
def init_everything(main_conf: Conf,
                    args: Iterable[str],
                    add_utils=True,
                    add_nn=True):
    list_args = list(args)  # store it!
    gconf = get_singleton_global_conf()
    # utils?
    if add_utils:
        # first we try to init a Msp2UtilsConf to allow logging!
        utils_conf = Msp2UtilsConf()
        utils_conf.update_from_args(list_args,
                                    quite=True,
                                    check=False,
                                    add_global_key='')
        init(utils_conf)
        # add to gconf!
        gconf.add_subconf("utils", Msp2UtilsConf())
    # nn?
    if add_nn:
        from msp2.nn import BK
        gconf.add_subconf("nn", BK.BKNIConf())
    # --
    # then actual init
    all_argv = main_conf.update_from_args(list_args)
    # --
    # init utils
    if add_utils:
        # write conf?
        if utils_conf.conf_output:
            with zopen(utils_conf.conf_output, 'w') as fd:
                for k, v in all_argv.items():
                    # todo(note): do not save this one!!
                    if k.split(".")[-1] not in [
                            "conf_output", "log_file", "log_files"
                    ]:
                        fd.write(f"{k}:{v}\n")
        # no need to re-init
    # --
    # init nn
    if add_nn:
        from msp2.nn import init as nn_init
        nn_init(gconf.nn)
    # --
    return main_conf
コード例 #13
0
ファイル: base.py プロジェクト: zzsfornlp/zmsp
 def load(self, path, strict=None):
     if strict is not None:
         BK.load_model(self, path, strict=strict)
     else:  # otherwise, first try strict, then relax if there are errors
         try:
             BK.load_model(self, path, strict=True)
         except:
             import traceback
             zlog(
                 f"#== Error in strict loading:\n{traceback.format_exc()}\n#=="
             )
             BK.load_model(self, path, strict=False)
     zlog(f"Load {self} from {path}.", func="io")
コード例 #14
0
ファイル: srl.py プロジェクト: zzsfornlp/zmsp
 def put_results(self, insts: List[Sent], best_evt_labs, best_evt_scores, best_arg_labs, best_arg_scores):
     conf: MySRLConf = self.conf
     _evt_pred_use_posi = conf.evt_pred_use_posi
     vocab_evt = self.vocab_evt
     vocab_arg = self.vocab_arg
     if conf.arg_use_bio:
         real_vocab_arg = vocab_arg.base_vocab
     else:
         real_vocab_arg = vocab_arg
     # --
     all_arrs = [BK.get_value(z) for z in [best_evt_labs, best_evt_scores, best_arg_labs, best_arg_scores]]
     for bidx, inst in enumerate(insts):
         inst.delete_frames(conf.arg_ftag)  # delete old args
         # --
         cur_len = len(inst)
         cur_evt_labs, cur_evt_scores, cur_arg_labs, cur_arg_scores = [z[bidx][:cur_len] for z in all_arrs]
         inst.info["evt_lab"] = [vocab_evt.idx2word(z) if z>0 else 'O' for z in cur_evt_labs]
         # --
         if _evt_pred_use_posi:  # special mode
             for evt in inst.get_frames(conf.evt_ftag):
                 # reuse posi but re-assign label!
                 one_widx = evt.mention.shead_widx
                 one_lab, one_score = cur_evt_labs[one_widx].item(), cur_evt_scores[one_widx].item()
                 evt.set_label(vocab_evt.idx2word(one_lab))
                 evt.set_label_idx(one_lab)
                 evt.score = one_score
                 # args
                 new_arg_scores = cur_arg_scores[one_widx][:cur_len]
                 new_arg_label_idxes = cur_arg_labs[one_widx][:cur_len]
                 self.decode_arg(evt, new_arg_label_idxes, new_arg_scores, vocab_arg, real_vocab_arg)
         else:  # pred everything!
             inst.delete_frames(conf.evt_ftag)
             for one_widx in range(cur_len):
                 one_lab, one_score = cur_evt_labs[one_widx].item(), cur_evt_scores[one_widx].item()
                 if one_lab == 0:
                     continue
                 # make new evt
                 new_evt = inst.make_frame(one_widx, 1, conf.evt_ftag, type=vocab_evt.idx2word(one_lab), score=one_score)
                 new_evt.set_label_idx(one_lab)
                 self.evt_span_setter(new_evt.mention, one_widx, 1)
                 # args
                 new_arg_scores = cur_arg_scores[one_widx][:cur_len]
                 new_arg_label_idxes = cur_arg_labs[one_widx][:cur_len]
                 self.decode_arg(new_evt, new_arg_label_idxes, new_arg_scores, vocab_arg, real_vocab_arg)
コード例 #15
0
 def _common_prepare(self, input_shape: Tuple[int], _mask_f: Callable,
                     gold_widx_expr: BK.Expr, gold_wlen_expr: BK.Expr,
                     gold_addr_expr: BK.Expr):
     conf: SpanExtractorConf = self.conf
     min_width, max_width = conf.min_width, conf.max_width
     diff_width = max_width - min_width + 1  # number of width to extract
     # --
     bsize, mlen = input_shape
     # --
     # [bsize, mlen*(max_width-min_width)], mlen first (dim=1)
     # note: the spans are always sorted by (widx, wlen)
     _tmp_arange_t = BK.arange_idx(mlen * diff_width)  # [mlen*dw]
     widx_t0 = (_tmp_arange_t // diff_width)  # [mlen*dw]
     wlen_t0 = (_tmp_arange_t % diff_width) + min_width  # [mlen*dw]
     mask_t0 = _mask_f(widx_t0, wlen_t0)  # [bsize, mlen*dw]
     # --
     # compacting (use mask2idx and gather)
     final_idx_t, final_mask_t = BK.mask2idx(mask_t0,
                                             padding_idx=0)  # [bsize, ??]
     _tmp2_arange_t = BK.arange_idx(bsize).unsqueeze(1)  # [bsize, 1]
     # no need to make valid for mask=0, since idx=0 means (0, min_width)
     # todo(+?): do we need to deal with empty ones here?
     ret_widx = widx_t0[final_idx_t]  # [bsize, ??]
     ret_wlen = wlen_t0[final_idx_t]  # [bsize, ??]
     # --
     # prepare gold (as pointer-like addresses)
     if gold_addr_expr is not None:
         gold_t0 = BK.constants_idx((bsize, mlen * diff_width),
                                    -1)  # [bsize, mlen*diff]
         # check valid of golds (flatten all)
         gold_valid_t = ((gold_addr_expr >= 0) &
                         (gold_wlen_expr >= min_width) &
                         (gold_wlen_expr <= max_width))
         gold_valid_t = gold_valid_t.view(-1)  # [bsize*_glen]
         _glen = BK.get_shape(gold_addr_expr, 1)
         flattened_bsize_t = BK.arange_idx(
             bsize * _glen) // _glen  # [bsize*_glen]
         flattened_fidx_t = (gold_widx_expr * diff_width + gold_wlen_expr -
                             min_width).view(-1)  # [bsize*_glen]
         flattened_gaddr_t = gold_addr_expr.view(-1)
         # mask and assign
         gold_t0[flattened_bsize_t[gold_valid_t],
                 flattened_fidx_t[gold_valid_t]] = flattened_gaddr_t[
                     gold_valid_t]
         ret_gaddr = gold_t0[_tmp2_arange_t, final_idx_t]  # [bsize, ??]
         ret_gaddr.masked_fill_((final_mask_t == 0),
                                -1)  # make invalid ones -1
     else:
         ret_gaddr = None
     # --
     return ret_widx, ret_wlen, final_mask_t, ret_gaddr
コード例 #16
0
 def judge(self, scores: BK.Expr, cf_scores: BK.Expr, mask_t: BK.Expr):
     conf: SeqExitHelperConf = self.conf
     # --
     if self.is_cf and conf.cf_use_seq:  # in this mode, already aggr_metrics
         return (cf_scores.squeeze(-1) / conf.cf_scale) >= conf.exit_thresh
     # --
     if self.is_cf:  # Geometric-like qt
         # aggr_metrics = (cf_scores.squeeze(-1)).sigmoid()  # [*]
         seq_metrics = cf_scores.squeeze(-1) / conf.cf_scale  # [*, slen]
     else:
         seq_metrics = self._cri_f(scores)  # [*, slen]
     # --
     seq_metrics = (1. -
                    mask_t) + mask_t * seq_metrics  # put 1. at mask place!
     slen = BK.get_shape(seq_metrics, -1)
     K = self._getk(slen, conf.exit_min_k)
     aggr_metrics = topk_avg(seq_metrics, mask_t, K, dim=-1,
                             largest=False)  # [*]
     return aggr_metrics >= conf.exit_thresh
コード例 #17
0
 def score_on_batch(self, insts: List, **kwargs):
     conf: ZmtlModelConf = self.conf
     # --
     with BK.no_grad_env():
         self.refresh_batch(conf.score_training_flag)
         actual_insts = list(self._yield_insts(insts))
         # forward enc
         med = self.med
         enc_cached_input = self.enc.prepare_inputs(actual_insts * conf.score_times)  # multiple times
         self.enc.forward(None, med, cached_input=enc_cached_input)
         # do score with dec
         # note: do we need to split here?
         info_counter = med.do_scores(orig_insts=actual_insts)
         # --
         info = {"inst0": len(insts), "inst": len(actual_insts), "forw": 1}
         info.update(info_counter)
         # --
         med.restart()
         return info
コード例 #18
0
ファイル: upos.py プロジェクト: zzsfornlp/zmsp
 def _loss_upos(self, mask_expr, expr_upos_labels):
     conf: ZDecoderUPOSConf = self.conf
     all_upos_scores = self.upos_node.buffer_scores.values()  # [*, slen, L]
     all_upos_losses = []
     for one_upos_scores in all_upos_scores:
         one_losses = BK.loss_nll(one_upos_scores,
                                  expr_upos_labels,
                                  label_smoothing=conf.upos_label_smoothing)
         all_upos_losses.append(one_losses)
     upos_loss_results = self.upos_node.helper.loss(
         all_losses=all_upos_losses)
     loss_items = []
     for loss_t, loss_alpha, loss_name in upos_loss_results:
         one_upos_item = LossHelper.compile_leaf_loss(
             "upos" + loss_name, (loss_t * mask_expr).sum(),
             mask_expr.sum(),
             loss_lambda=(loss_alpha * conf.loss_upos))
         loss_items.append(one_upos_item)
     return loss_items
コード例 #19
0
 def __init__(self, cons: Constrainer, src_vocab: SimpleVocab, trg_vocab: SimpleVocab, conf: ConstrainerNodeConf, **kwargs):
     super().__init__(conf, **kwargs)
     conf: ConstrainerNodeConf = self.conf
     # --
     # input vocab
     if src_vocab is None:  # make our own src_vocab
         cons_keys = sorted(cons.cmap.keys())  # simply get all the keys
         src_vocab = SimpleVocab.build_by_static(cons_keys, pre_list=["non"], post_list=None)  # non==0!
     # output vocab
     assert trg_vocab is not None
     out_size = len(trg_vocab)  # output size is len(trg_vocab)
     trg_is_seq_vocab = isinstance(trg_vocab, SeqVocab)
     _trg_get_f = (lambda x: trg_vocab.get_range_by_basename(x)) if trg_is_seq_vocab else (lambda x: trg_vocab.get(x))
     # set it up
     _vec = np.full((len(src_vocab), out_size), 0., dtype=np.float32)
     assert src_vocab.non == 0
     _vec[0] = 1.  # by default: src-non is all valid!
     _vec[:,0] = 1.  # by default: trg-non is all valid!
     # --
     stat = {"k_skip": 0, "k_hit": 0, "v_skip": 0, "v_hit": 1}
     for k, v in cons.cmap.items():
         idx_k = src_vocab.get(k)
         if idx_k is None:
             stat["k_skip"] += 1
             continue  # skip no_hit!
         stat["k_hit"] += 1
         for k2 in v.keys():
             idx_k2 = _trg_get_f(k2)
             if idx_k2 is None:
                 stat["v_skip"] += 1
                 continue
             stat["v_hit"] += 1
             if trg_is_seq_vocab:
                 _vec[idx_k, idx_k2[0]:idx_k2[1]] = 1.  # hit range
             else:
                 _vec[idx_k, idx_k2] = 1.  # hit!!
     zlog(f"Setup ConstrainerNode ok: vec={_vec.shape}, stat={stat}")
     # --
     self.cons = cons
     self.src_vocab = src_vocab
     self.trg_vocab = trg_vocab
     self.vec = BK.input_real(_vec)
コード例 #20
0
ファイル: dec_upos.py プロジェクト: zzsfornlp/zmsp
 def decode_upos(self, ibatch, logprobs_t: BK.Expr):
     conf: ZDecoderUposConf = self.conf
     # get argmax label!
     pred_upos_scores, pred_upos_labels = logprobs_t.max(-1)  # [*, dlen]
     # arr_upos_scores, arr_upos_labels = BK.get_value(pred_upos_scores), BK.get_value(pred_upos_labels)
     arr_upos_labels = BK.get_value(pred_upos_labels)
     # put results
     voc = self.voc
     for bidx, item in enumerate(
             ibatch.items):  # for each item in the batch
         _dec_offsets = item.seq_info.dec_offsets
         for sidx, sent in enumerate(item.sents):
             if conf.msent_pred_center and (sidx != item.center_sidx):
                 continue  # skip non-center sent in this mode!
             _start = _dec_offsets[sidx]
             _len = len(sent)
             _upos_idxes = arr_upos_labels[bidx][_start:_start +
                                                 _len].tolist()
             _upos_labels = voc.seq_idx2word(_upos_idxes)
             sent.build_uposes(_upos_labels)
コード例 #21
0
ファイル: model.py プロジェクト: zzsfornlp/zmsp
 def predict_on_batch(self, insts: List, **kwargs):
     conf: ZsfpModelConf = self.conf
     self.refresh_batch(False)
     # --
     sents: List[Sent] = list(yield_sents(insts))
     with BK.no_grad_env():
         # batch run inside if input is doc
         sent_buckets = BatchHelper.group_buckets(
             sents,
             thresh_diff=conf.decode_sent_thresh_diff,
             thresh_all=conf.decode_sent_thresh_batch,
             size_f=lambda x: 1,
             sort_key=lambda x: len(x))
         for one_sents in sent_buckets:
             # emb and enc
             mask_expr, emb_expr, enc_expr = self._emb_and_enc(one_sents)
             # frame
             self.framer.predict(one_sents, enc_expr, mask_expr)
     # --
     info = {"inst": len(insts), "sent": len(sents)}
     return info
コード例 #22
0
ファイル: dec_udep.py プロジェクト: zzsfornlp/zmsp
 def get_label_mask(self, sels: List[str]):
     expand_sels = []
     for s in sels:
         if s in UD_CATEGORIES:
             expand_sels.extend(UD_CATEGORIES[s])
         else:
             expand_sels.append(s)
     expand_sels = sorted(set(expand_sels))
     voc = self.voc
     # --
     ret = np.zeros(len(voc))
     _cc = 0
     for s in expand_sels:
         if s in voc:
             ret[voc[s]] = 1.
             _cc += voc.word2count(s)
         else:
             zwarn(f"UNK dep label: {s}")
     _all_cc = voc.get_all_counts()
     zlog(f"Get label mask with {expand_sels}: {len(expand_sels)}=={ret.sum().item()} -> {_cc}/{_all_cc}={_cc/(_all_cc+1e-5)}")
     return BK.input_real(ret)
コード例 #23
0
ファイル: med.py プロジェクト: zzsfornlp/zmsp
 def get_val(self,
             idx=-1,
             stack_dim=-2,
             signature=None,
             function=None,
             no_cache=False):
     _k = (idx, stack_dim, signature)  # key for cache
     ret = None
     if not no_cache:
         ret = self._cache.get(_k)
     if ret is None:  # calculate!!
         if idx is None:
             v0 = BK.stack(self.vals, dim=stack_dim)  # [*, ilen, ND, *]
         else:
             v0 = self.vals[idx]  # [*, ilen, *]
         ret = function(
             v0) if function is not None else v0  # post-processing!
         if not no_cache:
             self._cache[_k] = ret  # store cache
     # --
     return ret
コード例 #24
0
ファイル: nmst.py プロジェクト: zzsfornlp/zmsp
def _common_nmst(CPU_f, scores_expr, mask_expr, lengths_arr, labeled, ret_arr):
    assert labeled
    with BK.no_grad_env():
        # argmax-label: [BS, m, h]
        scores_unlabeled_max, labels_argmax = scores_expr.max(-1)
        #
        scores_unlabeled_max_arr = BK.get_value(scores_unlabeled_max)
        mst_heads_arr, _, mst_scores_arr = CPU_f(scores_unlabeled_max_arr, lengths_arr, labeled=False)
        # [BS, m]
        mst_heads_expr = BK.input_idx(mst_heads_arr)
        mst_labels_expr = BK.gather_one_lastdim(labels_argmax, mst_heads_expr).squeeze(-1)
        # prepare for the outputs
        if ret_arr:
            return mst_heads_arr, BK.get_value(mst_labels_expr), mst_scores_arr
        else:
            return mst_heads_expr, mst_labels_expr, BK.input_real(mst_scores_arr)
コード例 #25
0
 def go_sample(
         self,
         input_expr: BK.Expr,
         input_mask: BK.Expr,  # input
         widx_expr: BK.Expr,
         wlen_expr: BK.Expr,
         span_mask: BK.Expr,
         rate: float = None,
         count: float = None,  # span
         gaddr_expr: BK.Expr = None,
         add_gold_rate: float = 0.):  # gold
     lookup_res = self.go_lookup(input_expr, widx_expr, wlen_expr,
                                 span_mask, gaddr_expr)  # [bsize, NUM, *]
     # --
     # rate is according to overall input length
     _tmp_len = (input_mask.sum(-1, keepdim=True) + 1e-5)
     sample_rate = self._determine_size(_tmp_len, rate,
                                        count) / _tmp_len  # [bsize, 1]
     sample_mask = (BK.rand(span_mask.shape) <
                    sample_rate).float()  # [bsize, NUM]
     # select and add_gold
     return self._go_common(lookup_res, sample_mask, add_gold_rate)
コード例 #26
0
ファイル: dec_mlm.py プロジェクト: zzsfornlp/zmsp
 def __init__(self, conf: ZDecoderMlmConf, ztask, main_enc: ZEncoder,
              **kwargs):
     super().__init__(conf, ztask, main_enc, **kwargs)
     conf: ZDecoderMlmConf = self.conf
     # --
     # mlm
     _enc_dim, _head_dim = main_enc.get_enc_dim(), main_enc.get_head_dim()
     # --
     _W = main_enc.get_embed_w()  # get input embeddings: [nword, D]
     self.target_size = BK.get_shape(_W, 0)
     self.mask_token_id = main_enc.tokenizer.mask_token_id  # note: specific one!!
     self.repl_ranges = conf.get_repl_ranges()
     # --
     self.lab_mlm = ZlabelNode(conf.lab_mlm, _csize=self.target_size)
     self.idec_mlm = conf.idec_mlm.make_node(
         _isize=_enc_dim,
         _nhead=_head_dim,
         _csize=self.lab_mlm.get_core_csize())
     self.reg_idec('mlm', self.idec_mlm)
     if conf.mlm_use_input_embed:
         zlog(f"Use input embed of {_W.T.shape} for output!")
         self.lab_mlm.aff_final.put_external_ws([(lambda: _W.T)])
コード例 #27
0
ファイル: soft.py プロジェクト: zzsfornlp/zmsp
 def _cand_score_and_select(self, input_expr: BK.Expr, mask_expr: BK.Expr):
     conf: SoftExtractorConf = self.conf
     # --
     cand_full_scores = self.cand_scorer(input_expr).squeeze(
         -1) + (1. - mask_expr) * Constants.REAL_PRAC_MIN  # [*, slen]
     # decide topk count
     len_t = mask_expr.sum(-1)  # [*]
     topk_t = (len_t * conf.cand_topk_rate).clamp(
         max=conf.cand_topk_count).ceil().long().unsqueeze(-1)  # [*, 1]
     # get topk mask
     if BK.is_zero_shape(mask_expr):
         topk_mask = mask_expr.clone(
         )  # no need to go topk since no elements
     else:
         topk_mask = select_topk(cand_full_scores,
                                 topk_t,
                                 mask_t=mask_expr,
                                 dim=-1)  # [*, slen]
     # thresh
     cand_decisions = topk_mask * (cand_full_scores >=
                                   conf.cand_thresh).float()  # [*, slen]
     return cand_full_scores, cand_decisions  # [*, slen]
コード例 #28
0
 def forward(self, med: ZMediator, **kwargs):
     conf: IdecConnectorPlainConf = self.conf
     # --
     if self.do_seq_pool:
         # note: for pooling, use the raw emb!!
         mixed_emb0 = self._go_detach(med.get_raw_last_emb())  # [*, ??, D]
         mixed_emb = self.pool_f(mixed_emb0)  # [*, D]
     else:
         if conf.use_nlayer == 1:  # simply get the last one
             mixed_emb = self._go_detach(med.get_last_emb())
         else:  # mix them
             stacked_embs = self._go_detach(med.get_stack_emb(
             ))[:, :, :, -len(self.mixed_weights):]  # [*, slen, D, NL]
             mixed_emb = BK.matmul(
                 stacked_embs,
                 BK.softmax(self.mixed_weights,
                            -1).unsqueeze(-1)).squeeze(-1)  # [*, slen, D]
         if self.do_seq_sel:
             _arange_t = BK.arange_idx(BK.get_shape(mixed_emb, 0))
             _idx_t = med.get_cache(conf.seq_sel_key)
             mixed_emb = mixed_emb[_arange_t, _idx_t]  # [*, D]
     # further affine
     if self.input_mask is not None:  # note: special input mask!!
         mixed_emb = mixed_emb * self.input_mask.detach(
         )  # no grad for input_mask!!
     drop_emb = self.pre_mid_drop(mixed_emb)
     if conf.mid_dim > 0:
         # gather inputs
         _r = conf.mid_extra_range
         _detached_drop_emb = drop_emb.detach()
         _inputs = []
         for ii in range(-_r, _r + 1):
             if ii < 0:
                 _one = BK.pad(_detached_drop_emb[:, :ii], [0, 0, -ii, 0])
             elif ii == 0:
                 _one = drop_emb  # no need to change!
             else:
                 _one = BK.pad(_detached_drop_emb[:, ii:], [0, 0, 0, ii])
             _inputs.append(_one)
         # --
         ret_t = self.mid_aff(_inputs)  # [*, slen, M] or [*, M]
     else:
         ret_t = drop_emb
     return ret_t
コード例 #29
0
 def loss(self,
          input_main: BK.Expr,
          input_pair: BK.Expr,
          input_mask: BK.Expr,
          gold_idxes: BK.Expr,
          loss_weight_expr: BK.Expr = None,
          extra_score: BK.Expr = None):
     # not normalize here!
     scores_t = self.score(input_main,
                           input_pair,
                           input_mask,
                           local_normalize=False,
                           extra_score=extra_score)  # [*, L]
     # negative log likelihood
     # all_losses_t = - scores_t.gather(-1, gold_idxes.unsqueeze(-1)).squeeze(-1) * input_mask  # [*]
     all_losses_t = BK.loss_nll(
         scores_t, gold_idxes,
         label_smoothing=self.conf.label_smoothing)  # [*]
     all_losses_t *= input_mask
     if loss_weight_expr is not None:
         all_losses_t *= loss_weight_expr  # [*]
     ret_loss = all_losses_t.sum()  # []
     ret_div = input_mask.sum()
     return (ret_loss, ret_div)
コード例 #30
0
ファイル: nmst.py プロジェクト: zzsfornlp/zmsp
def nmst_greedy(scores_expr, mask_expr, lengths_arr, labeled=True, ret_arr=False):
    assert labeled
    with BK.no_grad_env():
        scores_shape = BK.get_shape(scores_expr)
        maxlen = scores_shape[1]
        # mask out diag
        scores_expr += BK.diagflat(BK.constants([maxlen], Constants.REAL_PRAC_MIN)).unsqueeze(-1)
        # combined last two dimension and Max over them
        combined_scores_expr = scores_expr.view(scores_shape[:-2] + [-1])
        combine_max_scores, combined_max_idxes = BK.max(combined_scores_expr, dim=-1)
        # back to real idxes
        last_size = scores_shape[-1]
        greedy_heads = combined_max_idxes // last_size
        greedy_labels = combined_max_idxes % last_size
        if ret_arr:
            mst_heads_arr, mst_labels_arr, mst_scores_arr = [BK.get_value(z) for z in (greedy_heads, greedy_labels, combine_max_scores)]
            return mst_heads_arr, mst_labels_arr, mst_scores_arr
        else:
            return greedy_heads, greedy_labels, combine_max_scores