def _get_loss_mask(self, pos_t: BK.Expr, valid_t: BK.Expr, loss_neg_sample: float = None): conf: ZLabelConf = self.conf # use default config if not from outside! _loss_neg_sample = conf.loss_neg_sample if loss_neg_sample is None else loss_neg_sample # -- if _loss_neg_sample >= 1.: # all valid is ok! return valid_t # -- pos_t = pos_t * valid_t # should also filter pos here! # first get sample rate if _loss_neg_sample >= 0.: # percentage to valid _rate = _loss_neg_sample # directly it!! else: # ratio to pos _count_pos = pos_t.sum() _count_valid = valid_t.sum() _rate = (-_loss_neg_sample) * ( (_count_pos + 1) / (_count_valid - _count_pos + 1)) # add-1 to make it >0 # random select! ret_t = (BK.rand(valid_t.shape) <= _rate).float() * valid_t ret_t += pos_t # also include pos ones! ret_t.clamp_(max=1.) return ret_t
def forward(self, input_t: BK.Expr, edges: BK.Expr, mask_t: BK.Expr): _isize = self.conf._isize _ntype = self.conf.type_num _slen = BK.get_shape(edges, -1) # -- edges3 = edges.clamp(min=-1, max=1) + 1 edgesF = edges + _ntype # offset to positive! # get hid hid0 = BK.matmul(input_t, self.W_hid).view( BK.get_shape(input_t)[:-1] + [3, _isize]) # [*, L, 3, D] hid1 = hid0.unsqueeze(-4).expand(-1, _slen, -1, -1, -1) # [*, L, L, 3, D] hid2 = BK.gather_first_dims(hid1.contiguous(), edges3.unsqueeze(-1), -2).squeeze(-2) # [*, L, L, D] hidB = self.b_hid[edgesF] # [*, L, L, D] _hid = hid2 + hidB # get gate gate0 = BK.matmul(input_t, self.W_gate) # [*, L, 3] gate1 = gate0.unsqueeze(-3).expand(-1, _slen, -1, -1) # [*, L, L, 3] gate2 = gate1.gather(-1, edges3.unsqueeze(-1)) # [*, L, L, 1] gateB = self.b_gate[edgesF].unsqueeze(-1) # [*, L, L, 1] _gate0 = BK.sigmoid(gate2 + gateB) _gmask0 = ( (edges != 0) | (BK.eye(_slen) > 0)).float() * mask_t.unsqueeze(-2) # [*,L,L] _gate = _gate0 * _gmask0.unsqueeze(-1) # [*,L,L,1] # combine h0 = BK.relu((_hid * _gate).sum(-2)) # [*, L, D] h1 = self.drop_node(h0) # add & norm? if self.ln is not None: h1 = self.ln(h1 + input_t) return h1
def loss(self, unary_scores: BK.Expr, input_mask: BK.Expr, gold_idxes: BK.Expr): mat_t = self.bigram.get_matrix() # [L, L] if BK.is_zero_shape(unary_scores): # note: avoid empty potential_t = BK.zeros(BK.get_shape(unary_scores)[:-2]) # [*] else: potential_t = BigramInferenceHelper.inference_forward(unary_scores, mat_t, input_mask, self.conf.crf_beam) # [*] gold_single_scores_t = unary_scores.gather(-1, gold_idxes.unsqueeze(-1)).squeeze(-1) * input_mask # [*, slen] gold_bigram_scores_t = mat_t[gold_idxes[:, :-1], gold_idxes[:, 1:]] * input_mask[:, 1:] # [*, slen-1] all_losses_t = (potential_t - (gold_single_scores_t.sum(-1) + gold_bigram_scores_t.sum(-1))) # [*] if self.conf.loss_by_tok: ret_count = input_mask.sum() # [] else: ret_count = (input_mask.sum(-1)>0).float() # [*] return all_losses_t, ret_count
def forward(self, expr_t: BK.Expr, fixed_scores_t: BK.Expr = None, feed_output=False, mask_t: BK.Expr = None): conf: SingleBlockConf = self.conf # -- # pred if fixed_scores_t is not None: score_t = fixed_scores_t cf_t = None else: hid1_t = self.hid_in(expr_t) # [*, hid] score_t = self.pred_in(hid1_t) # [*, nlab] cf_t = self.aff_cf(hid1_t).squeeze(-1) # [*] # -- if mask_t is not None: shape0 = BK.get_shape(expr_t) shape1 = BK.get_shape(mask_t) if len(shape1) < len(shape0): mask_t = mask_t.unsqueeze(-1) # [*, 1] score_t += Constants.REAL_PRAC_MIN * (1. - mask_t) # [*, nlab] # -- # output if feed_output: W = self.W_getf() # [nlab, hid] prob_t = score_t.softmax(-1) # [*, nlab] hid2_t = BK.matmul(prob_t, W) * self.e_mul_scale # [*, hid], todo(+W): need dropout here? out_t = self.hid_out(hid2_t) # [*, ndim] final_t = self.norm(out_t + expr_t) # [*, ndim], add and norm else: final_t = expr_t # [*, ndim], simply no change and use input! return score_t, cf_t, final_t # [*, nlab], [*], [*, ndim]
def predict(self, insts: List[Sent], input_expr: BK.Expr, mask_expr: BK.Expr): conf: MySRLConf = self.conf slen = BK.get_shape(mask_expr, -1) # -- # ===== # evt _, all_evt_cfs, all_evt_raw_scores = self.evt_node.get_all_values() # [*, slen, Le] all_evt_scores = [z.log_softmax(-1) for z in all_evt_raw_scores] final_evt_scores = self.evt_node.helper.pred(all_logprobs=all_evt_scores, all_cfs=all_evt_cfs) # [*, slen, Le] if conf.evt_pred_use_all or conf.evt_pred_use_posi: # todo(+W): not an elegant way... final_evt_scores[:,:,0] += Constants.REAL_PRAC_MIN # all pred sth!! pred_evt_scores, pred_evt_labels = final_evt_scores.max(-1) # [*, slen] # ===== # arg _, all_arg_cfs, all_arg_raw_score = self.arg_node.get_all_values() # [*, slen, slen, La] all_arg_scores = [z.log_softmax(-1) for z in all_arg_raw_score] final_arg_scores = self.arg_node.helper.pred(all_logprobs=all_arg_scores, all_cfs=all_arg_cfs) # [*, slen, slen, La] # slightly more efficient by masking valid evts?? full_pred_shape = BK.get_shape(final_arg_scores)[:-1] # [*, slen, slen] pred_arg_scores, pred_arg_labels = BK.zeros(full_pred_shape), BK.zeros(full_pred_shape).long() arg_flat_mask = (pred_evt_labels > 0) # [*, slen] flat_arg_scores = final_arg_scores[arg_flat_mask] # [??, slen, La] if not BK.is_zero_shape(flat_arg_scores): # at least one predicate! if self.pred_cons_mat is not None: flat_mask_expr = mask_expr.unsqueeze(-2).expand(-1, slen, slen)[arg_flat_mask] # [*, 1->slen, slen] => [??, slen] flat_pred_arg_labels, flat_pred_arg_scores = BigramInferenceHelper.inference_search( flat_arg_scores, self.pred_cons_mat, flat_mask_expr, conf.arg_beam_k) # [??, slen] else: flat_pred_arg_scores, flat_pred_arg_labels = flat_arg_scores.max(-1) # [??, slen] pred_arg_scores[arg_flat_mask] = flat_pred_arg_scores pred_arg_labels[arg_flat_mask] = flat_pred_arg_labels # ===== # assign self.helper.put_results(insts, pred_evt_labels, pred_evt_scores, pred_arg_labels, pred_arg_scores)
def select_topk(score_t: BK.Expr, topk_t: Union[int, BK.Expr], mask_t: BK.Expr = None, dim=-1): # prepare K if isinstance(topk_t, int): K = topk_t tmp_shape = BK.get_shape(score_t) tmp_shape[dim] = 1 # set it as 1 topk_t = BK.constants_idx(tmp_shape, K) else: K = topk_t.max().item() exact_rank_t = topk_t - 1 # [bsize, 1] exact_rank_t.clamp_(min=0, max=K - 1) # make it in valid range! # mask values if mask_t is not None: score_t = score_t + Constants.REAL_PRAC_MIN * (1. - mask_t) # topk topk_vals, _ = score_t.topk(K, dim, largest=True, sorted=True) # [*, K, *] # gather score sel_thresh = topk_vals.gather(dim, exact_rank_t) # [*, 1, *] # get topk_mask topk_mask = (score_t >= sel_thresh).float() # [*, D, *] if mask_t is not None: topk_mask *= mask_t return topk_mask
def _extend_cand_score(self, cand_score: BK.Expr): if self.conf.lab_add_extract_score and cand_score is not None: non0_mask = self.lab_node.laber.speical_mask_non0 ret = non0_mask * cand_score.unsqueeze(-1) # [*, slen, L] else: ret = None return ret
def decode_frame(self, ibatch, scores_t: BK.Expr, pred_max_layer: int, voc, pred_label: bool, pred_tag: str, pred_check_layer: int): # -- # first get topk for each position logprobs_t = scores_t.log_softmax(-1) # [*, dlen, L] pred_scores, pred_labels = logprobs_t.topk( pred_max_layer) # [*, dlen, K] arr_scores, arr_labels = BK.get_value(pred_scores), BK.get_value( pred_labels) # [*, dlen, K] # put results res_bidxes, res_widxes, res_frames = [], [], [] # flattened results res_farrs = np.full(arr_scores.shape, None, dtype=object) # [*, dlen, K] for bidx, item in enumerate( ibatch.items): # for each item in the batch _dec_offsets = item.seq_info.dec_offsets for sidx, sent in enumerate(item.sents): # todo(+N): currently we only predict for center if there is! if item.center_sidx is not None and sidx != item.center_sidx: continue # skip non-center sent in this mode! _start = _dec_offsets[sidx] _len = len(sent) _arr_scores, _arr_labels = arr_scores[bidx][ _start:_start + _len], arr_labels[bidx][_start:_start + _len] for widx in range(_len): _full_widx = widx + _start # idx in the msent _tmp_set = set() for _k in range(pred_max_layer): _score, _lab = float(_arr_scores[widx][_k]), int( _arr_labels[widx][_k]) if _lab == 0: # note: lab=0 means NIL break _type_str = (voc.idx2word(_lab) if pred_label else "UNK") _type_str_prefix = '.'.join( _type_str.split('.')[:pred_check_layer]) if pred_check_layer >= 0 and _type_str_prefix in _tmp_set: continue # ignore since constraint _tmp_set.add(_type_str_prefix) # add new one! res_bidxes.append(bidx) res_widxes.append(_full_widx) _new_frame = sent.make_frame(widx, 1, tag=pred_tag, type=_type_str, score=float(_score)) _new_frame.set_label_idx(int(_lab)) _new_frame._tmp_sstart = _start # todo(+N): ugly tmp value ... _new_frame._tmp_sidx = sidx _new_frame._tmp_item = item res_frames.append(_new_frame) res_farrs[bidx, _full_widx, _k] = _new_frame # return res_bidxes_t, res_widxes_t = BK.input_idx(res_bidxes), BK.input_idx( res_widxes) # [??] return (res_bidxes_t, res_widxes_t), res_frames, res_farrs # [??], [*, dlen, K]
def prepare_with_lengths(self, input_shape: Tuple[int], length_expr: BK.Expr, gold_widx_expr: BK.Expr, gold_wlen_expr: BK.Expr, gold_addr_expr: BK.Expr): _f = (lambda _widx, _wlen: ( (_widx + _wlen).unsqueeze(0) <= length_expr.unsqueeze(-1)).float() ) # [bsize, mlen*dw] return self._common_prepare(input_shape, _f, gold_widx_expr, gold_wlen_expr, gold_addr_expr)
def expand_ranged_idxes(widx_t: BK.Expr, wlen_t: BK.Expr, pad: int = 0, max_width: int = None): if max_width is None: # if not provided if BK.is_zero_shape(wlen_t): max_width = 1 # at least one else: max_width = wlen_t.max().item() # overall max width # -- input_shape = BK.get_shape(widx_t) # [*] mw_range_t = BK.arange_idx(max_width).view([1] * len(input_shape) + [-1]) # [*, MW] expanded_idxes = widx_t.unsqueeze(-1) + mw_range_t # [*, MW] expanded_masks_bool = (mw_range_t < wlen_t.unsqueeze(-1)) # [*, MW] expanded_idxes.masked_fill_(~expanded_masks_bool, pad) # [*, MW] return expanded_idxes, expanded_masks_bool.float()
def judge(self, scores: BK.Expr, cf_scores: BK.Expr, mask_t: BK.Expr): conf: SeqExitHelperConf = self.conf # -- if self.is_cf and conf.cf_use_seq: # in this mode, already aggr_metrics return (cf_scores.squeeze(-1) / conf.cf_scale) >= conf.exit_thresh # -- if self.is_cf: # Geometric-like qt # aggr_metrics = (cf_scores.squeeze(-1)).sigmoid() # [*] seq_metrics = cf_scores.squeeze(-1) / conf.cf_scale # [*, slen] else: seq_metrics = self._cri_f(scores) # [*, slen] # -- seq_metrics = (1. - mask_t) + mask_t * seq_metrics # put 1. at mask place! slen = BK.get_shape(seq_metrics, -1) K = self._getk(slen, conf.exit_min_k) aggr_metrics = topk_avg(seq_metrics, mask_t, K, dim=-1, largest=False) # [*] return aggr_metrics >= conf.exit_thresh
def _prepare_loss_weights(self, mask_expr: BK.Expr, must_include_t: BK.Expr, neg_rate: float, extra_weights=None): if neg_rate <= 0.: # only must_include ret_weights = (must_include_t.float()) * mask_expr elif neg_rate < 1.: # must_include + sample ret_weights = ((BK.rand(mask_expr.shape) < neg_rate) | must_include_t).float() * mask_expr else: # all in ret_weights = mask_expr # simply as it is if extra_weights is not None: ret_weights = ret_weights * extra_weights return ret_weights
def loss(self, insts: List[Sent], input_expr: BK.Expr, mask_expr: BK.Expr): conf: MySRLConf = self.conf # -- slen = BK.get_shape(mask_expr, -1) arr_items, expr_evt_labels, expr_arg_labels, expr_loss_weight_non = self.helper.prepare(insts, True) if conf.binary_evt: expr_evt_labels = (expr_evt_labels>0).long() # either 0 or 1 loss_items = [] # ===== # evt # -- prepare weights and masks evt_not_nil = (expr_evt_labels>0) # [*, slen] evt_extra_weights = BK.where(evt_not_nil, mask_expr, expr_loss_weight_non.unsqueeze(-1)*conf.evt_loss_weight_non) evt_weights = self._prepare_loss_weights(mask_expr, evt_not_nil, conf.evt_loss_sample_neg, evt_extra_weights) # -- get losses _, all_evt_cfs, all_evt_scores = self.evt_node.get_all_values() # [*, slen] all_evt_losses = [] for one_evt_scores in all_evt_scores: one_losses = BK.loss_nll(one_evt_scores, expr_evt_labels, label_smoothing=conf.evt_label_smoothing) all_evt_losses.append(one_losses) evt_loss_results = self.evt_node.helper.loss(all_losses=all_evt_losses, all_cfs=all_evt_cfs) for loss_t, loss_alpha, loss_name in evt_loss_results: one_evt_item = LossHelper.compile_leaf_loss("evt"+loss_name, (loss_t*evt_weights).sum(), evt_weights.sum(), loss_lambda=conf.loss_evt*loss_alpha, gold=evt_not_nil.float().sum()) loss_items.append(one_evt_item) # ===== # arg _arg_loss_evt_sample_neg = conf.arg_loss_evt_sample_neg if _arg_loss_evt_sample_neg > 0: arg_evt_masks = ((BK.rand(mask_expr.shape)<_arg_loss_evt_sample_neg) | evt_not_nil).float() * mask_expr else: arg_evt_masks = evt_not_nil.float() # [*, slen] # expand/flat the dims arg_flat_mask = (arg_evt_masks > 0) # [*, slen] flat_mask_expr = mask_expr.unsqueeze(-2).expand(-1, slen, slen)[arg_flat_mask] # [*, 1->slen, slen] => [??, slen] flat_arg_labels = expr_arg_labels[arg_flat_mask] # [??, slen] flat_arg_not_nil = (flat_arg_labels > 0) # [??, slen] flat_arg_weights = self._prepare_loss_weights(flat_mask_expr, flat_arg_not_nil, conf.arg_loss_sample_neg) # -- get losses _, all_arg_cfs, all_arg_scores = self.arg_node.get_all_values() # [*, slen, slen] all_arg_losses = [] for one_arg_scores in all_arg_scores: one_flat_arg_scores = one_arg_scores[arg_flat_mask] # [??, slen] one_losses = BK.loss_nll(one_flat_arg_scores, flat_arg_labels, label_smoothing=conf.evt_label_smoothing) all_arg_losses.append(one_losses) all_arg_cfs = [z[arg_flat_mask] for z in all_arg_cfs] # [??, slen] arg_loss_results = self.arg_node.helper.loss(all_losses=all_arg_losses, all_cfs=all_arg_cfs) for loss_t, loss_alpha, loss_name in arg_loss_results: one_arg_item = LossHelper.compile_leaf_loss("arg"+loss_name, (loss_t*flat_arg_weights).sum(), flat_arg_weights.sum(), loss_lambda=conf.loss_arg*loss_alpha, gold=flat_arg_not_nil.float().sum()) loss_items.append(one_arg_item) # ===== # return loss ret_loss = LossHelper.combine_multiple_losses(loss_items) return ret_loss
def prepare_sd_init(self, expr_main: BK.Expr, expr_pair: BK.Expr): if self.is_pairwise: sd_init_t = self.sd_init_aff(expr_pair) # [*, hid] else: if BK.is_zero_shape(expr_main): sd_init_t0 = expr_main.sum(-2) # simply make the shape! else: sd_init_t0 = self.sd_init_pool_f(expr_main, -2) # pooling at -2: [*, Dm'] sd_init_t = self.sd_init_aff(sd_init_t0) # [*, hid] return sd_init_t
def decode_frame_given(self, ibatch, scores_t: BK.Expr, pred_max_layer: int, voc, pred_label: bool, pred_tag: str, assume_osof: bool): if pred_label: # if overwrite label! logprobs_t = scores_t.log_softmax(-1) # [*, dlen, L] pred_scores, pred_labels = logprobs_t.max( -1) # [*, dlen], note: maximum! arr_scores, arr_labels = BK.get_value(pred_scores), BK.get_value( pred_labels) # [*, dlen] else: arr_scores = arr_labels = None # -- # read given results res_bidxes, res_widxes, res_frames = [], [], [] # flattened results tmp_farrs = defaultdict(list) # later assign for bidx, item in enumerate( ibatch.items): # for each item in the batch _trg_frames = [item.inst] if assume_osof else \ sum([sent.get_frames(pred_tag) for sidx,sent in enumerate(item.sents) if (item.center_sidx is None or sidx == item.center_sidx)],[]) # still only pick center ones! # -- _dec_offsets = item.seq_info.dec_offsets for _frame in _trg_frames: # note: simply sort by original order! sidx = item.sents.index(_frame.sent) _start = _dec_offsets[sidx] _full_hidx = _start + _frame.mention.shead_widx # add new one res_bidxes.append(bidx) res_widxes.append(_full_hidx) _frame._tmp_sstart = _start # todo(+N): ugly tmp value ... _frame._tmp_sidx = sidx _frame._tmp_item = item res_frames.append(_frame) tmp_farrs[(bidx, _full_hidx)].append(_frame) # assign/rewrite label? if pred_label: _lab = int(arr_labels[bidx, _full_hidx]) # label index _frame.set_label_idx(_lab) _frame.set_label(voc.idx2word(_lab)) _frame.set_score(float(arr_scores[bidx, _full_hidx])) # -- # -- res_farrs = np.full(BK.get_shape(scores_t)[:-1] + [pred_max_layer], None, dtype=object) # [*, dlen, K] for _key, _values in tmp_farrs.items(): bidx, widx = _key _values = _values[:pred_max_layer] # truncate if more! res_farrs[bidx, widx, :len(_values)] = _values # return res_bidxes_t, res_widxes_t = BK.input_idx(res_bidxes), BK.input_idx( res_widxes) # [??] return (res_bidxes_t, res_widxes_t), res_frames, res_farrs # [??], [*, dlen, K]
def _go_common(self, res: SpanExtractorOutput, sel_mask: BK.Expr, add_gold_rate: float): gaddr_expr, span_mask = res.gaddr_expr, res.mask_expr bsize = BK.get_shape(span_mask, 0) # add gold? if add_gold_rate > 0.: # inplace gold_mask = ((gaddr_expr >= 0) & (BK.rand(sel_mask.shape) < add_gold_rate) ).float() # note: gaddr==-1 means nope sel_mask += gold_mask sel_mask.clamp_(max=1.) # OR sel_mask *= span_mask # must be valid # select masked final_idx_t, final_mask_t = BK.mask2idx(sel_mask, padding_idx=0) # [bsize, ??] _tmp_arange_t = BK.arange_idx(bsize).unsqueeze(1) # [bsize, 1] res.arrange(_tmp_arange_t, final_idx_t, final_mask_t) if res.gaddr_expr is not None: res.gaddr_expr.masked_fill_(final_mask_t == 0., -1) # make invalid ones -1 return res # [bsize, SNUM, *]
def decode_with_scores(left_scores: BK.Expr, right_scores: BK.Expr, normalize: bool): if normalize: left_scores = BK.log_softmax(left_scores, -1) right_scores = BK.log_softmax(right_scores, -1) # pairwise adding score_shape = BK.get_shape(left_scores) pair_scores = left_scores.unsqueeze(-1) + right_scores.unsqueeze( -2) # [*, slen_L, slen_R] flt_pair_scores = pair_scores.view(score_shape[:-1] + [-1]) # [*, slen*slen] # LR mask slen = score_shape[-1] arange_t = BK.arange_idx(slen) lr_mask = (arange_t.unsqueeze(-1) <= arange_t.unsqueeze(-2)).float().view(-1) # [slen_L*slen_R] max_scores, max_idxes = (flt_pair_scores + (1. - lr_mask) * Constants.REAL_PRAC_MIN).max( -1) # [*] left_idxes, right_idxes = max_idxes // slen, max_idxes % slen # [*] return max_scores, left_idxes, right_idxes
def loss(self, insts: Union[List[Sent], List[Frame]], input_expr: BK.Expr, mask_expr: BK.Expr, pair_expr: BK.Expr = None, lookup_flatten=False, external_extra_score: BK.Expr=None): conf: DirectExtractorConf = self.conf # step 0: prepare golds arr_gold_items, expr_gold_gaddr, expr_gold_widxes, expr_gold_wlens, expr_loss_weight_non = \ self.helper.prepare(insts, use_cache=True) # step 1: extract cands if conf.loss_use_posi: cand_res = self.extract_node.go_lookup( input_expr, expr_gold_widxes, expr_gold_wlens, (expr_gold_gaddr>=0).float(), gaddr_expr=expr_gold_gaddr) else: # todo(note): assume no in-middle mask!! cand_widx, cand_wlen, cand_mask, cand_gaddr = self.extract_node.prepare_with_lengths( BK.get_shape(mask_expr), mask_expr.sum(-1).long(), expr_gold_widxes, expr_gold_wlens, expr_gold_gaddr) if conf.span_train_sample: # simply do sampling cand_res = self.extract_node.go_sample( input_expr, mask_expr, cand_widx, cand_wlen, cand_mask, rate=conf.span_train_sample_rate, count=conf.span_train_sample_count, gaddr_expr=cand_gaddr, add_gold_rate=1.0) # note: always fully add gold for sampling!! else: # beam pruner using topk cand_res = self.extract_node.go_topk( input_expr, mask_expr, cand_widx, cand_wlen, cand_mask, rate=conf.span_topk_rate, count=conf.span_topk_count, gaddr_expr=cand_gaddr, add_gold_rate=conf.span_train_topk_add_gold_rate) # step 1+: prepare for labeling cand_gold_mask = (cand_res.gaddr_expr>=0).float() * cand_res.mask_expr # [*, cand_len] # todo(note): add a 0 as idx=-1 to make NEG ones as 0!! flatten_gold_label_idxes = BK.input_idx([(0 if z is None else z.label_idx) for z in arr_gold_items.flatten()] + [0]) gold_label_idxes = flatten_gold_label_idxes[cand_res.gaddr_expr] cand_loss_weights = BK.where(gold_label_idxes==0, expr_loss_weight_non.unsqueeze(-1)*conf.loss_weight_non, cand_res.mask_expr) final_loss_weights = cand_loss_weights * cand_res.mask_expr # cand loss if conf.loss_cand > 0. and not conf.loss_use_posi: loss_cand0 = BK.loss_binary(cand_res.score_expr, cand_gold_mask, label_smoothing=conf.cand_label_smoothing) loss_cand = (loss_cand0 * final_loss_weights).sum() loss_cand_item = LossHelper.compile_leaf_loss(f"cand", loss_cand, final_loss_weights.sum(), loss_lambda=conf.loss_cand) else: loss_cand_item = None # extra score cand_extra_score = self._get_extra_score( cand_res.score_expr, insts, cand_res, arr_gold_items, conf.loss_use_cons, conf.loss_use_lu) final_extra_score = self._sum_scores(external_extra_score, cand_extra_score) # step 2: label; with special weights loss_lab, loss_count = self.lab_node.loss( cand_res.span_expr, pair_expr, cand_res.mask_expr, gold_label_idxes, loss_weight_expr=final_loss_weights, extra_score=final_extra_score) loss_lab_item = LossHelper.compile_leaf_loss(f"lab", loss_lab, loss_count, loss_lambda=conf.loss_lab, gold=cand_gold_mask.sum()) # == # return loss ret_loss = LossHelper.combine_multiple_losses([loss_cand_item, loss_lab_item]) return self._finish_loss(ret_loss, insts, input_expr, mask_expr, pair_expr, lookup_flatten)
def _cand_score_and_select(self, input_expr: BK.Expr, mask_expr: BK.Expr): conf: SoftExtractorConf = self.conf # -- cand_full_scores = self.cand_scorer(input_expr).squeeze( -1) + (1. - mask_expr) * Constants.REAL_PRAC_MIN # [*, slen] # decide topk count len_t = mask_expr.sum(-1) # [*] topk_t = (len_t * conf.cand_topk_rate).clamp( max=conf.cand_topk_count).ceil().long().unsqueeze(-1) # [*, 1] # get topk mask if BK.is_zero_shape(mask_expr): topk_mask = mask_expr.clone( ) # no need to go topk since no elements else: topk_mask = select_topk(cand_full_scores, topk_t, mask_t=mask_expr, dim=-1) # [*, slen] # thresh cand_decisions = topk_mask * (cand_full_scores >= conf.cand_thresh).float() # [*, slen] return cand_full_scores, cand_decisions # [*, slen]
def _split_loss(self, stack_loss: BK.Expr): _budget = BK.get_shape(stack_loss, -1) # NL _splits = [] for _g in self.loss_groups + [_budget]: if _g < _budget: # still have ones _splits.append(_g) _budget -= _g else: # directly put all remaining budgets _splits.append(_budget) break # -- loss_list = stack_loss.split(_splits, dim=-1) # *[*, ??] return [z.sum(-1) for z in loss_list]
def _common_prepare(self, input_shape: Tuple[int], _mask_f: Callable, gold_widx_expr: BK.Expr, gold_wlen_expr: BK.Expr, gold_addr_expr: BK.Expr): conf: SpanExtractorConf = self.conf min_width, max_width = conf.min_width, conf.max_width diff_width = max_width - min_width + 1 # number of width to extract # -- bsize, mlen = input_shape # -- # [bsize, mlen*(max_width-min_width)], mlen first (dim=1) # note: the spans are always sorted by (widx, wlen) _tmp_arange_t = BK.arange_idx(mlen * diff_width) # [mlen*dw] widx_t0 = (_tmp_arange_t // diff_width) # [mlen*dw] wlen_t0 = (_tmp_arange_t % diff_width) + min_width # [mlen*dw] mask_t0 = _mask_f(widx_t0, wlen_t0) # [bsize, mlen*dw] # -- # compacting (use mask2idx and gather) final_idx_t, final_mask_t = BK.mask2idx(mask_t0, padding_idx=0) # [bsize, ??] _tmp2_arange_t = BK.arange_idx(bsize).unsqueeze(1) # [bsize, 1] # no need to make valid for mask=0, since idx=0 means (0, min_width) # todo(+?): do we need to deal with empty ones here? ret_widx = widx_t0[final_idx_t] # [bsize, ??] ret_wlen = wlen_t0[final_idx_t] # [bsize, ??] # -- # prepare gold (as pointer-like addresses) if gold_addr_expr is not None: gold_t0 = BK.constants_idx((bsize, mlen * diff_width), -1) # [bsize, mlen*diff] # check valid of golds (flatten all) gold_valid_t = ((gold_addr_expr >= 0) & (gold_wlen_expr >= min_width) & (gold_wlen_expr <= max_width)) gold_valid_t = gold_valid_t.view(-1) # [bsize*_glen] _glen = BK.get_shape(gold_addr_expr, 1) flattened_bsize_t = BK.arange_idx( bsize * _glen) // _glen # [bsize*_glen] flattened_fidx_t = (gold_widx_expr * diff_width + gold_wlen_expr - min_width).view(-1) # [bsize*_glen] flattened_gaddr_t = gold_addr_expr.view(-1) # mask and assign gold_t0[flattened_bsize_t[gold_valid_t], flattened_fidx_t[gold_valid_t]] = flattened_gaddr_t[ gold_valid_t] ret_gaddr = gold_t0[_tmp2_arange_t, final_idx_t] # [bsize, ??] ret_gaddr.masked_fill_((final_mask_t == 0), -1) # make invalid ones -1 else: ret_gaddr = None # -- return ret_widx, ret_wlen, final_mask_t, ret_gaddr
def decode_arg(self, res_evts: List, arg_scores_t: BK.Expr, pred_max_layer: int, voc, arg_allowed_sent_gap: int, arr_efs): # first get topk arg_logprobs_t = arg_scores_t.log_softmax(-1) # [??, dlen, L] pred_arg_scores, pred_arg_labels = arg_logprobs_t.topk( pred_max_layer) # [??, dlen, K] arr_arg_scores, arr_arg_labels = BK.get_value( pred_arg_scores), BK.get_value(pred_arg_labels) # [??, dlen, K] # put results res_fidxes, res_widxes, res_args = [], [], [] # flattened results for fidx, evt in enumerate(res_evts): # for each evt item = evt._tmp_item # cached _evt_sidx = evt._tmp_sidx _dec_offsets = item.seq_info.dec_offsets for sidx, sent in enumerate(item.sents): if abs(sidx - _evt_sidx) > arg_allowed_sent_gap: continue # larger than allowed sentence gap _start = _dec_offsets[sidx] _len = len(sent) _arr_scores, _arr_labels = arr_arg_scores[fidx][ _start:_start + _len], arr_arg_labels[fidx][_start:_start + _len] for widx in range(_len): _full_widx = widx + _start # idx in the msent _new_ef = None if arr_efs is not None: # note: arr_efs should also expand to frames! _new_ef = arr_efs[ fidx, _full_widx, 0] # todo(+N): only get the first one! if _new_ef is None: continue # no ef! for _score, _lab in zip(_arr_scores[widx], _arr_labels[widx]): # [K] if _lab == 0: # note: idx=0 means NIL break # add new one!! res_fidxes.append(fidx) res_widxes.append(_full_widx) if _new_ef is None: _new_ef = sent.make_entity_filler( widx, 1) # share them if possible! _new_arg = evt.add_arg(_new_ef, role=voc.idx2word(_lab), score=float(_score)) _new_arg._tmp_sstart = _start # todo(+N): ugly tmp value ... res_args.append(_new_arg) # return res_fidxes_t, res_widxes_t = BK.input_idx(res_fidxes), BK.input_idx( res_widxes) # [??] return (res_fidxes_t, res_widxes_t), res_args
def select_topk_non_overlapping(score_t: BK.Expr, topk_t: Union[int, BK.Expr], widx_t: BK.Expr, wlen_t: BK.Expr, input_mask_t: BK.Expr, mask_t: BK.Expr = None, dim=-1): score_shape = BK.get_shape(score_t) assert dim == -1 or dim == len( score_shape - 1 ), "Currently only support last-dim!!" # todo(+2): we can permute to allow any dim! # -- # prepare K if isinstance(topk_t, int): tmp_shape = score_shape.copy() tmp_shape[dim] = 1 # set it as 1 topk_t = BK.constants_idx(tmp_shape, topk_t) # -- reshape_trg = [np.prod(score_shape[:-1]).item(), -1] # [*, ?] _, sorted_idxes_t = score_t.sort(dim, descending=True) # -- # put it as CPU and use loop; todo(+N): more efficient ways? arr_sorted_idxes_t, arr_topk_t, arr_widx_t, arr_wlen_t, arr_input_mask_t, arr_mask_t = \ [BK.get_value(z.reshape(reshape_trg)) for z in [sorted_idxes_t, topk_t, widx_t, wlen_t, input_mask_t, mask_t]] _bsize, _cnum = BK.get_shape(arr_sorted_idxes_t) # [bsize, NUM] arr_topk_mask = np.full([_bsize, _cnum], 0.) # [bsize, NUM] _bidx = 0 for aslice_sorted_idxes_t, aslice_topk_t, aslice_widx_t, aslice_wlen_t, aslice_input_mask_t, aslice_mask_t \ in zip(arr_sorted_idxes_t, arr_topk_t, arr_widx_t, arr_wlen_t, arr_input_mask_t, arr_mask_t): aslice_topk_mask = arr_topk_mask[_bidx] # -- cur_ok_mask = np.copy(aslice_input_mask_t) cur_budget = aslice_topk_t.item() for _cidx in aslice_sorted_idxes_t: _cidx = _cidx.item() if cur_budget <= 0: break # no budget left if not aslice_mask_t[_cidx].item(): continue # non-valid candidate one_widx, one_wlen = aslice_widx_t[_cidx].item( ), aslice_wlen_t[_cidx].item() if np.prod(cur_ok_mask[one_widx:one_widx + one_wlen]).item() == 0.: # any hit one? continue # ok! add it! cur_budget -= 1 cur_ok_mask[one_widx:one_widx + one_wlen] = 0. aslice_topk_mask[_cidx] = 1. _bidx += 1 # note: no need to *=mask_t again since already check in the loop return BK.input_real(arr_topk_mask).reshape(score_shape)
def decode_evt(self, dec, ibatch, evt_scores_t: BK.Expr): _pred_max_layer_evt = dec.conf.max_layer_evt _voc_evt = dec.voc_evt _pred_evt_label = self.conf.pred_evt_label # -- evt_logprobs_t = evt_scores_t.log_softmax(-1) # [*, dlen, L] pred_evt_scores, pred_evt_labels = evt_logprobs_t.topk( _pred_max_layer_evt) # [*, dlen, K] arr_evt_scores, arr_evt_labels = BK.get_value( pred_evt_scores), BK.get_value(pred_evt_labels) # [*, dlen, K] # put results res_bidxes, res_widxes, res_evts = [], [], [] # flattened results for bidx, item in enumerate( ibatch.items): # for each item in the batch _dec_offsets = item.seq_info.dec_offsets for sidx, sent in enumerate(item.sents): # note: here we only predict for center if there is! if item.center_sidx is not None and sidx != item.center_sidx: continue # skip non-center sent in this mode! _start = _dec_offsets[sidx] _len = len(sent) _arr_scores, _arr_labels = arr_evt_scores[bidx][ _start:_start + _len], arr_evt_labels[bidx][_start:_start + _len] for widx in range(_len): for _score, _lab in zip(_arr_scores[widx], _arr_labels[widx]): # [K] if _lab == 0: # note: idx=0 means NIL break # add new one!! res_bidxes.append(bidx) res_widxes.append( _start + widx) # note: remember to add offset! _new_evt = sent.make_event( widx, 1, type=(_voc_evt.idx2word(_lab) if _pred_evt_label else "UNK"), score=float(_score)) _new_evt.set_label_idx(int(_lab)) _new_evt._tmp_sstart = _start # todo(+N): ugly tmp value ... _new_evt._tmp_sidx = sidx _new_evt._tmp_item = item res_evts.append(_new_evt) # return res_bidxes_t, res_widxes_t = BK.input_idx(res_bidxes), BK.input_idx( res_widxes) # [??] return (res_bidxes_t, res_widxes_t), res_evts
def step_end(self, cache: DecCache, slice_main: BK.Expr, slice_mask: BK.Expr, pred_idxes: BK.Expr): # we do possible decoder step here conf: SeqLabelerConf = self.conf # -- if conf.use_seqdec: embed_t = self.laber.lookup(pred_idxes) # [*, E] input_t = self.sd_input_aff([slice_main, embed_t]) # [*, hid] if conf.sd_skip_non: # further mask, todo(note): fixed non as 0! slice_mask = slice_mask * (pred_idxes > 0).float() hid_t = self.seqdec.go_feed( cache, input_t.unsqueeze(-2), slice_mask.unsqueeze(-1)) # [*, 1, hid] # add here for possible bigram usage cache.last_idxes = pred_idxes # [*] return cache # cache modified inplace
def decode_evt_given(self, dec, ibatch, evt_scores_t: BK.Expr): _voc_evt = dec.voc_evt _assume_osof = dec.conf.assume_osof # one seq one frame _pred_evt_label = self.conf.pred_evt_label # -- if _pred_evt_label: evt_logprobs_t = evt_scores_t.log_softmax(-1) # [*, dlen, L] pred_evt_scores, pred_evt_labels = evt_logprobs_t.max( -1) # [*, dlen], note: maximum! arr_evt_scores, arr_evt_labels = BK.get_value( pred_evt_scores), BK.get_value(pred_evt_labels) # [*, dlen] else: arr_evt_scores = arr_evt_labels = None # -- # read given results res_bidxes, res_widxes, res_evts = [], [], [] # flattened results for bidx, item in enumerate( ibatch.items): # for each item in the batch _trg_evts = [item.inst] if _assume_osof else \ sum([sent.events for sidx,sent in enumerate(item.sents) if (item.center_sidx is None or sidx == item.center_sidx)],[]) # -- _dec_offsets = item.seq_info.dec_offsets for _evt in _trg_evts: sidx = item.sents.index(_evt.sent) _start = _dec_offsets[sidx] _full_hidx = _start + _evt.mention.shead_widx # add new one res_bidxes.append(bidx) res_widxes.append(_full_hidx) _evt._tmp_sstart = _start # todo(+N): ugly tmp value ... _evt._tmp_sidx = sidx _evt._tmp_item = item res_evts.append(_evt) # assign label? if _pred_evt_label: _lab = int(arr_evt_labels[bidx, _full_hidx]) # label index _evt.set_label_idx(_lab) _evt.set_label(_voc_evt.idx2word(_lab)) _evt.set_score(float(arr_evt_scores[bidx, _full_hidx])) # -- # -- # return res_bidxes_t, res_widxes_t = BK.input_idx(res_bidxes), BK.input_idx( res_widxes) # [??] return (res_bidxes_t, res_widxes_t), res_evts
def apply_piece_pooling(t: BK.Expr, piece: int, f: Union[Callable, str] = ActivationHelper.get_pool('max'), dim: int = -1): # first do things like chunk by piece if piece == 1: return t # nothing to do # reshape orig_shape = BK.get_shape(t) if dim < 0: # should do this! dim = len(orig_shape) + dim orig_shape[dim] = piece # replace it with piece new_shape = orig_shape[:dim] + [-1] + orig_shape[dim:] # put before it reshaped_t = t.view(new_shape) # [..., -1, piece, ...] if isinstance(f, str): f = ActivationHelper.get_pool(f) return f(reshaped_t, dim + 1) # +1 since we make a new dim
def go_topk( self, input_expr: BK.Expr, input_mask: BK.Expr, # input widx_expr: BK.Expr, wlen_expr: BK.Expr, span_mask: BK.Expr, rate: float = None, count: float = None, # span gaddr_expr: BK.Expr = None, add_gold_rate: float = 0., # gold non_overlapping=False, score_prune: float = None): # non-overlapping! lookup_res = self.go_lookup(input_expr, widx_expr, wlen_expr, span_mask, gaddr_expr) # [bsize, NUM, *] # -- with BK.no_grad_env(): # no need grad here! all_score_expr = lookup_res.score_expr # get topk score: again rate is to the original input length if BK.is_zero_shape(lookup_res.mask_expr): topk_mask = lookup_res.mask_expr.clone( ) # no need to go topk since no elements else: topk_expr = self._determine_size( input_mask.sum(-1, keepdim=True), rate, count).long() # [bsize, 1] if non_overlapping: topk_mask = select_topk_non_overlapping(all_score_expr, topk_expr, widx_expr, wlen_expr, input_mask, mask_t=span_mask, dim=-1) else: topk_mask = select_topk(all_score_expr, topk_expr, mask_t=span_mask, dim=-1) # further score_prune? if score_prune is not None: topk_mask *= (all_score_expr >= score_prune).float() # select and add_gold return self._go_common(lookup_res, topk_mask, add_gold_rate)
def decode_arg(self, dec, res_evts: List, arg_scores_t: BK.Expr): _pred_max_layer_arg = dec.conf.max_layer_arg _arg_allowed_sent_gap = dec.conf.arg_allowed_sent_gap _voc_arg = dec.voc_arg # -- arg_logprobs_t = arg_scores_t.log_softmax(-1) # [??, dlen, L] pred_arg_scores, pred_arg_labels = arg_logprobs_t.topk( _pred_max_layer_arg) # [??, dlen, K] arr_arg_scores, arr_arg_labels = BK.get_value( pred_arg_scores), BK.get_value(pred_arg_labels) # [??, dlen, K] # put results res_fidxes, res_widxes, res_args = [], [], [] # flattened results for fidx, evt in enumerate(res_evts): # for each evt item = evt._tmp_item # cached _evt_sidx = evt._tmp_sidx _dec_offsets = item.seq_info.dec_offsets for sidx, sent in enumerate(item.sents): if abs(sidx - _evt_sidx) > _arg_allowed_sent_gap: continue # larger than allowed sentence gap _start = _dec_offsets[sidx] _len = len(sent) _arr_scores, _arr_labels = arr_arg_scores[fidx][ _start:_start + _len], arr_arg_labels[fidx][_start:_start + _len] for widx in range(_len): for _score, _lab in zip(_arr_scores[widx], _arr_labels[widx]): # [K] if _lab == 0: # note: idx=0 means NIL break # add new one!! res_fidxes.append(fidx) res_widxes.append(_start + widx) _new_ef = sent.make_entity_filler(widx, 1) _new_arg = evt.add_arg(_new_ef, role=_voc_arg.idx2word(_lab), score=float(_score)) _new_arg._tmp_sstart = _start # todo(+N): ugly tmp value ... res_args.append(_new_arg) # return res_fidxes_t, res_widxes_t = BK.input_idx(res_fidxes), BK.input_idx( res_widxes) # [??] return (res_fidxes_t, res_widxes_t), res_args
def decode_upos(self, ibatch, logprobs_t: BK.Expr): conf: ZDecoderUposConf = self.conf # get argmax label! pred_upos_scores, pred_upos_labels = logprobs_t.max(-1) # [*, dlen] # arr_upos_scores, arr_upos_labels = BK.get_value(pred_upos_scores), BK.get_value(pred_upos_labels) arr_upos_labels = BK.get_value(pred_upos_labels) # put results voc = self.voc for bidx, item in enumerate( ibatch.items): # for each item in the batch _dec_offsets = item.seq_info.dec_offsets for sidx, sent in enumerate(item.sents): if conf.msent_pred_center and (sidx != item.center_sidx): continue # skip non-center sent in this mode! _start = _dec_offsets[sidx] _len = len(sent) _upos_idxes = arr_upos_labels[bidx][_start:_start + _len].tolist() _upos_labels = voc.seq_idx2word(_upos_idxes) sent.build_uposes(_upos_labels)