def decode_udep(self, ibatch, udep_logprobs_t: BK.Expr, root_logprobs_t: BK.Expr): conf: ZDecoderUdepConf = self.conf # -- arr_udep = BK.get_value(udep_logprobs_t.transpose(-2,-3)) # [*, m, h, L] arr_root = None if root_logprobs_t is None else BK.get_value(root_logprobs_t) # [*, dlen] _dim_label = arr_udep.shape[-1] _neg = -10000. # should be enough!! _voc, _lab_range = self.ztask.vpack _idx_root = self._label_idx_root # -- for bidx, item in enumerate(ibatch.items): # for each item in the batch _dec_offsets = item.seq_info.dec_offsets for sidx, sent in enumerate(item.sents): if conf.msent_pred_center and (sidx != item.center_sidx): continue # skip non-center sent in this mode! _start = _dec_offsets[sidx] _len = len(sent) _len_p1 = _len + 1 # -- _arr = np.full([_len_p1, _len_p1, _dim_label], _neg, dtype=np.float32) # [1+m, 1+h, L] # assign label scores _arr[1:_len_p1, 1:_len_p1, 1:_lab_range] = arr_udep[bidx, _start:_start+_len, _start:_start+_len, 1:_lab_range] # assign root scores if arr_root is not None: _arr[1:_len_p1, 0, _idx_root] = arr_root[bidx, _start:_start+_len] else: # todo(+N): currently simply assign a smaller "neg-inf" _arr[1:_len_p1, 0, _idx_root] = -99. # -- from msp2.tools.algo.nmst import mst_unproj # decoding algorithm arr_ret_heads, arr_ret_labels, arr_ret_scores = \ mst_unproj(_arr[None], np.asarray([_len_p1]), labeled=True) # [*, 1+slen] # assign list_dep_heads = arr_ret_heads[0, 1:_len_p1].tolist() list_dep_lidxes = arr_ret_labels[0, 1:_len_p1].tolist() list_dep_labels = _voc.seq_idx2word(list_dep_lidxes) sent.build_dep_tree(list_dep_heads, list_dep_labels)
def _transform_factors(self, factors: Union[List[List[int]], BK.Expr], is_orig: bool, PAD_IDX: Union[int, float]): if isinstance(factors, BK.Expr): # already padded batched_ids = factors else: padder = DataPadder(2, pad_vals=PAD_IDX) batched_ids, _ = padder.pad(factors) batched_ids = BK.input_idx( batched_ids) # [bsize, orig-len if is_orig else sub_len] if is_orig: # map to subtoks final_batched_ids = batched_ids[ self.arange2_t, self.batched_rev_idxes] # [bsize, sub_len] else: final_batched_ids = batched_ids # [bsize, sub_len] return final_batched_ids
def score(self, input_main: BK.Expr, input_pair: BK.Expr, input_mask: BK.Expr, left_constraints: BK.Expr = None, right_constraints: BK.Expr = None): conf: SpanExpanderConf = self.conf # -- # left & right rets = [] seq_shape = BK.get_shape(input_mask) cur_mask = input_mask arange_t = BK.arange_idx(seq_shape[-1]).view( [1] * (len(seq_shape) - 1) + [-1]) # [*, slen] for scorer, cons_t in zip([self.s_left, self.s_right], [left_constraints, right_constraints]): mm = cur_mask if cons_t is None else ( cur_mask * (arange_t <= cons_t).float()) # [*, slen] ss = scorer( input_main, None if input_pair is None else input_pair.unsqueeze(-2), mm).squeeze(-1) # [*, slen] rets.append(ss) return rets[0], rets[1] # [*, slen] (already masked)
def forward(self, input_map: Dict): add_bos, add_eos = self.conf.add_bos, self.conf.add_eos ret = OrderedDict() # [*, len, ?] for key, embedder_pack in self.embedders.items( ): # according to REG order!! embedder, input_name = embedder_pack one_expr = embedder(input_map[input_name], add_bos=add_bos, add_eos=add_eos) ret[key] = one_expr # mask expr mask_expr = input_map.get("mask") if mask_expr is not None: all_input_slices = [] mask_slice = BK.constants(BK.get_shape(mask_expr)[:-1] + [1], 1, dtype=mask_expr.dtype) # [*, 1] if add_bos: all_input_slices.append(mask_slice) all_input_slices.append(mask_expr) if add_eos: all_input_slices.append(mask_slice) mask_expr = BK.concat(all_input_slices, -1) # [*, ?+len+?] return mask_expr, ret
def predict(self, med: ZMediator): conf: ZDecoderSRLConf = self.conf insts, mask_expr = med.insts, med.get_mask_t() # -- pred_evt_labels, pred_evt_scores = self._pred_evt() pred_arg_labels, pred_arg_scores = self._pred_arg(mask_expr, pred_evt_labels) # transfer data from gpu also counts (also make sure gpu calculations are done)! all_arrs = [BK.get_value(z) for z in [pred_evt_labels, pred_evt_scores, pred_arg_labels, pred_arg_scores]] # ===== # assign; also record post-processing (non-computing) time time0 = time.time() self.helper.put_results(insts, all_arrs) time1 = time.time() # -- return {f"{self.name}_posttime": time1-time0}
def _merge_cf_geo(self, all_cfs: List[BK.Expr]): _temp = self.current_temperature accu_cfs = [] remainings = None for cf in all_cfs: cf_prob = cf.sigmoid() if _temp == 1. else (cf / _temp).sigmoid() # [*] if remainings is None: accu_cfs.append(cf_prob) remainings = 1. - cf_prob else: accu_cfs.append(cf_prob * remainings) remainings = remainings * (1. - cf_prob) # add back to the final one accu_cfs[-1] += remainings return BK.stack(accu_cfs, -1) # [*, NL]
def _loss_feed_split(self, mask_expr, split_scores, pred_split_decisions, cand_widxes, cand_masks, cand_expr, cand_scores, expr_seq_gaddr): conf: SoftExtractorConf = self.conf bsize, slen = BK.get_shape(mask_expr) arange2_t = BK.arange_idx(bsize).unsqueeze(-1) # [*, 1, 1] # -- # step 2.1: split loss (only on good points (excluding -1|-1 or paddings) with dynamic oracle) cand_gaddr = expr_seq_gaddr[arange2_t, cand_widxes] # [*, clen] cand_gaddr0, cand_gaddr1 = cand_gaddr[:, : -1], cand_gaddr[:, 1:] # [*, clen-1] split_oracle = (cand_gaddr0 != cand_gaddr1).float() * cand_masks[:, 1:] # [*, clen-1] split_oracle_mask = ( (cand_gaddr0 >= 0) | (cand_gaddr1 >= 0)).float() * cand_masks[:, 1:] # [*, clen-1] raw_split_loss = BK.loss_binary( split_scores, split_oracle, label_smoothing=conf.split_label_smoothing) # [*, slen] loss_split_item = LossHelper.compile_leaf_loss( f"split", (raw_split_loss * split_oracle_mask).sum(), split_oracle_mask.sum(), loss_lambda=conf.loss_split) # step 2.2: feed split # note: when teacher-forcing, only forcing good points, others still use pred force_split_decisions = split_oracle_mask * split_oracle + ( 1. - split_oracle_mask) * pred_split_decisions # [*, clen-1] _use_force_mask = (BK.rand([bsize]) <= conf.split_feed_force_rate).float().unsqueeze( -1) # [*, 1], seq-level feed_split_decisions = (_use_force_mask * force_split_decisions + (1. - _use_force_mask) * pred_split_decisions ) # [*, clen-1] # next # *[*, seglen, MW], [*, seglen] seg_ext_cidxes, seg_ext_masks, seg_masks = self._split_extend( feed_split_decisions, cand_masks) seg_ext_scores, seg_ext_cidxes, seg_ext_widxes, seg_ext_masks, seg_weighted_expr = self._split_aggregate( cand_expr, cand_scores, cand_widxes, seg_ext_cidxes, seg_ext_masks, conf.split_topk) # [*, seglen, ?] # finally get oracles for next steps # todo(+N): simply select the highest scored one as oracle if BK.is_zero_shape(seg_ext_scores): # simply make them all -1 oracle_gaddr = BK.constants_idx(seg_masks.shape, -1) # [*, seglen] else: _, _seg_max_t = seg_ext_scores.max(-1, keepdim=True) # [*, seglen, 1] oracle_widxes = seg_ext_widxes.gather(-1, _seg_max_t).squeeze( -1) # [*, seglen] oracle_gaddr = expr_seq_gaddr.gather(-1, oracle_widxes) # [*, seglen] oracle_gaddr[seg_masks <= 0] = -1 # (assign invalid ones) [*, seglen] return loss_split_item, seg_masks, seg_ext_widxes, seg_ext_masks, seg_weighted_expr, oracle_gaddr
def do_score(self, query, key): conf: AttentionPlainConf = self.conf query_len = BK.get_shape(query, -2) key_len = BK.get_shape(key, -2) # -- # 1. project query_up = self._shape_project(self.affine_q(query), conf.nh_qk) # [*, Hin, len_q, d_qk] key_up = self._shape_project(self.affine_k(key), conf.nh_qk) # [*, Hin, len_k, d_qk] # 2. score query_up = query_up / self._att_scale scores_t = BK.matmul(query_up, BK.transpose(key_up, -1, -2)) # [*, Hin, len_q, len_k] if conf.use_rposi: distance, distance_out, _ = self.rposi.embed_lens(query_len, key_len) # avoid broadcast! _d_bs, _d_h, _d_q, _d_d = BK.get_shape(query_up) query_up0 = BK.reshape(query_up.transpose(2, 1).transpose(1, 0), [_d_q, _d_bs * _d_h, _d_d]) add_term0 = BK.matmul(query_up0, distance_out.transpose(-1, -2)) # [len_q, head*bs, len_k] add_term = BK.reshape(add_term0.transpose(0, 1), BK.get_shape(scores_t)) # -- scores_t += add_term # [*, Hin, len_q, len_k] # todo(note): no dropout here, if use this at outside, need extra one!! return scores_t # [*, Hin, len_q, len_k]
def apply_piece_pooling(t: BK.Expr, piece: int, f: Union[Callable, str] = ActivationHelper.get_pool('max'), dim: int = -1): # first do things like chunk by piece if piece == 1: return t # nothing to do # reshape orig_shape = BK.get_shape(t) if dim < 0: # should do this! dim = len(orig_shape) + dim orig_shape[dim] = piece # replace it with piece new_shape = orig_shape[:dim] + [-1] + orig_shape[dim:] # put before it reshaped_t = t.view(new_shape) # [..., -1, piece, ...] if isinstance(f, str): f = ActivationHelper.get_pool(f) return f(reshaped_t, dim + 1) # +1 since we make a new dim
def forward(self, input_expr: BK.Expr, widx_expr: BK.Expr, wlen_expr: BK.Expr): conf: BaseSpanConf = self.conf # -- # note: check empty, otherwise error input_item_shape = BK.get_shape(widx_expr) if np.prod(input_item_shape) == 0: return BK.zeros(input_item_shape + [self.output_dim]) # return an empty but shaped tensor # -- start_idxes, end_idxes = widx_expr, widx_expr+wlen_expr # make [start, end) # get sizes bsize, slen = BK.get_shape(input_expr)[:2] # num_span = BK.get_shape(start_idxes, 1) arange2_t = BK.arange_idx(bsize).unsqueeze(-1) # [bsize, 1] # -- reprs = [] if conf.use_starts: # start [start, reprs.append(input_expr[arange2_t, start_idxes]) # [bsize, ?, D] if conf.use_ends: # simply ,end-1] reprs.append(input_expr[arange2_t, end_idxes-1]) if conf.use_softhead: # expand range all_span_idxes, all_span_mask = expand_ranged_idxes(widx_expr, wlen_expr, 0, None) # [bsize, ?, MW] # flatten flatten_all_span_idxes = all_span_idxes.view(bsize, -1) # [bsize, ?*MW] flatten_all_span_mask = all_span_mask.view(bsize, -1) # [bsize, ?*MW] # get softhead score (consider mask here) softhead_scores = self.softhead_scorer(input_expr).squeeze(-1) # [bsize, slen] flatten_all_span_scores = softhead_scores[arange2_t, flatten_all_span_idxes] # [bsize, ?*MW] flatten_all_span_scores += (1.-flatten_all_span_mask) * Constants.REAL_PRAC_MIN all_span_scores = flatten_all_span_scores.view(all_span_idxes.shape) # [bsize, ?, MW] # reshape and (optionally topk) and softmax softhead_topk = conf.softhead_topk if softhead_topk>0 and BK.get_shape(all_span_scores,-1)>softhead_topk: # further select topk; note: this may save mem final_span_score, _tmp_idxes = all_span_scores.topk(softhead_topk, dim=-1, sorted=False) # [bsize, ?, K] final_span_idxes = all_span_idxes.gather(-1, _tmp_idxes) # [bsize, ?, K] else: final_span_score, final_span_idxes = all_span_scores, all_span_idxes # [bsize, ?, MW] final_prob = final_span_score.softmax(-1) # [bsize, ?, ??] # [bsize, ?, ??, D] final_repr = input_expr[arange2_t, final_span_idxes.view(bsize, -1)].view(BK.get_shape(final_span_idxes)+[-1]) weighted_repr = (final_repr * final_prob.unsqueeze(-1)).sum(-2) # [bsize, ?, D] reprs.append(weighted_repr) if conf.use_width: cur_width_embed = self.width_embed(wlen_expr) # [bsize, ?, DE] reprs.append(cur_width_embed) # concat concat_repr = BK.concat(reprs, -1) # [bsize, ?, SUM] if conf.use_proj: ret = self.final_proj(concat_repr) # [bsize, ?, DR] else: ret = concat_repr return ret
def predict_on_batch(self, ibatch: InputBatch, **kwargs): with BK.no_grad_env(): self.refresh_batch(False) self._mark_active(ibatch) # -- # restart self.encoder.restart(ibatch, self.med) # prepare self.med.do_prep_enc() # enc forward self.encoder.forward(self.med) # get all losses pred_info = self.med.do_preds() # -- info = {"inst": len(ibatch), "ff": 1} info.update(pred_info) self.med.restart(None) # clean return info
def init_everything(main_conf: Conf, args: Iterable[str], add_utils=True, add_nn=True): list_args = list(args) # store it! gconf = get_singleton_global_conf() # utils? if add_utils: # first we try to init a Msp2UtilsConf to allow logging! utils_conf = Msp2UtilsConf() utils_conf.update_from_args(list_args, quite=True, check=False, add_global_key='') init(utils_conf) # add to gconf! gconf.add_subconf("utils", Msp2UtilsConf()) # nn? if add_nn: from msp2.nn import BK gconf.add_subconf("nn", BK.BKNIConf()) # -- # then actual init all_argv = main_conf.update_from_args(list_args) # -- # init utils if add_utils: # write conf? if utils_conf.conf_output: with zopen(utils_conf.conf_output, 'w') as fd: for k, v in all_argv.items(): # todo(note): do not save this one!! if k.split(".")[-1] not in [ "conf_output", "log_file", "log_files" ]: fd.write(f"{k}:{v}\n") # no need to re-init # -- # init nn if add_nn: from msp2.nn import init as nn_init nn_init(gconf.nn) # -- return main_conf
def load(self, path, strict=None): if strict is not None: BK.load_model(self, path, strict=strict) else: # otherwise, first try strict, then relax if there are errors try: BK.load_model(self, path, strict=True) except: import traceback zlog( f"#== Error in strict loading:\n{traceback.format_exc()}\n#==" ) BK.load_model(self, path, strict=False) zlog(f"Load {self} from {path}.", func="io")
def put_results(self, insts: List[Sent], best_evt_labs, best_evt_scores, best_arg_labs, best_arg_scores): conf: MySRLConf = self.conf _evt_pred_use_posi = conf.evt_pred_use_posi vocab_evt = self.vocab_evt vocab_arg = self.vocab_arg if conf.arg_use_bio: real_vocab_arg = vocab_arg.base_vocab else: real_vocab_arg = vocab_arg # -- all_arrs = [BK.get_value(z) for z in [best_evt_labs, best_evt_scores, best_arg_labs, best_arg_scores]] for bidx, inst in enumerate(insts): inst.delete_frames(conf.arg_ftag) # delete old args # -- cur_len = len(inst) cur_evt_labs, cur_evt_scores, cur_arg_labs, cur_arg_scores = [z[bidx][:cur_len] for z in all_arrs] inst.info["evt_lab"] = [vocab_evt.idx2word(z) if z>0 else 'O' for z in cur_evt_labs] # -- if _evt_pred_use_posi: # special mode for evt in inst.get_frames(conf.evt_ftag): # reuse posi but re-assign label! one_widx = evt.mention.shead_widx one_lab, one_score = cur_evt_labs[one_widx].item(), cur_evt_scores[one_widx].item() evt.set_label(vocab_evt.idx2word(one_lab)) evt.set_label_idx(one_lab) evt.score = one_score # args new_arg_scores = cur_arg_scores[one_widx][:cur_len] new_arg_label_idxes = cur_arg_labs[one_widx][:cur_len] self.decode_arg(evt, new_arg_label_idxes, new_arg_scores, vocab_arg, real_vocab_arg) else: # pred everything! inst.delete_frames(conf.evt_ftag) for one_widx in range(cur_len): one_lab, one_score = cur_evt_labs[one_widx].item(), cur_evt_scores[one_widx].item() if one_lab == 0: continue # make new evt new_evt = inst.make_frame(one_widx, 1, conf.evt_ftag, type=vocab_evt.idx2word(one_lab), score=one_score) new_evt.set_label_idx(one_lab) self.evt_span_setter(new_evt.mention, one_widx, 1) # args new_arg_scores = cur_arg_scores[one_widx][:cur_len] new_arg_label_idxes = cur_arg_labs[one_widx][:cur_len] self.decode_arg(new_evt, new_arg_label_idxes, new_arg_scores, vocab_arg, real_vocab_arg)
def _common_prepare(self, input_shape: Tuple[int], _mask_f: Callable, gold_widx_expr: BK.Expr, gold_wlen_expr: BK.Expr, gold_addr_expr: BK.Expr): conf: SpanExtractorConf = self.conf min_width, max_width = conf.min_width, conf.max_width diff_width = max_width - min_width + 1 # number of width to extract # -- bsize, mlen = input_shape # -- # [bsize, mlen*(max_width-min_width)], mlen first (dim=1) # note: the spans are always sorted by (widx, wlen) _tmp_arange_t = BK.arange_idx(mlen * diff_width) # [mlen*dw] widx_t0 = (_tmp_arange_t // diff_width) # [mlen*dw] wlen_t0 = (_tmp_arange_t % diff_width) + min_width # [mlen*dw] mask_t0 = _mask_f(widx_t0, wlen_t0) # [bsize, mlen*dw] # -- # compacting (use mask2idx and gather) final_idx_t, final_mask_t = BK.mask2idx(mask_t0, padding_idx=0) # [bsize, ??] _tmp2_arange_t = BK.arange_idx(bsize).unsqueeze(1) # [bsize, 1] # no need to make valid for mask=0, since idx=0 means (0, min_width) # todo(+?): do we need to deal with empty ones here? ret_widx = widx_t0[final_idx_t] # [bsize, ??] ret_wlen = wlen_t0[final_idx_t] # [bsize, ??] # -- # prepare gold (as pointer-like addresses) if gold_addr_expr is not None: gold_t0 = BK.constants_idx((bsize, mlen * diff_width), -1) # [bsize, mlen*diff] # check valid of golds (flatten all) gold_valid_t = ((gold_addr_expr >= 0) & (gold_wlen_expr >= min_width) & (gold_wlen_expr <= max_width)) gold_valid_t = gold_valid_t.view(-1) # [bsize*_glen] _glen = BK.get_shape(gold_addr_expr, 1) flattened_bsize_t = BK.arange_idx( bsize * _glen) // _glen # [bsize*_glen] flattened_fidx_t = (gold_widx_expr * diff_width + gold_wlen_expr - min_width).view(-1) # [bsize*_glen] flattened_gaddr_t = gold_addr_expr.view(-1) # mask and assign gold_t0[flattened_bsize_t[gold_valid_t], flattened_fidx_t[gold_valid_t]] = flattened_gaddr_t[ gold_valid_t] ret_gaddr = gold_t0[_tmp2_arange_t, final_idx_t] # [bsize, ??] ret_gaddr.masked_fill_((final_mask_t == 0), -1) # make invalid ones -1 else: ret_gaddr = None # -- return ret_widx, ret_wlen, final_mask_t, ret_gaddr
def judge(self, scores: BK.Expr, cf_scores: BK.Expr, mask_t: BK.Expr): conf: SeqExitHelperConf = self.conf # -- if self.is_cf and conf.cf_use_seq: # in this mode, already aggr_metrics return (cf_scores.squeeze(-1) / conf.cf_scale) >= conf.exit_thresh # -- if self.is_cf: # Geometric-like qt # aggr_metrics = (cf_scores.squeeze(-1)).sigmoid() # [*] seq_metrics = cf_scores.squeeze(-1) / conf.cf_scale # [*, slen] else: seq_metrics = self._cri_f(scores) # [*, slen] # -- seq_metrics = (1. - mask_t) + mask_t * seq_metrics # put 1. at mask place! slen = BK.get_shape(seq_metrics, -1) K = self._getk(slen, conf.exit_min_k) aggr_metrics = topk_avg(seq_metrics, mask_t, K, dim=-1, largest=False) # [*] return aggr_metrics >= conf.exit_thresh
def score_on_batch(self, insts: List, **kwargs): conf: ZmtlModelConf = self.conf # -- with BK.no_grad_env(): self.refresh_batch(conf.score_training_flag) actual_insts = list(self._yield_insts(insts)) # forward enc med = self.med enc_cached_input = self.enc.prepare_inputs(actual_insts * conf.score_times) # multiple times self.enc.forward(None, med, cached_input=enc_cached_input) # do score with dec # note: do we need to split here? info_counter = med.do_scores(orig_insts=actual_insts) # -- info = {"inst0": len(insts), "inst": len(actual_insts), "forw": 1} info.update(info_counter) # -- med.restart() return info
def _loss_upos(self, mask_expr, expr_upos_labels): conf: ZDecoderUPOSConf = self.conf all_upos_scores = self.upos_node.buffer_scores.values() # [*, slen, L] all_upos_losses = [] for one_upos_scores in all_upos_scores: one_losses = BK.loss_nll(one_upos_scores, expr_upos_labels, label_smoothing=conf.upos_label_smoothing) all_upos_losses.append(one_losses) upos_loss_results = self.upos_node.helper.loss( all_losses=all_upos_losses) loss_items = [] for loss_t, loss_alpha, loss_name in upos_loss_results: one_upos_item = LossHelper.compile_leaf_loss( "upos" + loss_name, (loss_t * mask_expr).sum(), mask_expr.sum(), loss_lambda=(loss_alpha * conf.loss_upos)) loss_items.append(one_upos_item) return loss_items
def __init__(self, cons: Constrainer, src_vocab: SimpleVocab, trg_vocab: SimpleVocab, conf: ConstrainerNodeConf, **kwargs): super().__init__(conf, **kwargs) conf: ConstrainerNodeConf = self.conf # -- # input vocab if src_vocab is None: # make our own src_vocab cons_keys = sorted(cons.cmap.keys()) # simply get all the keys src_vocab = SimpleVocab.build_by_static(cons_keys, pre_list=["non"], post_list=None) # non==0! # output vocab assert trg_vocab is not None out_size = len(trg_vocab) # output size is len(trg_vocab) trg_is_seq_vocab = isinstance(trg_vocab, SeqVocab) _trg_get_f = (lambda x: trg_vocab.get_range_by_basename(x)) if trg_is_seq_vocab else (lambda x: trg_vocab.get(x)) # set it up _vec = np.full((len(src_vocab), out_size), 0., dtype=np.float32) assert src_vocab.non == 0 _vec[0] = 1. # by default: src-non is all valid! _vec[:,0] = 1. # by default: trg-non is all valid! # -- stat = {"k_skip": 0, "k_hit": 0, "v_skip": 0, "v_hit": 1} for k, v in cons.cmap.items(): idx_k = src_vocab.get(k) if idx_k is None: stat["k_skip"] += 1 continue # skip no_hit! stat["k_hit"] += 1 for k2 in v.keys(): idx_k2 = _trg_get_f(k2) if idx_k2 is None: stat["v_skip"] += 1 continue stat["v_hit"] += 1 if trg_is_seq_vocab: _vec[idx_k, idx_k2[0]:idx_k2[1]] = 1. # hit range else: _vec[idx_k, idx_k2] = 1. # hit!! zlog(f"Setup ConstrainerNode ok: vec={_vec.shape}, stat={stat}") # -- self.cons = cons self.src_vocab = src_vocab self.trg_vocab = trg_vocab self.vec = BK.input_real(_vec)
def decode_upos(self, ibatch, logprobs_t: BK.Expr): conf: ZDecoderUposConf = self.conf # get argmax label! pred_upos_scores, pred_upos_labels = logprobs_t.max(-1) # [*, dlen] # arr_upos_scores, arr_upos_labels = BK.get_value(pred_upos_scores), BK.get_value(pred_upos_labels) arr_upos_labels = BK.get_value(pred_upos_labels) # put results voc = self.voc for bidx, item in enumerate( ibatch.items): # for each item in the batch _dec_offsets = item.seq_info.dec_offsets for sidx, sent in enumerate(item.sents): if conf.msent_pred_center and (sidx != item.center_sidx): continue # skip non-center sent in this mode! _start = _dec_offsets[sidx] _len = len(sent) _upos_idxes = arr_upos_labels[bidx][_start:_start + _len].tolist() _upos_labels = voc.seq_idx2word(_upos_idxes) sent.build_uposes(_upos_labels)
def predict_on_batch(self, insts: List, **kwargs): conf: ZsfpModelConf = self.conf self.refresh_batch(False) # -- sents: List[Sent] = list(yield_sents(insts)) with BK.no_grad_env(): # batch run inside if input is doc sent_buckets = BatchHelper.group_buckets( sents, thresh_diff=conf.decode_sent_thresh_diff, thresh_all=conf.decode_sent_thresh_batch, size_f=lambda x: 1, sort_key=lambda x: len(x)) for one_sents in sent_buckets: # emb and enc mask_expr, emb_expr, enc_expr = self._emb_and_enc(one_sents) # frame self.framer.predict(one_sents, enc_expr, mask_expr) # -- info = {"inst": len(insts), "sent": len(sents)} return info
def get_label_mask(self, sels: List[str]): expand_sels = [] for s in sels: if s in UD_CATEGORIES: expand_sels.extend(UD_CATEGORIES[s]) else: expand_sels.append(s) expand_sels = sorted(set(expand_sels)) voc = self.voc # -- ret = np.zeros(len(voc)) _cc = 0 for s in expand_sels: if s in voc: ret[voc[s]] = 1. _cc += voc.word2count(s) else: zwarn(f"UNK dep label: {s}") _all_cc = voc.get_all_counts() zlog(f"Get label mask with {expand_sels}: {len(expand_sels)}=={ret.sum().item()} -> {_cc}/{_all_cc}={_cc/(_all_cc+1e-5)}") return BK.input_real(ret)
def get_val(self, idx=-1, stack_dim=-2, signature=None, function=None, no_cache=False): _k = (idx, stack_dim, signature) # key for cache ret = None if not no_cache: ret = self._cache.get(_k) if ret is None: # calculate!! if idx is None: v0 = BK.stack(self.vals, dim=stack_dim) # [*, ilen, ND, *] else: v0 = self.vals[idx] # [*, ilen, *] ret = function( v0) if function is not None else v0 # post-processing! if not no_cache: self._cache[_k] = ret # store cache # -- return ret
def _common_nmst(CPU_f, scores_expr, mask_expr, lengths_arr, labeled, ret_arr): assert labeled with BK.no_grad_env(): # argmax-label: [BS, m, h] scores_unlabeled_max, labels_argmax = scores_expr.max(-1) # scores_unlabeled_max_arr = BK.get_value(scores_unlabeled_max) mst_heads_arr, _, mst_scores_arr = CPU_f(scores_unlabeled_max_arr, lengths_arr, labeled=False) # [BS, m] mst_heads_expr = BK.input_idx(mst_heads_arr) mst_labels_expr = BK.gather_one_lastdim(labels_argmax, mst_heads_expr).squeeze(-1) # prepare for the outputs if ret_arr: return mst_heads_arr, BK.get_value(mst_labels_expr), mst_scores_arr else: return mst_heads_expr, mst_labels_expr, BK.input_real(mst_scores_arr)
def go_sample( self, input_expr: BK.Expr, input_mask: BK.Expr, # input widx_expr: BK.Expr, wlen_expr: BK.Expr, span_mask: BK.Expr, rate: float = None, count: float = None, # span gaddr_expr: BK.Expr = None, add_gold_rate: float = 0.): # gold lookup_res = self.go_lookup(input_expr, widx_expr, wlen_expr, span_mask, gaddr_expr) # [bsize, NUM, *] # -- # rate is according to overall input length _tmp_len = (input_mask.sum(-1, keepdim=True) + 1e-5) sample_rate = self._determine_size(_tmp_len, rate, count) / _tmp_len # [bsize, 1] sample_mask = (BK.rand(span_mask.shape) < sample_rate).float() # [bsize, NUM] # select and add_gold return self._go_common(lookup_res, sample_mask, add_gold_rate)
def __init__(self, conf: ZDecoderMlmConf, ztask, main_enc: ZEncoder, **kwargs): super().__init__(conf, ztask, main_enc, **kwargs) conf: ZDecoderMlmConf = self.conf # -- # mlm _enc_dim, _head_dim = main_enc.get_enc_dim(), main_enc.get_head_dim() # -- _W = main_enc.get_embed_w() # get input embeddings: [nword, D] self.target_size = BK.get_shape(_W, 0) self.mask_token_id = main_enc.tokenizer.mask_token_id # note: specific one!! self.repl_ranges = conf.get_repl_ranges() # -- self.lab_mlm = ZlabelNode(conf.lab_mlm, _csize=self.target_size) self.idec_mlm = conf.idec_mlm.make_node( _isize=_enc_dim, _nhead=_head_dim, _csize=self.lab_mlm.get_core_csize()) self.reg_idec('mlm', self.idec_mlm) if conf.mlm_use_input_embed: zlog(f"Use input embed of {_W.T.shape} for output!") self.lab_mlm.aff_final.put_external_ws([(lambda: _W.T)])
def _cand_score_and_select(self, input_expr: BK.Expr, mask_expr: BK.Expr): conf: SoftExtractorConf = self.conf # -- cand_full_scores = self.cand_scorer(input_expr).squeeze( -1) + (1. - mask_expr) * Constants.REAL_PRAC_MIN # [*, slen] # decide topk count len_t = mask_expr.sum(-1) # [*] topk_t = (len_t * conf.cand_topk_rate).clamp( max=conf.cand_topk_count).ceil().long().unsqueeze(-1) # [*, 1] # get topk mask if BK.is_zero_shape(mask_expr): topk_mask = mask_expr.clone( ) # no need to go topk since no elements else: topk_mask = select_topk(cand_full_scores, topk_t, mask_t=mask_expr, dim=-1) # [*, slen] # thresh cand_decisions = topk_mask * (cand_full_scores >= conf.cand_thresh).float() # [*, slen] return cand_full_scores, cand_decisions # [*, slen]
def forward(self, med: ZMediator, **kwargs): conf: IdecConnectorPlainConf = self.conf # -- if self.do_seq_pool: # note: for pooling, use the raw emb!! mixed_emb0 = self._go_detach(med.get_raw_last_emb()) # [*, ??, D] mixed_emb = self.pool_f(mixed_emb0) # [*, D] else: if conf.use_nlayer == 1: # simply get the last one mixed_emb = self._go_detach(med.get_last_emb()) else: # mix them stacked_embs = self._go_detach(med.get_stack_emb( ))[:, :, :, -len(self.mixed_weights):] # [*, slen, D, NL] mixed_emb = BK.matmul( stacked_embs, BK.softmax(self.mixed_weights, -1).unsqueeze(-1)).squeeze(-1) # [*, slen, D] if self.do_seq_sel: _arange_t = BK.arange_idx(BK.get_shape(mixed_emb, 0)) _idx_t = med.get_cache(conf.seq_sel_key) mixed_emb = mixed_emb[_arange_t, _idx_t] # [*, D] # further affine if self.input_mask is not None: # note: special input mask!! mixed_emb = mixed_emb * self.input_mask.detach( ) # no grad for input_mask!! drop_emb = self.pre_mid_drop(mixed_emb) if conf.mid_dim > 0: # gather inputs _r = conf.mid_extra_range _detached_drop_emb = drop_emb.detach() _inputs = [] for ii in range(-_r, _r + 1): if ii < 0: _one = BK.pad(_detached_drop_emb[:, :ii], [0, 0, -ii, 0]) elif ii == 0: _one = drop_emb # no need to change! else: _one = BK.pad(_detached_drop_emb[:, ii:], [0, 0, 0, ii]) _inputs.append(_one) # -- ret_t = self.mid_aff(_inputs) # [*, slen, M] or [*, M] else: ret_t = drop_emb return ret_t
def loss(self, input_main: BK.Expr, input_pair: BK.Expr, input_mask: BK.Expr, gold_idxes: BK.Expr, loss_weight_expr: BK.Expr = None, extra_score: BK.Expr = None): # not normalize here! scores_t = self.score(input_main, input_pair, input_mask, local_normalize=False, extra_score=extra_score) # [*, L] # negative log likelihood # all_losses_t = - scores_t.gather(-1, gold_idxes.unsqueeze(-1)).squeeze(-1) * input_mask # [*] all_losses_t = BK.loss_nll( scores_t, gold_idxes, label_smoothing=self.conf.label_smoothing) # [*] all_losses_t *= input_mask if loss_weight_expr is not None: all_losses_t *= loss_weight_expr # [*] ret_loss = all_losses_t.sum() # [] ret_div = input_mask.sum() return (ret_loss, ret_div)
def nmst_greedy(scores_expr, mask_expr, lengths_arr, labeled=True, ret_arr=False): assert labeled with BK.no_grad_env(): scores_shape = BK.get_shape(scores_expr) maxlen = scores_shape[1] # mask out diag scores_expr += BK.diagflat(BK.constants([maxlen], Constants.REAL_PRAC_MIN)).unsqueeze(-1) # combined last two dimension and Max over them combined_scores_expr = scores_expr.view(scores_shape[:-2] + [-1]) combine_max_scores, combined_max_idxes = BK.max(combined_scores_expr, dim=-1) # back to real idxes last_size = scores_shape[-1] greedy_heads = combined_max_idxes // last_size greedy_labels = combined_max_idxes % last_size if ret_arr: mst_heads_arr, mst_labels_arr, mst_scores_arr = [BK.get_value(z) for z in (greedy_heads, greedy_labels, combine_max_scores)] return mst_heads_arr, mst_labels_arr, mst_scores_arr else: return greedy_heads, greedy_labels, combine_max_scores