コード例 #1
0
ファイル: gcn.py プロジェクト: zzsfornlp/zmsp
 def forward(self, input_t: BK.Expr, edges: BK.Expr, mask_t: BK.Expr):
     _isize = self.conf._isize
     _ntype = self.conf.type_num
     _slen = BK.get_shape(edges, -1)
     # --
     edges3 = edges.clamp(min=-1, max=1) + 1
     edgesF = edges + _ntype  # offset to positive!
     # get hid
     hid0 = BK.matmul(input_t, self.W_hid).view(
         BK.get_shape(input_t)[:-1] + [3, _isize])  # [*, L, 3, D]
     hid1 = hid0.unsqueeze(-4).expand(-1, _slen, -1, -1,
                                      -1)  # [*, L, L, 3, D]
     hid2 = BK.gather_first_dims(hid1.contiguous(), edges3.unsqueeze(-1),
                                 -2).squeeze(-2)  # [*, L, L, D]
     hidB = self.b_hid[edgesF]  # [*, L, L, D]
     _hid = hid2 + hidB
     # get gate
     gate0 = BK.matmul(input_t, self.W_gate)  # [*, L, 3]
     gate1 = gate0.unsqueeze(-3).expand(-1, _slen, -1, -1)  # [*, L, L, 3]
     gate2 = gate1.gather(-1, edges3.unsqueeze(-1))  # [*, L, L, 1]
     gateB = self.b_gate[edgesF].unsqueeze(-1)  # [*, L, L, 1]
     _gate0 = BK.sigmoid(gate2 + gateB)
     _gmask0 = (
         (edges != 0) |
         (BK.eye(_slen) > 0)).float() * mask_t.unsqueeze(-2)  # [*,L,L]
     _gate = _gate0 * _gmask0.unsqueeze(-1)  # [*,L,L,1]
     # combine
     h0 = BK.relu((_hid * _gate).sum(-2))  # [*, L, D]
     h1 = self.drop_node(h0)
     # add & norm?
     if self.ln is not None:
         h1 = self.ln(h1 + input_t)
     return h1
コード例 #2
0
ファイル: block.py プロジェクト: zzsfornlp/zmsp
 def forward(self, expr_t: BK.Expr, fixed_scores_t: BK.Expr = None, feed_output=False, mask_t: BK.Expr = None):
     conf: SingleBlockConf = self.conf
     # --
     # pred
     if fixed_scores_t is not None:
         score_t = fixed_scores_t
         cf_t = None
     else:
         hid1_t = self.hid_in(expr_t)  # [*, hid]
         score_t = self.pred_in(hid1_t)  # [*, nlab]
         cf_t = self.aff_cf(hid1_t).squeeze(-1)  # [*]
     # --
     if mask_t is not None:
         shape0 = BK.get_shape(expr_t)
         shape1 = BK.get_shape(mask_t)
         if len(shape1) < len(shape0):
             mask_t = mask_t.unsqueeze(-1)  # [*, 1]
         score_t += Constants.REAL_PRAC_MIN * (1. - mask_t)  # [*, nlab]
     # --
     # output
     if feed_output:
         W = self.W_getf()  # [nlab, hid]
         prob_t = score_t.softmax(-1)  # [*, nlab]
         hid2_t = BK.matmul(prob_t, W) * self.e_mul_scale  # [*, hid], todo(+W): need dropout here?
         out_t = self.hid_out(hid2_t)  # [*, ndim]
         final_t = self.norm(out_t + expr_t)  # [*, ndim], add and norm
     else:
         final_t = expr_t  # [*, ndim], simply no change and use input!
     return score_t, cf_t, final_t  # [*, nlab], [*], [*, ndim]
コード例 #3
0
ファイル: srl.py プロジェクト: zzsfornlp/zmsp
 def predict(self, insts: List[Sent], input_expr: BK.Expr, mask_expr: BK.Expr):
     conf: MySRLConf = self.conf
     slen = BK.get_shape(mask_expr, -1)
     # --
     # =====
     # evt
     _, all_evt_cfs, all_evt_raw_scores = self.evt_node.get_all_values()  # [*, slen, Le]
     all_evt_scores = [z.log_softmax(-1) for z in all_evt_raw_scores]
     final_evt_scores = self.evt_node.helper.pred(all_logprobs=all_evt_scores, all_cfs=all_evt_cfs)  # [*, slen, Le]
     if conf.evt_pred_use_all or conf.evt_pred_use_posi:  # todo(+W): not an elegant way...
         final_evt_scores[:,:,0] += Constants.REAL_PRAC_MIN  # all pred sth!!
     pred_evt_scores, pred_evt_labels = final_evt_scores.max(-1)  # [*, slen]
     # =====
     # arg
     _, all_arg_cfs, all_arg_raw_score = self.arg_node.get_all_values()  # [*, slen, slen, La]
     all_arg_scores = [z.log_softmax(-1) for z in all_arg_raw_score]
     final_arg_scores = self.arg_node.helper.pred(all_logprobs=all_arg_scores, all_cfs=all_arg_cfs)  # [*, slen, slen, La]
     # slightly more efficient by masking valid evts??
     full_pred_shape = BK.get_shape(final_arg_scores)[:-1]  # [*, slen, slen]
     pred_arg_scores, pred_arg_labels = BK.zeros(full_pred_shape), BK.zeros(full_pred_shape).long()
     arg_flat_mask = (pred_evt_labels > 0)  # [*, slen]
     flat_arg_scores = final_arg_scores[arg_flat_mask]  # [??, slen, La]
     if not BK.is_zero_shape(flat_arg_scores):  # at least one predicate!
         if self.pred_cons_mat is not None:
             flat_mask_expr = mask_expr.unsqueeze(-2).expand(-1, slen, slen)[arg_flat_mask]  # [*, 1->slen, slen] => [??, slen]
             flat_pred_arg_labels, flat_pred_arg_scores = BigramInferenceHelper.inference_search(
                 flat_arg_scores, self.pred_cons_mat, flat_mask_expr, conf.arg_beam_k)  # [??, slen]
         else:
             flat_pred_arg_scores, flat_pred_arg_labels = flat_arg_scores.max(-1)  # [??, slen]
         pred_arg_scores[arg_flat_mask] = flat_pred_arg_scores
         pred_arg_labels[arg_flat_mask] = flat_pred_arg_labels
     # =====
     # assign
     self.helper.put_results(insts, pred_evt_labels, pred_evt_scores, pred_arg_labels, pred_arg_scores)
コード例 #4
0
ファイル: direct.py プロジェクト: zzsfornlp/zmsp
 def _extend_cand_score(self, cand_score: BK.Expr):
     if self.conf.lab_add_extract_score and cand_score is not None:
         non0_mask = self.lab_node.laber.speical_mask_non0
         ret = non0_mask * cand_score.unsqueeze(-1)  # [*, slen, L]
     else:
         ret = None
     return ret
コード例 #5
0
 def prepare_with_lengths(self, input_shape: Tuple[int],
                          length_expr: BK.Expr, gold_widx_expr: BK.Expr,
                          gold_wlen_expr: BK.Expr, gold_addr_expr: BK.Expr):
     _f = (lambda _widx, _wlen: (
         (_widx + _wlen).unsqueeze(0) <= length_expr.unsqueeze(-1)).float()
           )  # [bsize, mlen*dw]
     return self._common_prepare(input_shape, _f, gold_widx_expr,
                                 gold_wlen_expr, gold_addr_expr)
コード例 #6
0
ファイル: helper.py プロジェクト: zzsfornlp/zmsp
def expand_ranged_idxes(widx_t: BK.Expr,
                        wlen_t: BK.Expr,
                        pad: int = 0,
                        max_width: int = None):
    if max_width is None:  # if not provided
        if BK.is_zero_shape(wlen_t):
            max_width = 1  # at least one
        else:
            max_width = wlen_t.max().item()  # overall max width
    # --
    input_shape = BK.get_shape(widx_t)  # [*]
    mw_range_t = BK.arange_idx(max_width).view([1] * len(input_shape) +
                                               [-1])  # [*, MW]
    expanded_idxes = widx_t.unsqueeze(-1) + mw_range_t  # [*, MW]
    expanded_masks_bool = (mw_range_t < wlen_t.unsqueeze(-1))  # [*, MW]
    expanded_idxes.masked_fill_(~expanded_masks_bool, pad)  # [*, MW]
    return expanded_idxes, expanded_masks_bool.float()
コード例 #7
0
ファイル: srl.py プロジェクト: zzsfornlp/zmsp
 def loss(self, insts: List[Sent], input_expr: BK.Expr, mask_expr: BK.Expr):
     conf: MySRLConf = self.conf
     # --
     slen = BK.get_shape(mask_expr, -1)
     arr_items, expr_evt_labels, expr_arg_labels, expr_loss_weight_non = self.helper.prepare(insts, True)
     if conf.binary_evt:
         expr_evt_labels = (expr_evt_labels>0).long()  # either 0 or 1
     loss_items = []
     # =====
     # evt
     # -- prepare weights and masks
     evt_not_nil = (expr_evt_labels>0)  # [*, slen]
     evt_extra_weights = BK.where(evt_not_nil, mask_expr, expr_loss_weight_non.unsqueeze(-1)*conf.evt_loss_weight_non)
     evt_weights = self._prepare_loss_weights(mask_expr, evt_not_nil, conf.evt_loss_sample_neg, evt_extra_weights)
     # -- get losses
     _, all_evt_cfs, all_evt_scores = self.evt_node.get_all_values()  # [*, slen]
     all_evt_losses = []
     for one_evt_scores in all_evt_scores:
         one_losses = BK.loss_nll(one_evt_scores, expr_evt_labels, label_smoothing=conf.evt_label_smoothing)
         all_evt_losses.append(one_losses)
     evt_loss_results = self.evt_node.helper.loss(all_losses=all_evt_losses, all_cfs=all_evt_cfs)
     for loss_t, loss_alpha, loss_name in evt_loss_results:
         one_evt_item = LossHelper.compile_leaf_loss("evt"+loss_name, (loss_t*evt_weights).sum(), evt_weights.sum(),
                                                     loss_lambda=conf.loss_evt*loss_alpha, gold=evt_not_nil.float().sum())
         loss_items.append(one_evt_item)
     # =====
     # arg
     _arg_loss_evt_sample_neg = conf.arg_loss_evt_sample_neg
     if _arg_loss_evt_sample_neg > 0:
         arg_evt_masks = ((BK.rand(mask_expr.shape)<_arg_loss_evt_sample_neg) | evt_not_nil).float() * mask_expr
     else:
         arg_evt_masks = evt_not_nil.float()  # [*, slen]
     # expand/flat the dims
     arg_flat_mask = (arg_evt_masks > 0)  # [*, slen]
     flat_mask_expr = mask_expr.unsqueeze(-2).expand(-1, slen, slen)[arg_flat_mask]  # [*, 1->slen, slen] => [??, slen]
     flat_arg_labels = expr_arg_labels[arg_flat_mask]  # [??, slen]
     flat_arg_not_nil = (flat_arg_labels > 0)  # [??, slen]
     flat_arg_weights = self._prepare_loss_weights(flat_mask_expr, flat_arg_not_nil, conf.arg_loss_sample_neg)
     # -- get losses
     _, all_arg_cfs, all_arg_scores = self.arg_node.get_all_values()  # [*, slen, slen]
     all_arg_losses = []
     for one_arg_scores in all_arg_scores:
         one_flat_arg_scores = one_arg_scores[arg_flat_mask]  # [??, slen]
         one_losses = BK.loss_nll(one_flat_arg_scores, flat_arg_labels, label_smoothing=conf.evt_label_smoothing)
         all_arg_losses.append(one_losses)
     all_arg_cfs = [z[arg_flat_mask] for z in all_arg_cfs]  # [??, slen]
     arg_loss_results = self.arg_node.helper.loss(all_losses=all_arg_losses, all_cfs=all_arg_cfs)
     for loss_t, loss_alpha, loss_name in arg_loss_results:
         one_arg_item = LossHelper.compile_leaf_loss("arg"+loss_name, (loss_t*flat_arg_weights).sum(), flat_arg_weights.sum(),
                                                     loss_lambda=conf.loss_arg*loss_alpha, gold=flat_arg_not_nil.float().sum())
         loss_items.append(one_arg_item)
     # =====
     # return loss
     ret_loss = LossHelper.combine_multiple_losses(loss_items)
     return ret_loss
コード例 #8
0
ファイル: expand.py プロジェクト: zzsfornlp/zmsp
 def decode_with_scores(left_scores: BK.Expr, right_scores: BK.Expr,
                        normalize: bool):
     if normalize:
         left_scores = BK.log_softmax(left_scores, -1)
         right_scores = BK.log_softmax(right_scores, -1)
     # pairwise adding
     score_shape = BK.get_shape(left_scores)
     pair_scores = left_scores.unsqueeze(-1) + right_scores.unsqueeze(
         -2)  # [*, slen_L, slen_R]
     flt_pair_scores = pair_scores.view(score_shape[:-1] +
                                        [-1])  # [*, slen*slen]
     # LR mask
     slen = score_shape[-1]
     arange_t = BK.arange_idx(slen)
     lr_mask = (arange_t.unsqueeze(-1) <=
                arange_t.unsqueeze(-2)).float().view(-1)  # [slen_L*slen_R]
     max_scores, max_idxes = (flt_pair_scores +
                              (1. - lr_mask) * Constants.REAL_PRAC_MIN).max(
                                  -1)  # [*]
     left_idxes, right_idxes = max_idxes // slen, max_idxes % slen  # [*]
     return max_scores, left_idxes, right_idxes
コード例 #9
0
ファイル: crf.py プロジェクト: zzsfornlp/zmsp
 def loss(self, unary_scores: BK.Expr, input_mask: BK.Expr, gold_idxes: BK.Expr):
     mat_t = self.bigram.get_matrix()  # [L, L]
     if BK.is_zero_shape(unary_scores):  # note: avoid empty
         potential_t = BK.zeros(BK.get_shape(unary_scores)[:-2])  # [*]
     else:
         potential_t = BigramInferenceHelper.inference_forward(unary_scores, mat_t, input_mask, self.conf.crf_beam)  # [*]
     gold_single_scores_t = unary_scores.gather(-1, gold_idxes.unsqueeze(-1)).squeeze(-1) * input_mask  # [*, slen]
     gold_bigram_scores_t = mat_t[gold_idxes[:, :-1], gold_idxes[:, 1:]] * input_mask[:, 1:]  # [*, slen-1]
     all_losses_t = (potential_t - (gold_single_scores_t.sum(-1) + gold_bigram_scores_t.sum(-1)))  # [*]
     if self.conf.loss_by_tok:
         ret_count = input_mask.sum()  # []
     else:
         ret_count = (input_mask.sum(-1)>0).float()  # [*]
     return all_losses_t, ret_count
コード例 #10
0
 def step_end(self, cache: DecCache, slice_main: BK.Expr,
              slice_mask: BK.Expr, pred_idxes: BK.Expr):
     # we do possible decoder step here
     conf: SeqLabelerConf = self.conf
     # --
     if conf.use_seqdec:
         embed_t = self.laber.lookup(pred_idxes)  # [*, E]
         input_t = self.sd_input_aff([slice_main, embed_t])  # [*, hid]
         if conf.sd_skip_non:  # further mask, todo(note): fixed non as 0!
             slice_mask = slice_mask * (pred_idxes > 0).float()
         hid_t = self.seqdec.go_feed(
             cache, input_t.unsqueeze(-2),
             slice_mask.unsqueeze(-1))  # [*, 1, hid]
     # add here for possible bigram usage
     cache.last_idxes = pred_idxes  # [*]
     return cache  # cache modified inplace
コード例 #11
0
ファイル: expand.py プロジェクト: zzsfornlp/zmsp
 def score(self,
           input_main: BK.Expr,
           input_pair: BK.Expr,
           input_mask: BK.Expr,
           left_constraints: BK.Expr = None,
           right_constraints: BK.Expr = None):
     conf: SpanExpanderConf = self.conf
     # --
     # left & right
     rets = []
     seq_shape = BK.get_shape(input_mask)
     cur_mask = input_mask
     arange_t = BK.arange_idx(seq_shape[-1]).view(
         [1] * (len(seq_shape) - 1) + [-1])  # [*, slen]
     for scorer, cons_t in zip([self.s_left, self.s_right],
                               [left_constraints, right_constraints]):
         mm = cur_mask if cons_t is None else (
             cur_mask * (arange_t <= cons_t).float())  # [*, slen]
         ss = scorer(
             input_main,
             None if input_pair is None else input_pair.unsqueeze(-2),
             mm).squeeze(-1)  # [*, slen]
         rets.append(ss)
     return rets[0], rets[1]  # [*, slen] (already masked)
コード例 #12
0
ファイル: helper.py プロジェクト: zzsfornlp/zmsp
def log_sum_exp(t: BK.Expr, dim: int, t_max: BK.Expr = None):
    if t_max is None:
        t_max, _ = t.max(dim)  # get maximum value; [*, *]
    ret = t_max + (t - t_max.unsqueeze(dim)).exp().sum(dim).log()  # [*, *]
    return ret
コード例 #13
0
 def mask(self, v: BK.Expr, erase_mask: BK.Expr):
     return super().mask(
         v, erase_mask.unsqueeze(-1))  # todo(note): simply allow all-mask
コード例 #14
0
 def loss(self,
          input_main: BK.Expr,
          input_pair: BK.Expr,
          input_mask: BK.Expr,
          gold_idxes: BK.Expr,
          loss_weight_expr: BK.Expr = None,
          extra_score: BK.Expr = None):
     conf: SeqLabelerConf = self.conf
     # --
     expr_main, expr_pair = self.transform_expr(input_main, input_pair)
     if self.loss_mle:
         # simply collect them all (not normalize here!)
         all_scores_t = self.score_all(
             expr_main,
             expr_pair,
             input_mask,
             gold_idxes,
             local_normalize=False,
             extra_score=extra_score)  # [*, slen, L]
         # negative log likelihood; todo(+1): repeat log-softmax here
         # all_losses_t = - all_scores_t.gather(-1, gold_idxes.unsqueeze(-1)).squeeze(-1) * input_mask  # [*, slen]
         all_losses_t = BK.loss_nll(
             all_scores_t,
             gold_idxes,
             label_smoothing=self.conf.labeler_conf.label_smoothing)  # [*]
         all_losses_t *= input_mask
         if loss_weight_expr is not None:
             all_losses_t *= loss_weight_expr
         ret_loss = all_losses_t.sum()  # []
     elif self.loss_crf:
         # no normalization & no bigram
         single_scores_t = self.score_all(
             expr_main,
             expr_pair,
             input_mask,
             None,
             use_bigram=False,
             extra_score=extra_score)  # [*, slen, L]
         mat_t = self.bigram.get_matrix()  # [L, L]
         if BK.is_zero_shape(single_scores_t):  # note: avoid empty
             potential_t = BK.zeros(
                 BK.get_shape(single_scores_t)[:-2])  # [*]
         else:
             potential_t = BigramInferenceHelper.inference_forward(
                 single_scores_t, mat_t, input_mask, conf.beam_k)  # [*]
         gold_single_scores_t = single_scores_t.gather(
             -1,
             gold_idxes.unsqueeze(-1)).squeeze(-1) * input_mask  # [*, slen]
         gold_bigram_scores_t = mat_t[
             gold_idxes[:, :-1],
             gold_idxes[:, 1:]] * input_mask[:, 1:]  # [*, slen-1]
         all_losses_t = (
             potential_t -
             (gold_single_scores_t.sum(-1) + gold_bigram_scores_t.sum(-1))
         )  # [*]
         # todo(+N): also no label_smoothing for crf
         # todo(+N): for now, ignore loss_weight for crf mode!!
         # if loss_weight_expr is not None:
         #     assert BK.get_shape(loss_weight_expr, -1) == 1, "Currently CRF loss requires seq level loss_weight!!"
         #     all_losses_t *= loss_weight_expr
         ret_loss = all_losses_t.sum()  # []
     else:
         raise NotImplementedError()
     # ret_count
     if conf.loss_by_tok:  # sum all valid toks
         if conf.loss_by_tok_weighted and loss_weight_expr is not None:
             ret_count = (input_mask * loss_weight_expr).sum()
         else:
             ret_count = input_mask.sum()
     else:  # sum all valid batch items
         ret_count = input_mask.prod(-1).sum()
     return (ret_loss, ret_count)
コード例 #15
0
 def score_all(self,
               expr_main: BK.Expr,
               expr_pair: BK.Expr,
               input_mask: BK.Expr,
               gold_idxes: BK.Expr,
               local_normalize: bool = None,
               use_bigram: bool = True,
               extra_score: BK.Expr = None):
     conf: SeqLabelerConf = self.conf
     # first collect basic scores
     if conf.use_seqdec:
         # first prepare init hidden
         sd_init_t = self.prepare_sd_init(expr_main, expr_pair)  # [*, hid]
         # init cache: no mask at batch level
         sd_cache = self.seqdec.go_init(
             sd_init_t, init_mask=None)  # and no need to cum_state here!
         # prepare inputs at once
         if conf.sd_skip_non:
             gold_valid_mask = (gold_idxes > 0).float(
             ) * input_mask  # [*, slen], todo(note): fix 0 as non here!
             gv_idxes, gv_masks = BK.mask2idx(gold_valid_mask)  # [*, ?]
             bsize = BK.get_shape(gold_idxes, 0)
             arange_t = BK.arange_idx(bsize).unsqueeze(-1)  # [*, 1]
             # select and forward
             gv_embeds = self.laber.lookup(
                 gold_idxes[arange_t, gv_idxes])  # [*, ?, E]
             gv_input_t = self.sd_input_aff(
                 [expr_main[arange_t, gv_idxes], gv_embeds])  # [*, ?, hid]
             gv_hid_t = self.seqdec.go_feed(sd_cache, gv_input_t,
                                            gv_masks)  # [*, ?, hid]
             # select back and output_aff
             aug_hid_t = BK.concat([sd_init_t.unsqueeze(-2), gv_hid_t],
                                   -2)  # [*, 1+?, hid]
             sel_t = BK.pad(gold_valid_mask[:, :-1].cumsum(-1), (1, 0),
                            value=0.).long()  # [*, 1+(slen-1)]
             shifted_hid_t = aug_hid_t[arange_t, sel_t]  # [*, slen, hid]
         else:
             gold_idx_embeds = self.laber.lookup(gold_idxes)  # [*, slen, E]
             all_input_t = self.sd_input_aff(
                 [expr_main,
                  gold_idx_embeds])  # inputs to dec, [*, slen, hid]
             all_hid_t = self.seqdec.go_feed(
                 sd_cache, all_input_t,
                 input_mask)  # output-hids, [*, slen, hid]
             shifted_hid_t = BK.concat(
                 [sd_init_t.unsqueeze(-2), all_hid_t[:, :-1]],
                 -2)  # [*, slen, hid]
         # scorer
         pre_labeler_t = self.sd_output_aff([expr_main, shifted_hid_t
                                             ])  # [*, slen, hid]
     else:
         pre_labeler_t = expr_main  # [*, slen, Dm']
     # score with labeler (no norm here since we may need to add other scores)
     scores_t = self.laber.score(
         pre_labeler_t,
         None if expr_pair is None else expr_pair.unsqueeze(-2),
         input_mask,
         extra_score=extra_score,
         local_normalize=False)  # [*, slen, L]
     # bigram score addition
     if conf.use_bigram and use_bigram:
         bigram_scores_t = self.bigram.get_matrix()[
             gold_idxes[:, :-1]]  # [*, slen-1, L]
         score_shape = BK.get_shape(bigram_scores_t)
         score_shape[1] = 1
         slice_t = BK.constants(
             score_shape,
             0.)  # fix 0., no transition from BOS (and EOS) for simplicity!
         bigram_scores_t = BK.concat([slice_t, bigram_scores_t],
                                     1)  # [*, slen, L]
         scores_t += bigram_scores_t  # [*, slen]
     # local normalization?
     scores_t = self.laber.output_score(scores_t, local_normalize)
     return scores_t