Exemplo n.º 1
0
 def forward(self, inputs, add_bos=False, add_eos=False):
     conf: PosiInputEmbedderConf = self.conf
     # --
     try:
         # input is a shape as prepared by "PosiHelper"
         batch_size, max_len = inputs
         if add_bos:
             max_len += 1
         if add_eos:
             max_len += 1
         posi_idxes = BK.arange_idx(max_len)  # [?len?]
         ret = self.E(posi_idxes).unsqueeze(0).expand(batch_size, -1, -1)
     except:
         # input is tensor
         posi_idxes = BK.input_idx(inputs)  # [*, len]
         cur_maxlen = BK.get_shape(posi_idxes, -1)
         # --
         all_input_slices = []
         slice_shape = BK.get_shape(posi_idxes)[:-1] + [1]
         if add_bos:  # add 0 and offset
             all_input_slices.append(
                 BK.constants(slice_shape, 0, dtype=posi_idxes.dtype))
             cur_maxlen += 1
             posi_idxes += 1
         all_input_slices.append(posi_idxes)  # [*, len]
         if add_eos:
             all_input_slices.append(
                 BK.constants(slice_shape,
                              cur_maxlen,
                              dtype=posi_idxes.dtype))
         final_input_t = BK.concat(all_input_slices, -1)  # [*, 1?+len+1?]
         # finally
         ret = self.E(final_input_t)  # [*, ??, dim]
     return ret
Exemplo n.º 2
0
 def predict(self, insts: List[Sent], input_expr: BK.Expr, mask_expr: BK.Expr):
     conf: MySRLConf = self.conf
     slen = BK.get_shape(mask_expr, -1)
     # --
     # =====
     # evt
     _, all_evt_cfs, all_evt_raw_scores = self.evt_node.get_all_values()  # [*, slen, Le]
     all_evt_scores = [z.log_softmax(-1) for z in all_evt_raw_scores]
     final_evt_scores = self.evt_node.helper.pred(all_logprobs=all_evt_scores, all_cfs=all_evt_cfs)  # [*, slen, Le]
     if conf.evt_pred_use_all or conf.evt_pred_use_posi:  # todo(+W): not an elegant way...
         final_evt_scores[:,:,0] += Constants.REAL_PRAC_MIN  # all pred sth!!
     pred_evt_scores, pred_evt_labels = final_evt_scores.max(-1)  # [*, slen]
     # =====
     # arg
     _, all_arg_cfs, all_arg_raw_score = self.arg_node.get_all_values()  # [*, slen, slen, La]
     all_arg_scores = [z.log_softmax(-1) for z in all_arg_raw_score]
     final_arg_scores = self.arg_node.helper.pred(all_logprobs=all_arg_scores, all_cfs=all_arg_cfs)  # [*, slen, slen, La]
     # slightly more efficient by masking valid evts??
     full_pred_shape = BK.get_shape(final_arg_scores)[:-1]  # [*, slen, slen]
     pred_arg_scores, pred_arg_labels = BK.zeros(full_pred_shape), BK.zeros(full_pred_shape).long()
     arg_flat_mask = (pred_evt_labels > 0)  # [*, slen]
     flat_arg_scores = final_arg_scores[arg_flat_mask]  # [??, slen, La]
     if not BK.is_zero_shape(flat_arg_scores):  # at least one predicate!
         if self.pred_cons_mat is not None:
             flat_mask_expr = mask_expr.unsqueeze(-2).expand(-1, slen, slen)[arg_flat_mask]  # [*, 1->slen, slen] => [??, slen]
             flat_pred_arg_labels, flat_pred_arg_scores = BigramInferenceHelper.inference_search(
                 flat_arg_scores, self.pred_cons_mat, flat_mask_expr, conf.arg_beam_k)  # [??, slen]
         else:
             flat_pred_arg_scores, flat_pred_arg_labels = flat_arg_scores.max(-1)  # [??, slen]
         pred_arg_scores[arg_flat_mask] = flat_pred_arg_scores
         pred_arg_labels[arg_flat_mask] = flat_pred_arg_labels
     # =====
     # assign
     self.helper.put_results(insts, pred_evt_labels, pred_evt_scores, pred_arg_labels, pred_arg_scores)
Exemplo n.º 3
0
 def forward(self, inputs, add_bos=False, add_eos=False):
     conf: PlainInputEmbedderConf = self.conf
     # --
     voc = self.voc
     input_t = BK.input_idx(inputs)  # [*, len]
     # rare unk in training
     if self.is_training() and self.use_rare_unk:
         rare_unk_rate = conf.rare_unk_rate
         cur_unk_imask = (
             self.rare_unk_mask[input_t] *
             (BK.rand(BK.get_shape(input_t)) < rare_unk_rate)).long()
         input_t = input_t * (1 - cur_unk_imask) + voc.unk * cur_unk_imask
     # bos and eos
     all_input_slices = []
     slice_shape = BK.get_shape(input_t)[:-1] + [1]
     if add_bos:
         all_input_slices.append(
             BK.constants(slice_shape, voc.bos, dtype=input_t.dtype))
     all_input_slices.append(input_t)  # [*, len]
     if add_eos:
         all_input_slices.append(
             BK.constants(slice_shape, voc.eos, dtype=input_t.dtype))
     final_input_t = BK.concat(all_input_slices, -1)  # [*, 1?+len+1?]
     # finally
     ret = self.E(final_input_t)  # [*, ??, dim]
     return ret
Exemplo n.º 4
0
 def forward(self, input_t: BK.Expr, edges: BK.Expr, mask_t: BK.Expr):
     _isize = self.conf._isize
     _ntype = self.conf.type_num
     _slen = BK.get_shape(edges, -1)
     # --
     edges3 = edges.clamp(min=-1, max=1) + 1
     edgesF = edges + _ntype  # offset to positive!
     # get hid
     hid0 = BK.matmul(input_t, self.W_hid).view(
         BK.get_shape(input_t)[:-1] + [3, _isize])  # [*, L, 3, D]
     hid1 = hid0.unsqueeze(-4).expand(-1, _slen, -1, -1,
                                      -1)  # [*, L, L, 3, D]
     hid2 = BK.gather_first_dims(hid1.contiguous(), edges3.unsqueeze(-1),
                                 -2).squeeze(-2)  # [*, L, L, D]
     hidB = self.b_hid[edgesF]  # [*, L, L, D]
     _hid = hid2 + hidB
     # get gate
     gate0 = BK.matmul(input_t, self.W_gate)  # [*, L, 3]
     gate1 = gate0.unsqueeze(-3).expand(-1, _slen, -1, -1)  # [*, L, L, 3]
     gate2 = gate1.gather(-1, edges3.unsqueeze(-1))  # [*, L, L, 1]
     gateB = self.b_gate[edgesF].unsqueeze(-1)  # [*, L, L, 1]
     _gate0 = BK.sigmoid(gate2 + gateB)
     _gmask0 = (
         (edges != 0) |
         (BK.eye(_slen) > 0)).float() * mask_t.unsqueeze(-2)  # [*,L,L]
     _gate = _gate0 * _gmask0.unsqueeze(-1)  # [*,L,L,1]
     # combine
     h0 = BK.relu((_hid * _gate).sum(-2))  # [*, L, D]
     h1 = self.drop_node(h0)
     # add & norm?
     if self.ln is not None:
         h1 = self.ln(h1 + input_t)
     return h1
Exemplo n.º 5
0
 def forward(self, med: ZMediator):
     ibatch_seq_info = med.ibatch.seq_info
     # prepare input, truncate if too long
     _input_ids, _input_masks, _input_segids = \
         ibatch_seq_info.enc_input_ids, ibatch_seq_info.enc_input_masks, ibatch_seq_info.enc_input_segids
     _eff_input_ids = med.get_cache('eff_input_ids')  # note: special name!!
     if _eff_input_ids is not None:
         _input_ids = _eff_input_ids
     # --
     if BK.get_shape(_input_ids, -1) > self.tokenizer.model_max_length:
         _full_len = BK.get_shape(_input_ids, -1)
         _max_len = self.tokenizer.model_max_length
         zwarn(
             f"Input too long for bert, truncate it: {BK.get_shape(_input_ids)} => {_max_len}"
         )
         _input_ids, _input_masks, _input_segids = \
             _input_ids[:,:_max_len], _input_masks[:,:_max_len], _input_segids[:,:_max_len]
         # todo(+W+N): how to handle decoders for these cases?
     # forward
     ret = self.bert.forward(_input_ids,
                             _input_masks,
                             _input_segids,
                             med=med)
     # extra
     if self.gcn:
         ret = self.gcn.forward(med)
     # --
     return ret
Exemplo n.º 6
0
 def lookup_flatten(self, insts: Union[List[Sent], List[Frame]], input_expr: BK.Expr, mask_expr: BK.Expr,
                    pair_expr: BK.Expr = None):
     arr_items, mlp_expr, zmask_expr, extra_info = self.lookup(insts, input_expr, mask_expr, pair_expr)
     expr_widxes, expr_wlens = extra_info
     # flatten
     ret_items, ret_sidx, ret_expr, fl_widxes, fl_wlens = LookupNode.flatten_results(
         arr_items, zmask_expr, mlp_expr, expr_widxes, expr_wlens)
     # --
     # also make full expr
     # full_masks = ((_arange_t>=fl_widxes.unsqueeze(-1)) & (_arange_t<(fl_widxes+fl_wlens).unsqueeze(-1))).float()  # [??, slen]
     # ret_full_expr = full_masks.unsqueeze(-1) * ret_expr.unsqueeze(-2)  # [??, slen, D]
     if self.conf.flatten_lookup_use_dist:  # use posi again: [...,-2,-1,0,0,0,1,2,...]
         left_widxes = fl_widxes.unsqueeze(-1)  # [??, 1]
         right_widxes = (fl_widxes+fl_wlens-1).unsqueeze(-1)  # [??, 1]
         _arange_t = BK.arange_idx(BK.get_shape(mask_expr, 1)).unsqueeze(0)  # [1, slen]
         dist0 = _arange_t - left_widxes  # [??, slen]
         dist1 = _arange_t - right_widxes  # [??, slen]
         full_dist = (_arange_t < left_widxes).long() * dist0 + (_arange_t > right_widxes).long() * dist1
         ret_full_expr = self.indicator_norm(self.indicator_embed(full_dist))  # [??, slen, D]
         # # ret_full_expr = self.indicator_embed(full_dist)  # [??, slen, D]
     else:  # otherwise 0/1
         _arange_t = BK.arange_idx(BK.get_shape(mask_expr, 1)).unsqueeze(0)  # [1, slen]
         full_ind = ((_arange_t>=fl_widxes.unsqueeze(-1)) & (_arange_t<(fl_widxes+fl_wlens).unsqueeze(-1))).long()  # [??, slen]
         ret_full_expr = self.indicator_norm(self.indicator_embed(full_ind))  # [??, slen, D]
     # --
     return ret_items, ret_sidx, ret_expr, ret_full_expr  # [??, D]
Exemplo n.º 7
0
 def forward(self, expr_t: BK.Expr, fixed_scores_t: BK.Expr = None, feed_output=False, mask_t: BK.Expr = None):
     conf: SingleBlockConf = self.conf
     # --
     # pred
     if fixed_scores_t is not None:
         score_t = fixed_scores_t
         cf_t = None
     else:
         hid1_t = self.hid_in(expr_t)  # [*, hid]
         score_t = self.pred_in(hid1_t)  # [*, nlab]
         cf_t = self.aff_cf(hid1_t).squeeze(-1)  # [*]
     # --
     if mask_t is not None:
         shape0 = BK.get_shape(expr_t)
         shape1 = BK.get_shape(mask_t)
         if len(shape1) < len(shape0):
             mask_t = mask_t.unsqueeze(-1)  # [*, 1]
         score_t += Constants.REAL_PRAC_MIN * (1. - mask_t)  # [*, nlab]
     # --
     # output
     if feed_output:
         W = self.W_getf()  # [nlab, hid]
         prob_t = score_t.softmax(-1)  # [*, nlab]
         hid2_t = BK.matmul(prob_t, W) * self.e_mul_scale  # [*, hid], todo(+W): need dropout here?
         out_t = self.hid_out(hid2_t)  # [*, ndim]
         final_t = self.norm(out_t + expr_t)  # [*, ndim], add and norm
     else:
         final_t = expr_t  # [*, ndim], simply no change and use input!
     return score_t, cf_t, final_t  # [*, nlab], [*], [*, ndim]
Exemplo n.º 8
0
 def put_labels(self, arr_items, best_labs, best_scores):
     assert BK.get_shape(best_labs) == BK.get_shape(arr_items)
     # --
     all_arrs = [BK.get_value(z) for z in [best_labs, best_scores]]
     for cur_items, cur_labs, cur_scores in zip(arr_items, *all_arrs):
         for one_item, one_lab, one_score in zip(cur_items, cur_labs, cur_scores):
             if one_item is None: continue
             one_lab, one_score = int(one_lab), float(one_score)
             one_item.score = one_score
             one_item.set_label_idx(one_lab)
             one_item.set_label(self.vocab.idx2word(one_lab))
Exemplo n.º 9
0
 def forward(self, input_expr: BK.Expr, widx_expr: BK.Expr, wlen_expr: BK.Expr):
     conf: BaseSpanConf = self.conf
     # --
     # note: check empty, otherwise error
     input_item_shape = BK.get_shape(widx_expr)
     if np.prod(input_item_shape) == 0:
         return BK.zeros(input_item_shape + [self.output_dim])  # return an empty but shaped tensor
     # --
     start_idxes, end_idxes = widx_expr, widx_expr+wlen_expr  # make [start, end)
     # get sizes
     bsize, slen = BK.get_shape(input_expr)[:2]
     # num_span = BK.get_shape(start_idxes, 1)
     arange2_t = BK.arange_idx(bsize).unsqueeze(-1)  # [bsize, 1]
     # --
     reprs = []
     if conf.use_starts:  # start [start,
         reprs.append(input_expr[arange2_t, start_idxes])  # [bsize, ?, D]
     if conf.use_ends:  # simply ,end-1]
         reprs.append(input_expr[arange2_t, end_idxes-1])
     if conf.use_softhead:
         # expand range
         all_span_idxes, all_span_mask = expand_ranged_idxes(widx_expr, wlen_expr, 0, None)  # [bsize, ?, MW]
         # flatten
         flatten_all_span_idxes = all_span_idxes.view(bsize, -1)  # [bsize, ?*MW]
         flatten_all_span_mask = all_span_mask.view(bsize, -1)  # [bsize, ?*MW]
         # get softhead score (consider mask here)
         softhead_scores = self.softhead_scorer(input_expr).squeeze(-1)  # [bsize, slen]
         flatten_all_span_scores = softhead_scores[arange2_t, flatten_all_span_idxes]  # [bsize, ?*MW]
         flatten_all_span_scores += (1.-flatten_all_span_mask) * Constants.REAL_PRAC_MIN
         all_span_scores = flatten_all_span_scores.view(all_span_idxes.shape)  # [bsize, ?, MW]
         # reshape and (optionally topk) and softmax
         softhead_topk = conf.softhead_topk
         if softhead_topk>0 and BK.get_shape(all_span_scores,-1)>softhead_topk:  # further select topk; note: this may save mem
             final_span_score, _tmp_idxes = all_span_scores.topk(softhead_topk, dim=-1, sorted=False)  # [bsize, ?, K]
             final_span_idxes = all_span_idxes.gather(-1, _tmp_idxes)  # [bsize, ?, K]
         else:
             final_span_score, final_span_idxes = all_span_scores, all_span_idxes  # [bsize, ?, MW]
         final_prob = final_span_score.softmax(-1)  # [bsize, ?, ??]
         # [bsize, ?, ??, D]
         final_repr = input_expr[arange2_t, final_span_idxes.view(bsize, -1)].view(BK.get_shape(final_span_idxes)+[-1])
         weighted_repr = (final_repr * final_prob.unsqueeze(-1)).sum(-2)  # [bsize, ?, D]
         reprs.append(weighted_repr)
     if conf.use_width:
         cur_width_embed = self.width_embed(wlen_expr)  # [bsize, ?, DE]
         reprs.append(cur_width_embed)
     # concat
     concat_repr = BK.concat(reprs, -1)  # [bsize, ?, SUM]
     if conf.use_proj:
         ret = self.final_proj(concat_repr)  # [bsize, ?, DR]
     else:
         ret = concat_repr
     return ret
Exemplo n.º 10
0
def select_topk_non_overlapping(score_t: BK.Expr,
                                topk_t: Union[int, BK.Expr],
                                widx_t: BK.Expr,
                                wlen_t: BK.Expr,
                                input_mask_t: BK.Expr,
                                mask_t: BK.Expr = None,
                                dim=-1):
    score_shape = BK.get_shape(score_t)
    assert dim == -1 or dim == len(
        score_shape - 1
    ), "Currently only support last-dim!!"  # todo(+2): we can permute to allow any dim!
    # --
    # prepare K
    if isinstance(topk_t, int):
        tmp_shape = score_shape.copy()
        tmp_shape[dim] = 1  # set it as 1
        topk_t = BK.constants_idx(tmp_shape, topk_t)
    # --
    reshape_trg = [np.prod(score_shape[:-1]).item(), -1]  # [*, ?]
    _, sorted_idxes_t = score_t.sort(dim, descending=True)
    # --
    # put it as CPU and use loop; todo(+N): more efficient ways?
    arr_sorted_idxes_t, arr_topk_t, arr_widx_t, arr_wlen_t, arr_input_mask_t, arr_mask_t = \
        [BK.get_value(z.reshape(reshape_trg)) for z in [sorted_idxes_t, topk_t, widx_t, wlen_t, input_mask_t, mask_t]]
    _bsize, _cnum = BK.get_shape(arr_sorted_idxes_t)  # [bsize, NUM]
    arr_topk_mask = np.full([_bsize, _cnum], 0.)  # [bsize, NUM]
    _bidx = 0
    for aslice_sorted_idxes_t, aslice_topk_t, aslice_widx_t, aslice_wlen_t, aslice_input_mask_t, aslice_mask_t \
            in zip(arr_sorted_idxes_t, arr_topk_t, arr_widx_t, arr_wlen_t, arr_input_mask_t, arr_mask_t):
        aslice_topk_mask = arr_topk_mask[_bidx]
        # --
        cur_ok_mask = np.copy(aslice_input_mask_t)
        cur_budget = aslice_topk_t.item()
        for _cidx in aslice_sorted_idxes_t:
            _cidx = _cidx.item()
            if cur_budget <= 0: break  # no budget left
            if not aslice_mask_t[_cidx].item(): continue  # non-valid candidate
            one_widx, one_wlen = aslice_widx_t[_cidx].item(
            ), aslice_wlen_t[_cidx].item()
            if np.prod(cur_ok_mask[one_widx:one_widx +
                                   one_wlen]).item() == 0.:  # any hit one?
                continue
            # ok! add it!
            cur_budget -= 1
            cur_ok_mask[one_widx:one_widx + one_wlen] = 0.
            aslice_topk_mask[_cidx] = 1.
        _bidx += 1
    # note: no need to *=mask_t again since already check in the loop
    return BK.input_real(arr_topk_mask).reshape(score_shape)
Exemplo n.º 11
0
 def _aggregate_subtoks(self, repr_t: BK.Expr, dsel_seq_info):
     conf: DSelectorConf = self.conf
     _arange_t, _sel_t, _len_t = dsel_seq_info.arange2_t, dsel_seq_info.dec_sel_idxes, dsel_seq_info.dec_sel_lens
     _max_len = 1 if BK.is_zero_shape(_len_t) else _len_t.max().item()
     _max_len = max(1, min(conf.dsel_max_subtoks, _max_len))  # truncate
     # --
     _tmp_arange_t = BK.arange_idx(_max_len)  # [M]
     _all_valids_t = (_tmp_arange_t < _len_t.unsqueeze(-1)).float()  # [*, dlen, M]
     _tmp_arange_t = _tmp_arange_t * _all_valids_t.long()  # note: pad as 0
     _all_idxes_t = _sel_t.unsqueeze(-1) + _tmp_arange_t  # [*, dlen, M]
     _all_repr_t = repr_t[_arange_t.unsqueeze(-1), _all_idxes_t]  # [*, dlen, M, D]
     while len(BK.get_shape(_all_valids_t)) < len(BK.get_shape(_all_repr_t)):
         _all_valids_t = _all_valids_t.unsqueeze(-1)
     _all_repr_t = _all_repr_t * _all_valids_t
     return _all_repr_t, _all_valids_t
Exemplo n.º 12
0
def select_topk(score_t: BK.Expr,
                topk_t: Union[int, BK.Expr],
                mask_t: BK.Expr = None,
                dim=-1):
    # prepare K
    if isinstance(topk_t, int):
        K = topk_t
        tmp_shape = BK.get_shape(score_t)
        tmp_shape[dim] = 1  # set it as 1
        topk_t = BK.constants_idx(tmp_shape, K)
    else:
        K = topk_t.max().item()
    exact_rank_t = topk_t - 1  # [bsize, 1]
    exact_rank_t.clamp_(min=0, max=K - 1)  # make it in valid range!
    # mask values
    if mask_t is not None:
        score_t = score_t + Constants.REAL_PRAC_MIN * (1. - mask_t)
    # topk
    topk_vals, _ = score_t.topk(K, dim, largest=True, sorted=True)  # [*, K, *]
    # gather score
    sel_thresh = topk_vals.gather(dim, exact_rank_t)  # [*, 1, *]
    # get topk_mask
    topk_mask = (score_t >= sel_thresh).float()  # [*, D, *]
    if mask_t is not None:
        topk_mask *= mask_t
    return topk_mask
Exemplo n.º 13
0
 def loss(self,
          insts: Union[List[Sent], List[Frame]],
          input_expr: BK.Expr,
          mask_expr: BK.Expr,
          pair_expr: BK.Expr = None,
          lookup_flatten=False,
          external_extra_score: BK.Expr = None):
     conf: SeqExtractorConf = self.conf
     # step 0: prepare golds
     expr_gold_slabs, expr_loss_weight_non = self.helper.prepare(
         insts, mlen=BK.get_shape(mask_expr, -1), use_cache=True)
     final_loss_weights = BK.where(
         expr_gold_slabs == 0,
         expr_loss_weight_non.unsqueeze(-1) * conf.loss_weight_non,
         mask_expr)
     # step 1: label; with special weights
     loss_lab, loss_count = self.lab_node.loss(
         input_expr,
         pair_expr,
         mask_expr,
         expr_gold_slabs,
         loss_weight_expr=final_loss_weights,
         extra_score=external_extra_score)
     loss_lab_item = LossHelper.compile_leaf_loss(
         f"lab",
         loss_lab,
         loss_count,
         loss_lambda=conf.loss_lab,
         gold=(expr_gold_slabs > 0).float().sum())
     # ==
     # return loss
     ret_loss = LossHelper.combine_multiple_losses([loss_lab_item])
     return self._finish_loss(ret_loss, insts, input_expr, mask_expr,
                              pair_expr, lookup_flatten)
Exemplo n.º 14
0
 def inference_forward(scores_t: BK.Expr,
                       mat_t: BK.Expr,
                       mask_t: BK.Expr,
                       beam_k: int = 0):
     scores_shape = BK.get_shape(scores_t)  # [*, slen, L]
     need_topk = (beam_k > 0) and (beam_k < scores_shape[-1]
                                   )  # whether we need topk
     # --
     score_slices = split_at_dim(scores_t, -2, True)  # List[*, 1, L]
     mask_slices = split_at_dim(mask_t, -1, True)  # List[*, 1]
     # the loop on slen
     start_shape = scores_shape[:-2] + [1]  # [*, 1]
     last_labs_t = BK.constants_idx(start_shape,
                                    0)  # [*, K], todo(note): start with 0!
     last_accu_scores = BK.zeros(start_shape)  # accumulated scores: [*, K]
     last_potential = BK.zeros(
         start_shape)  # accumulated potentials: [*, K]
     full_labs_t = BK.arange_idx(scores_shape[-1]).view(
         [1] * (len(scores_shape) - 2) + [-1])  # [*, L]
     cur_step = 0
     for one_score_slice, one_mask_slice in zip(score_slices,
                                                mask_slices):  # [*,L],[*,1]
         one_mask_slice_neg = 1. - one_mask_slice  # [*,1]
         # get current scores
         if cur_step == 0:  # no transition at start!
             one_cur_scores = one_score_slice  # [*, 1, L]
         else:
             one_cur_scores = one_score_slice + mat_t[
                 last_labs_t]  # [*, K, L]
         # first for potentials
         expanded_potentials = last_potential.unsqueeze(
             -1) + one_cur_scores  # [*, K, L]
         merged_potentials = log_sum_exp(expanded_potentials, -2)  # [*, L]
         # optional for topk with merging; note: not really topk!!
         if need_topk:
             # todo(+W): another option is to directly select with potentials rather than accu_scores
             expanded_scores = last_accu_scores.unsqueeze(
                 -1) + one_cur_scores  # [*, K, L]
             # max at -2, merge same current label
             max_scores, max_idxes = expanded_scores.max(-2)  # [*, L]
             # topk at current step, no need to sort!
             new_accu_scores, new_labs_t = max_scores.topk(
                 beam_k, -1, sorted=False)  # [*, K]
             new_potential = merged_potentials.gather(-1,
                                                      new_labs_t)  # [*, K]
             # mask and update
             last_potential = last_potential * one_mask_slice_neg + new_potential * one_mask_slice  # [*, K]
             last_accu_scores = last_accu_scores * one_mask_slice_neg + new_accu_scores * one_mask_slice  # [*, K]
             last_labs_t = last_labs_t * one_mask_slice_neg.long(
             ) + new_labs_t * one_mask_slice.long()  # [*, K]
         else:
             # mask and update
             last_potential = last_potential * one_mask_slice_neg + merged_potentials * one_mask_slice  # [*, L(K)]
             # note: still need to mask this!
             last_labs_t = last_labs_t * one_mask_slice_neg.long(
             ) + full_labs_t * one_mask_slice.long()
         cur_step += 1
     # finally sum all
     ret_potential = log_sum_exp(last_potential, -1)  # [*]
     return ret_potential
Exemplo n.º 15
0
 def predict(self,
             insts: Union[List[Sent], List[Frame]],
             input_expr: BK.Expr,
             mask_expr: BK.Expr,
             pair_expr: BK.Expr = None,
             lookup_flatten=False,
             external_extra_score: BK.Expr = None):
     conf: AnchorExtractorConf = self.conf
     assert not lookup_flatten
     bsize, slen = BK.get_shape(mask_expr)
     # --
     for inst in insts:  # first clear things
         self.helper._clear_f(inst)
     # --
     # step 1: simply labeling!
     best_labs, best_scores = self.lab_node.predict(
         input_expr, pair_expr, mask_expr, extra_score=external_extra_score)
     flt_items = self.helper.put_results(insts, best_labs,
                                         best_scores)  # [?]
     # --
     # step 2: final extend (in a flattened way)
     if len(flt_items) > 0 and conf.pred_ext:
         flt_mask = ((best_labs > 0) & (mask_expr > 0))  # [*, slen]
         flt_sidx = BK.arange_idx(bsize).unsqueeze(-1).expand_as(flt_mask)[
             flt_mask]  # [?]
         flt_expr = input_expr[flt_mask]  # [?, D]
         flt_full_expr = self._prepare_full_expr(flt_mask)  # [?, slen, D]
         self.ext_node.predict(flt_items, input_expr[flt_sidx], flt_expr,
                               flt_full_expr, mask_expr[flt_sidx])
     # --
     # extra:
     self.pp_node.prune(insts)
     return None
Exemplo n.º 16
0
 def forward(self,
             expr_t: BK.Expr,
             mask_t: BK.Expr,
             scores_t=None,
             **kwargs):
     conf: IdecConnectorAttConf = self.conf
     # --
     # prepare input
     _d_bs, _dq, _dk, _d_nl, _d_nh = BK.get_shape(scores_t)
     in1_t = scores_t[:, :, :, self.lstart:, :self.head_end].reshape(
         [_d_bs, _dq, _dk, self.d_in])  # [*, lenq, lenk, din]
     in2_t = in1_t.transpose(-3, -2)  # [*, lenk, lenq, din]
     final_input_t = BK.concat([in1_t, in2_t], -1)  # [*, lenk, lenq, din*2]
     # forward
     node_ret_t = self.node.forward(final_input_t, mask_t, self.feed_output,
                                    self.lidx,
                                    **kwargs)  # [*, lenq, lenk, head_end]
     if self.feed_output:
         # pad zeros if necessary
         if self.head_end < _d_nh:
             pad_t = BK.zeros([_d_bs, _dq, _dk, _d_nh - self.head_end])
             node_ret_t = BK.concat([node_ret_t, pad_t],
                                    -3)  # [*, lenq, lenk, Hin]
         return node_ret_t
     else:
         return None
Exemplo n.º 17
0
 def _get_extra_score(self, cand_score, insts, cand_res, arr_gold_items, use_cons: bool, use_lu: bool):
     # conf: DirectExtractorConf = self.conf
     # --
     # first cand score
     cand_score = self._extend_cand_score(cand_score)
     # then cons_lex score
     cons_lex_node = self.cons_lex_node
     if use_cons and cons_lex_node is not None:
         cons_lex = cons_lex_node.cons
         flt_arr_gold_items = arr_gold_items.flatten()
         _shape = BK.get_shape(cand_res.mask_expr)
         if cand_res.gaddr_expr is None:
             gaddr_expr = BK.constants(_shape, -1, dtype=BK.long)
         else:
             gaddr_expr = cand_res.gaddr_expr
         all_arrs = [BK.get_value(z) for z in [cand_res.widx_expr, cand_res.wlen_expr, cand_res.mask_expr, gaddr_expr]]
         arr_feats = np.full(_shape, None, dtype=object)
         for bidx, inst in enumerate(insts):
             one_arr_feats = arr_feats[bidx]
             _ii = -1
             for one_widx, one_wlen, one_mask, one_gaddr in zip(*[z[bidx] for z in all_arrs]):
                 _ii += 1
                 if one_mask == 0.: continue  # skip invlaid ones
                 if use_lu and one_gaddr>=0:
                     one_feat = cons_lex.lu2feat(flt_arr_gold_items[one_gaddr].info["luName"])
                 else:
                     one_feat = cons_lex.span2feat(inst, one_widx, one_wlen)
                 one_arr_feats[_ii] = one_feat
         cons_valids = cons_lex_node.lookup_with_feats(arr_feats)
         cons_score = (1.-cons_valids) * Constants.REAL_PRAC_MIN
     else:
         cons_score = None
     # sum
     return self._sum_scores(cand_score, cons_score)
Exemplo n.º 18
0
 def go_feed(self, cache: TransformerDecCache, input_expr: BK.Expr, mask_expr: BK.Expr = None):
     conf: TransformerConf = self.conf
     n_layers = conf.n_layers
     # --
     if n_layers == 0:  # only add the input
         cache.update([input_expr], mask_expr)  # one call of update is enough
         return input_expr  # change nothing
     # --
     # first prepare inputs
     input_shape = BK.get_shape(input_expr)  # [bsize, ssize, D]
     bsize, ssize = input_shape[:2]
     cache.s0_open_new_steps(bsize, ssize, mask_expr)  # open one
     # --
     if conf.use_posi:
         positions = cache.positions[:, -ssize:]  # [*, step]
         posi_embed = self.PE(positions)
         input_emb = self.input_f(input_expr+posi_embed)  # add absolute positional embeddings
     else:
         input_emb = self.input_f(input_expr)  # process input, [*, step, D]
     cur_q, cur_kv = cache.ss_add_new_layer(0, input_emb)
     # --
     # prepare rposi and casual-mask
     all_posi = cache.positions  # [*, old+new]
     rposi = all_posi[:, -ssize:].unsqueeze(-1) - all_posi.unsqueeze(-2)  # Q-KV [*, new(query), new+old(kv)]
     mask_qk = (rposi>=0).float()  # q must be later than kv
     # go!
     for ti, tnode in enumerate(self.tnodes):
         cur_q = tnode(cur_q, cur_kv, cache.mask, mask_qk, rposi=rposi)
         cur_q, cur_kv = cache.ss_add_new_layer(ti+1, cur_q)  # kv for next layer
     cache.sz_close()
     return cur_q  # [*, ssize, D]
Exemplo n.º 19
0
 def _prepare_full_expr(self, flt_mask: BK.Expr):
     bsize, slen = BK.get_shape(flt_mask)
     arange2_t = BK.arange_idx(slen).unsqueeze(0)  # [1, slen]
     all_widxes = arange2_t.expand_as(flt_mask)[flt_mask]  # [?]
     tmp_idxes = BK.zeros([len(all_widxes), slen]).long()  # [?, slen]
     tmp_idxes.scatter_(-1, all_widxes.unsqueeze(-1), 1)  # [?, slen]
     tmp_embs = self.indicator_embed(tmp_idxes)  # [?, slen, D]
     return tmp_embs
Exemplo n.º 20
0
 def get_batched_features(items: np.ndarray, df_val: Union[int, float], attr_f: Union[str, Callable], dtype=None):
     if isinstance(attr_f, str):
         _local_attr_str = str(attr_f)
         attr_f = lambda x: getattr(x, _local_attr_str)
     # --
     flattened_vals = BK.input_tensor([df_val if z is None else attr_f(z) for z in items.flatten()], dtype=dtype)
     ret = flattened_vals.view(BK.get_shape(items))
     return ret
Exemplo n.º 21
0
 def flatten_results(arr_items: np.ndarray, mask_expr: BK.Expr, *other_exprs: BK.Expr):
     sel_mask_expr = (mask_expr > 0.)
     # flatten first dims
     ret_items = [z for z in arr_items.flatten() if z is not None]
     ret_other_exprs = [z[sel_mask_expr] for z in other_exprs]  # [?(flat), D]
     ret_sidx = BK.arange_idx(BK.get_shape(mask_expr, 0)).unsqueeze(-1).expand_as(sel_mask_expr)[sel_mask_expr]
     assert all(len(ret_items) == len(z) for z in ret_other_exprs), "Error: dim0 not matched after flatten!"
     return ret_items, ret_sidx, *ret_other_exprs  # [?(flat), *]
Exemplo n.º 22
0
 def __init__(self, model: SeqLabelerNode, expr_main: BK.Expr,
              expr_pair: BK.Expr, input_mask: BK.Expr,
              extra_score: BK.Expr):
     self.model = model
     self.all_steps = BK.get_shape(expr_main, -2)  # slen
     # --
     # store them
     self.expr_main, self.expr_pair, self.input_mask, self.extra_score = expr_main, expr_pair, input_mask, extra_score
     # currently we only need repeat 1 & k
     self.contents: List[ZObject] = [None] * 1000  # this should be enough!
Exemplo n.º 23
0
def _ensure_margins_norm(marginals_expr):
    full_shape = BK.get_shape(marginals_expr)
    combined_marginals_expr = marginals_expr.view(full_shape[:-2] + [-1])       # [BS, Len, Len*L]
    # should be 1., but for no-solution situation there can be small values (+1 for all in this case, norm later)
    # make 0./0. = 0.
    combined_marginals_sum = combined_marginals_expr.sum(dim=-1, keepdim=True)
    combined_marginals_sum += (combined_marginals_sum < 1e-5).float() * 1e-5
    # then norm
    combined_marginals_expr /= combined_marginals_sum
    return combined_marginals_expr.view(full_shape)
Exemplo n.º 24
0
 def _loss_feed_split(self, mask_expr, split_scores, pred_split_decisions,
                      cand_widxes, cand_masks, cand_expr, cand_scores,
                      expr_seq_gaddr):
     conf: SoftExtractorConf = self.conf
     bsize, slen = BK.get_shape(mask_expr)
     arange2_t = BK.arange_idx(bsize).unsqueeze(-1)  # [*, 1, 1]
     # --
     # step 2.1: split loss (only on good points (excluding -1|-1 or paddings) with dynamic oracle)
     cand_gaddr = expr_seq_gaddr[arange2_t, cand_widxes]  # [*, clen]
     cand_gaddr0, cand_gaddr1 = cand_gaddr[:, :
                                           -1], cand_gaddr[:,
                                                           1:]  # [*, clen-1]
     split_oracle = (cand_gaddr0 !=
                     cand_gaddr1).float() * cand_masks[:, 1:]  # [*, clen-1]
     split_oracle_mask = (
         (cand_gaddr0 >= 0) |
         (cand_gaddr1 >= 0)).float() * cand_masks[:, 1:]  # [*, clen-1]
     raw_split_loss = BK.loss_binary(
         split_scores,
         split_oracle,
         label_smoothing=conf.split_label_smoothing)  # [*, slen]
     loss_split_item = LossHelper.compile_leaf_loss(
         f"split", (raw_split_loss * split_oracle_mask).sum(),
         split_oracle_mask.sum(),
         loss_lambda=conf.loss_split)
     # step 2.2: feed split
     # note: when teacher-forcing, only forcing good points, others still use pred
     force_split_decisions = split_oracle_mask * split_oracle + (
         1. - split_oracle_mask) * pred_split_decisions  # [*, clen-1]
     _use_force_mask = (BK.rand([bsize])
                        <= conf.split_feed_force_rate).float().unsqueeze(
                            -1)  # [*, 1], seq-level
     feed_split_decisions = (_use_force_mask * force_split_decisions +
                             (1. - _use_force_mask) * pred_split_decisions
                             )  # [*, clen-1]
     # next
     # *[*, seglen, MW], [*, seglen]
     seg_ext_cidxes, seg_ext_masks, seg_masks = self._split_extend(
         feed_split_decisions, cand_masks)
     seg_ext_scores, seg_ext_cidxes, seg_ext_widxes, seg_ext_masks, seg_weighted_expr = self._split_aggregate(
         cand_expr, cand_scores, cand_widxes, seg_ext_cidxes, seg_ext_masks,
         conf.split_topk)  # [*, seglen, ?]
     # finally get oracles for next steps
     # todo(+N): simply select the highest scored one as oracle
     if BK.is_zero_shape(seg_ext_scores):  # simply make them all -1
         oracle_gaddr = BK.constants_idx(seg_masks.shape, -1)  # [*, seglen]
     else:
         _, _seg_max_t = seg_ext_scores.max(-1,
                                            keepdim=True)  # [*, seglen, 1]
         oracle_widxes = seg_ext_widxes.gather(-1, _seg_max_t).squeeze(
             -1)  # [*, seglen]
         oracle_gaddr = expr_seq_gaddr.gather(-1,
                                              oracle_widxes)  # [*, seglen]
     oracle_gaddr[seg_masks <= 0] = -1  # (assign invalid ones) [*, seglen]
     return loss_split_item, seg_masks, seg_ext_widxes, seg_ext_masks, seg_weighted_expr, oracle_gaddr
Exemplo n.º 25
0
 def loss(self, insts: List[Sent], input_expr: BK.Expr, mask_expr: BK.Expr):
     conf: MySRLConf = self.conf
     # --
     slen = BK.get_shape(mask_expr, -1)
     arr_items, expr_evt_labels, expr_arg_labels, expr_loss_weight_non = self.helper.prepare(insts, True)
     if conf.binary_evt:
         expr_evt_labels = (expr_evt_labels>0).long()  # either 0 or 1
     loss_items = []
     # =====
     # evt
     # -- prepare weights and masks
     evt_not_nil = (expr_evt_labels>0)  # [*, slen]
     evt_extra_weights = BK.where(evt_not_nil, mask_expr, expr_loss_weight_non.unsqueeze(-1)*conf.evt_loss_weight_non)
     evt_weights = self._prepare_loss_weights(mask_expr, evt_not_nil, conf.evt_loss_sample_neg, evt_extra_weights)
     # -- get losses
     _, all_evt_cfs, all_evt_scores = self.evt_node.get_all_values()  # [*, slen]
     all_evt_losses = []
     for one_evt_scores in all_evt_scores:
         one_losses = BK.loss_nll(one_evt_scores, expr_evt_labels, label_smoothing=conf.evt_label_smoothing)
         all_evt_losses.append(one_losses)
     evt_loss_results = self.evt_node.helper.loss(all_losses=all_evt_losses, all_cfs=all_evt_cfs)
     for loss_t, loss_alpha, loss_name in evt_loss_results:
         one_evt_item = LossHelper.compile_leaf_loss("evt"+loss_name, (loss_t*evt_weights).sum(), evt_weights.sum(),
                                                     loss_lambda=conf.loss_evt*loss_alpha, gold=evt_not_nil.float().sum())
         loss_items.append(one_evt_item)
     # =====
     # arg
     _arg_loss_evt_sample_neg = conf.arg_loss_evt_sample_neg
     if _arg_loss_evt_sample_neg > 0:
         arg_evt_masks = ((BK.rand(mask_expr.shape)<_arg_loss_evt_sample_neg) | evt_not_nil).float() * mask_expr
     else:
         arg_evt_masks = evt_not_nil.float()  # [*, slen]
     # expand/flat the dims
     arg_flat_mask = (arg_evt_masks > 0)  # [*, slen]
     flat_mask_expr = mask_expr.unsqueeze(-2).expand(-1, slen, slen)[arg_flat_mask]  # [*, 1->slen, slen] => [??, slen]
     flat_arg_labels = expr_arg_labels[arg_flat_mask]  # [??, slen]
     flat_arg_not_nil = (flat_arg_labels > 0)  # [??, slen]
     flat_arg_weights = self._prepare_loss_weights(flat_mask_expr, flat_arg_not_nil, conf.arg_loss_sample_neg)
     # -- get losses
     _, all_arg_cfs, all_arg_scores = self.arg_node.get_all_values()  # [*, slen, slen]
     all_arg_losses = []
     for one_arg_scores in all_arg_scores:
         one_flat_arg_scores = one_arg_scores[arg_flat_mask]  # [??, slen]
         one_losses = BK.loss_nll(one_flat_arg_scores, flat_arg_labels, label_smoothing=conf.evt_label_smoothing)
         all_arg_losses.append(one_losses)
     all_arg_cfs = [z[arg_flat_mask] for z in all_arg_cfs]  # [??, slen]
     arg_loss_results = self.arg_node.helper.loss(all_losses=all_arg_losses, all_cfs=all_arg_cfs)
     for loss_t, loss_alpha, loss_name in arg_loss_results:
         one_arg_item = LossHelper.compile_leaf_loss("arg"+loss_name, (loss_t*flat_arg_weights).sum(), flat_arg_weights.sum(),
                                                     loss_lambda=conf.loss_arg*loss_alpha, gold=flat_arg_not_nil.float().sum())
         loss_items.append(one_arg_item)
     # =====
     # return loss
     ret_loss = LossHelper.combine_multiple_losses(loss_items)
     return ret_loss
Exemplo n.º 26
0
 def decode_frame_given(self, ibatch, scores_t: BK.Expr,
                        pred_max_layer: int, voc, pred_label: bool,
                        pred_tag: str, assume_osof: bool):
     if pred_label:  # if overwrite label!
         logprobs_t = scores_t.log_softmax(-1)  # [*, dlen, L]
         pred_scores, pred_labels = logprobs_t.max(
             -1)  # [*, dlen], note: maximum!
         arr_scores, arr_labels = BK.get_value(pred_scores), BK.get_value(
             pred_labels)  # [*, dlen]
     else:
         arr_scores = arr_labels = None
     # --
     # read given results
     res_bidxes, res_widxes, res_frames = [], [], []  # flattened results
     tmp_farrs = defaultdict(list)  # later assign
     for bidx, item in enumerate(
             ibatch.items):  # for each item in the batch
         _trg_frames = [item.inst] if assume_osof else \
             sum([sent.get_frames(pred_tag) for sidx,sent in enumerate(item.sents)
                  if (item.center_sidx is None or sidx == item.center_sidx)],[])  # still only pick center ones!
         # --
         _dec_offsets = item.seq_info.dec_offsets
         for _frame in _trg_frames:  # note: simply sort by original order!
             sidx = item.sents.index(_frame.sent)
             _start = _dec_offsets[sidx]
             _full_hidx = _start + _frame.mention.shead_widx
             # add new one
             res_bidxes.append(bidx)
             res_widxes.append(_full_hidx)
             _frame._tmp_sstart = _start  # todo(+N): ugly tmp value ...
             _frame._tmp_sidx = sidx
             _frame._tmp_item = item
             res_frames.append(_frame)
             tmp_farrs[(bidx, _full_hidx)].append(_frame)
             # assign/rewrite label?
             if pred_label:
                 _lab = int(arr_labels[bidx, _full_hidx])  # label index
                 _frame.set_label_idx(_lab)
                 _frame.set_label(voc.idx2word(_lab))
                 _frame.set_score(float(arr_scores[bidx, _full_hidx]))
         # --
     # --
     res_farrs = np.full(BK.get_shape(scores_t)[:-1] + [pred_max_layer],
                         None,
                         dtype=object)  # [*, dlen, K]
     for _key, _values in tmp_farrs.items():
         bidx, widx = _key
         _values = _values[:pred_max_layer]  # truncate if more!
         res_farrs[bidx, widx, :len(_values)] = _values
     # return
     res_bidxes_t, res_widxes_t = BK.input_idx(res_bidxes), BK.input_idx(
         res_widxes)  # [??]
     return (res_bidxes_t,
             res_widxes_t), res_frames, res_farrs  # [??], [*, dlen, K]
Exemplo n.º 27
0
 def get_last_emb(self):
     k = "last_emb"
     ret = self.l_caches.get(k)
     if ret is None:
         ret = self.embs[-1]
         valid_idxes_t = self.valid_idxes_t
         if valid_idxes_t is not None:
             arange2_t = BK.arange_idx(BK.get_shape(valid_idxes_t, 0)).unsqueeze(-1)  # [bsize, 1]
             ret = ret[arange2_t, valid_idxes_t]  # select!
         self.l_caches[k] = ret  # cache
     return ret
Exemplo n.º 28
0
 def get_stack_att(self):
     k = "stack_att"
     ret = self.l_caches.get(k)
     if ret is None:
         ret = BK.stack(self.attns, -1).permute(0,2,3,4,1)  # NL*[*, H, lenq, lenk] -> [*, lenq, lenk, NL, H]
         valid_idxes_t = self.valid_idxes_t
         if valid_idxes_t is not None:
             arange3_t = BK.arange_idx(BK.get_shape(valid_idxes_t, 0)).unsqueeze(-1).unsqueeze(-1)  # [bsize, 1, 1]
             ret = ret[arange3_t, valid_idxes_t.unsqueeze(-1), valid_idxes_t.unsqueeze(-2)]  # select!
         self.l_caches[k] = ret  # cache
     return ret
Exemplo n.º 29
0
 def _prepare_full_expr(self, flt_ext_widxes, flt_ext_masks, slen: int):
     tmp_bsize = BK.get_shape(flt_ext_widxes, 0)
     tmp_idxes = BK.zeros([tmp_bsize, slen + 1])  # [?, slen+1]
     # note: (once a bug) should get rid of paddings!!
     _mask_lt = flt_ext_masks.long()  # [?, N]
     tmp_idxes.scatter_(-1,
                        flt_ext_widxes * _mask_lt + slen * (1 - _mask_lt),
                        1)  # [?, slen]
     tmp_idxes = tmp_idxes[:, :-1].long()  # [?, slen]
     tmp_embs = self.indicator_embed(tmp_idxes)  # [?, slen, D]
     return tmp_embs
Exemplo n.º 30
0
 def get_stack_emb(self):
     k = "stack_emb"
     ret = self.l_caches.get(k)
     if ret is None:
         # note: excluding embeddings here to make it convenient!!
         ret = BK.stack(self.embs[1:], -1)  # [*, slen, D, NL]
         valid_idxes_t = self.valid_idxes_t
         if valid_idxes_t is not None:
             arange2_t = BK.arange_idx(BK.get_shape(valid_idxes_t, 0)).unsqueeze(-1)  # [bsize, 1]
             ret = ret[arange2_t, valid_idxes_t]  # select!
         self.l_caches[k] = ret  # cache
     return ret