Beispiel #1
0
 def __call__(self, query_up, key_up, rel_dist=None, input_scores=None):
     _att_scale_qk = self._att_scale_qk
     # -----
     # get dim info
     len_q, len_k = BK.get_shape(query_up, -2), BK.get_shape(key_up, -2)
     # get distance embeddings
     if rel_dist is None:
         rel_dist = self.get_rel_dist(len_q, len_k)
     if self.rel_dist_abs:  # use abs?
         rel_dist = BK.abs(rel_dist)
     dist_embs = self.E(rel_dist)  # [len_q, len_k, Demb]
     # -----
     # dist_up
     dist_up0 = self.affine_rel(dist_embs)  # [len_q, len_k, head*D]
     # -> [head, len_q, len_k, D]
     dist_up1 = dist_up0.view(
         BK.get_shape(dist_up0)[:-1] + self.split_dims).transpose(
             -2, -3).transpose(-3, -4)
     # -----
     # all items are [*, head, len_q, len_k]
     posi_scores = (input_scores if (input_scores is not None) else 0.)
     # item (b): <query, dist>: [head, len_q, len_k, D] * [*, head, len_q, D, 1] -> [*, head, len_q, len_k]
     item_b = (BK.matmul(dist_up1, query_up.unsqueeze(-1)) /
               _att_scale_qk).squeeze(-1)
     posi_scores += item_b
     # todo(note): remove this item_c since it is not related with rel_dist
     # # item (c): <key, u>: [*, head, len_k, D] * [head, D, 1] -> [*, head, 1, len_k]
     # item_c = (BK.matmul(key_up, self.vec_u.unsqueeze(-1)) / _att_scale_qk).squeeze(-1).unsqueeze(-2)
     # posi_scores += item_c
     # item (d): <dist, v>: [head, len_q, len_k, D] * [head, 1, D, 1] -> [head, len_q, len_k]
     item_d = (BK.matmul(dist_up1,
                         self.vec_v.unsqueeze(-2).unsqueeze(-1)) /
               _att_scale_qk).squeeze(-1)
     posi_scores += item_d
     return posi_scores
Beispiel #2
0
 def __call__(self,
              input_repr,
              mask_arr,
              require_loss,
              require_pred,
              gold_pos_arr=None):
     enc0_expr = self.enc(input_repr, mask_arr)  # [*, len, d]
     #
     enc1_expr = enc0_expr
     pos_probs, pos_losses_expr, pos_preds_expr = None, None, None
     if self.jpos_multitask:
         # get probabilities
         pos_logits = self.pred(enc0_expr)  # [*, len, nl]
         pos_probs = BK.softmax(pos_logits, dim=-1)
         # stacking for input -> output
         if self.jpos_stacking:
             enc1_expr = enc0_expr + BK.matmul(pos_probs, self.pos_weights)
         # simple cross entropy loss
         if require_loss and self.jpos_lambda > 0.:
             gold_probs = BK.gather_one_lastdim(
                 pos_probs, gold_pos_arr).squeeze(-1)  # [*, len]
             # todo(warn): multiplying the factor here, but not maksing here (masking in the final steps)
             pos_losses_expr = (-self.jpos_lambda) * gold_probs.log()
         # simple argmax for prediction
         if require_pred and self.jpos_decode:
             pos_preds_expr = pos_probs.max(dim=-1)[1]
     return enc1_expr, (pos_probs, pos_losses_expr, pos_preds_expr)
Beispiel #3
0
 def _score(self,
            input_expr,
            input_mask,
            scores_aug_tok=None,
            scores_aug_sent=None):
     # token level attention and score
     # calculate the attention
     query_tok = self.query_tok  # [L, D]
     query_tok_t = query_tok.transpose(0, 1)  # [D, L]
     att_scores = BK.matmul(input_expr, query_tok_t)  # [*, slen, L]
     att_scores += (1. - input_mask).unsqueeze(-1) * Constants.REAL_PRAC_MIN
     if scores_aug_tok is not None:  # margin
         att_scores += scores_aug_tok
     attn = BK.softmax(att_scores, -2)  # [*, slen, L]
     score_tok = (att_scores * attn).sum(-2)  # [*, L]
     # token level labeling softmax
     attn2 = BK.softmax(
         att_scores.view(BK.get_shape(att_scores)[:-2] + [-1]),
         -1)  # [*, slen*L]
     # sent level score
     query_sent = self.query_sent  # [L, D]
     context_sent = input_expr[:,
                               0] + input_expr[:,
                                               -1]  # [*, D], simply adding the two ends
     score_sent = BK.matmul(context_sent,
                            self.query_sent.transpose(0, 1))  # [*, L]
     # combine
     if self.lambda_score_tok < 0.:
         context_tok = BK.matmul(input_expr.transpose(
             -1, -2), attn).transpose(-1, -2).contiguous()  # [*, L, D]
         # 4*[*,L,D] -> [*, L]
         cur_lambda_score_tok = self.score_gate([
             context_tok,
             query_tok.unsqueeze(0),
             context_sent.unsqueeze(-2),
             query_sent.unsqueeze(0)
         ]).squeeze(-1)
     else:
         cur_lambda_score_tok = self.lambda_score_tok
     final_score = score_tok * cur_lambda_score_tok + score_sent * (
         1. - cur_lambda_score_tok)
     if scores_aug_sent is not None:
         final_score += scores_aug_sent
     if self.conf.score_sigmoid:  # margin
         final_score = BK.sigmoid(final_score)
     return final_score, attn, attn2  # [*, L], [*, slen, L], [*, slen*L]
Beispiel #4
0
 def __call__(self, query, key, accu_attn, mask_k, mask_qk, rel_dist):
     conf = self.conf
     # == calculate the dot-product scores
     # calculate the three: # [bs, len_?, head*D]; and also add sta ones if needed
     query_up, key_up = self.affine_q(query), self.affine_k(
         key)  # [*, len?, head?*Dqk]
     query_up, key_up = self._shape_project(
         query_up, True), self._shape_project(key_up,
                                              True)  # [*, head?, len_?, D]
     # original scores
     scores = BK.matmul(query_up, BK.transpose(
         key_up, -1, -2)) / self._att_scale_qk  # [*, head?, len_q, len_k]
     # == adding rel_dist ones
     if conf.use_rel_dist:
         scores = self.dist_helper(query_up,
                                   key_up,
                                   rel_dist=rel_dist,
                                   input_scores=scores)
     # tranpose
     scores = scores.transpose(-2,
                               -3).transpose(-1,
                                             -2)  # [*, len_q, len_k, head?]
     # == unhead score
     if conf.use_unhead_score:
         scores_t0, score_t1 = BK.split(scores, [1, self.head_count],
                                        -1)  # [*, len_q, len_k, 1|head]
         scores = scores_t0 + score_t1  # [*, len_q, len_k, head]
     # == combining with history accumulated attns
     if conf.use_lambq and accu_attn is not None:
         # todo(note): here we only consider "query" and "head", would it be necessary for "key"?
         lambq_vals = self.lambq_aff(
             query
         )  # [*, len_q, head], if for eg., using relu as fact, this>=0
         scores -= lambq_vals.unsqueeze(-2) * accu_attn
     # == score offset
     if conf.use_soff:
         # todo(note): here we only consider "query" and "head", key may be handled by "unhead_score"
         score_offset_t = self.soff_aff(query)  # [*, len_q, 1+head]
         score_offset_t0, score_offset_t1 = BK.split(
             score_offset_t, [1, self.head_count], -1)  # [*, len_q, 1|head]
         scores -= score_offset_t0.unsqueeze(-2)
         scores -= score_offset_t1.unsqueeze(
             -2)  # still [*, len_q, len_k, head]
     # == apply mask & no-self-loop
     # NEG_INF = Constants.REAL_PRAC_MIN
     NEG_INF = -1000.  # this should be enough
     NEG_INF2 = -2000.  # this should be enough
     if mask_k is not None:  # [*, 1, len_k, 1]
         scores += (1. - mask_k).unsqueeze(-2).unsqueeze(-1) * NEG_INF2
     if mask_qk is not None:  # [*, len_q, len_k, 1]
         scores += (1. - mask_qk).unsqueeze(-1) * NEG_INF2
     if self.no_self_loop:
         query_len = BK.get_shape(query, -2)
         assert query_len == BK.get_shape(
             key, -2), "Shape not matched for no_self_loop"
         scores += BK.eye(query_len).unsqueeze(
             -1) * NEG_INF  # [len_q, len_k, 1]
     return scores.contiguous()  # [*, len_q, len_k, head]
Beispiel #5
0
 def lookup_soft(self, cascade_scores: List):
     all_embeds = []
     for i in range(self.eff_max_layer):
         cur_scores = cascade_scores[i] * self.lookup_soft_alphas[
             i]  # [*, ?]
         cur_embeds = BK.matmul(cur_scores,
                                self.layered_embeds_lookup[i])  # [*, D]
         all_embeds.append(cur_embeds)
     ret_embed = self.lookup_summer(all_embeds)
     return ret_embed
Beispiel #6
0
 def _step(self, input_expr, input_mask, hard_coverage, prev_state, force_widx, force_lidx, free_beam_size):
     conf = self.conf
     free_mode = (force_widx is None)
     prev_state_h = prev_state[0]
     # =====
     # collect att scores
     key_up = self.affine_k([input_expr, hard_coverage.unsqueeze(-1)])  # [*, slen, h]
     query_up = self.affine_q([self.repos.unsqueeze(0), prev_state_h.unsqueeze(-2)])  # [*, R, h]
     orig_scores = BK.matmul(key_up, query_up.transpose(-2, -1))  # [*, slen, R]
     orig_scores += (1.-input_mask).unsqueeze(-1) * Constants.REAL_PRAC_MIN  # [*, slen, R]
     # first maximum across the R dim (this step is hard max)
     maxr_scores, maxr_idxes = orig_scores.max(-1)  # [*, slen]
     if conf.zero_eos_score:
         # use mask to make it able to be backward
         tmp_mask = BK.constants(BK.get_shape(maxr_scores), 1.)
         tmp_mask.index_fill_(-1, BK.input_idx(0), 0.)
         maxr_scores *= tmp_mask
     # then select over the slen dim (this step is prob based)
     maxr_logprobs = BK.log_softmax(maxr_scores)  # [*, slen]
     if free_mode:
         cur_beam_size = min(free_beam_size, BK.get_shape(maxr_logprobs, -1))
         sel_tok_logprobs, sel_tok_idxes = maxr_logprobs.topk(cur_beam_size, dim=-1, sorted=False)  # [*, beam]
     else:
         sel_tok_idxes = force_widx.unsqueeze(-1)  # [*, 1]
         sel_tok_logprobs = maxr_logprobs.gather(-1, sel_tok_idxes)  # [*, 1]
     # then collect the info and perform labeling
     lf_input_expr = BK.gather_first_dims(input_expr, sel_tok_idxes, -2)  # [*, ?, ~]
     lf_coverage = hard_coverage.gather(-1, sel_tok_idxes).unsqueeze(-1)  # [*, ?, 1]
     lf_repos = self.repos[maxr_idxes.gather(-1, sel_tok_idxes)]  # [*, ?, ~]  # todo(+3): using soft version?
     lf_prev_state = prev_state_h.unsqueeze(-2)  # [*, 1, ~]
     lab_hid_expr = self.lab_f([lf_input_expr, lf_coverage, lf_repos, lf_prev_state])  # [*, ?, ~]
     # final predicting labels
     # todo(+N): here we select only max at labeling part, only beam at previous one
     if free_mode:
         sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict(lab_hid_expr, None)  # [*, ?]
     else:
         sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict(lab_hid_expr, force_lidx.unsqueeze(-1))
     # no lab-logprob (*=0) for eos (sel_tok==0)
     sel_lab_logprobs *= (sel_tok_idxes>0).float()
     # compute next-state [*, ?, ~]
     # todo(note): here we flatten the first two dims
     tmp_rnn_dims = BK.get_shape(sel_tok_idxes) + [-1]
     tmp_rnn_input = BK.concat([lab_hid_expr, sel_lab_embeds], -1)
     tmp_rnn_input = tmp_rnn_input.view(-1, BK.get_shape(tmp_rnn_input, -1))
     tmp_rnn_hidden = [z.unsqueeze(-2).expand(tmp_rnn_dims).contiguous().view(-1, BK.get_shape(z, -1))
                       for z in prev_state]  # [*, ?, ?, D]
     next_state = self.rnn_unit(tmp_rnn_input, tmp_rnn_hidden, None)
     next_state = [z.view(tmp_rnn_dims) for z in next_state]
     return sel_tok_idxes, sel_tok_logprobs, sel_lab_idxes, sel_lab_logprobs, sel_lab_embeds, next_state
Beispiel #7
0
 def _raw_scores(self, input_expr):
     all_scores = []
     for i in range(self.eff_max_layer):
         # first, the scores of the current layer; here no dropout!
         pred_w, pred_b = self.layered_embeds_pred[i], self.biases_pred[
             i]  # [?, D], [?]
         cur_score = BK.matmul(input_expr, pred_w)  # [*, ?]
         if pred_b is not None:
             cur_score += pred_b
         # apply None mask (make it score 0., must be before adding prev)
         if self.zero_nil:
             cur_score *= (1. - self.layered_isnil[i]
                           )  # make it zero for NIL(None) types
         all_scores.append(cur_score)
     return all_scores
Beispiel #8
0
 def loss(self, repr_t, orig_map: Dict, **kwargs):
     conf = self.conf
     _tie_input_embeddings = conf.tie_input_embeddings
     # --
     # specify input
     add_root_token = self.add_root_token
     # get from inputs
     if isinstance(repr_t, (list, tuple)):
         l2r_repr_t, r2l_repr_t = repr_t
     elif self.split_input_blm:
         l2r_repr_t, r2l_repr_t = BK.chunk(repr_t, 2, -1)
     else:
         l2r_repr_t, r2l_repr_t = repr_t, None
     # l2r and r2l
     word_t = BK.input_idx(orig_map["word"])  # [bs, rlen]
     slice_zero_t = BK.zeros([BK.get_shape(word_t, 0), 1]).long()  # [bs, 1]
     if add_root_token:
         l2r_trg_t = BK.concat([word_t, slice_zero_t],
                               -1)  # pad one extra 0, [bs, rlen+1]
         r2l_trg_t = BK.concat(
             [slice_zero_t, slice_zero_t, word_t[:, :-1]],
             -1)  # pad two extra 0 at front, [bs, 2+rlen-1]
     else:
         l2r_trg_t = BK.concat(
             [word_t[:, 1:], slice_zero_t], -1
         )  # pad one extra 0, but remove the first one, [bs, -1+rlen+1]
         r2l_trg_t = BK.concat(
             [slice_zero_t, word_t[:, :-1]],
             -1)  # pad one extra 0 at front, [bs, 1+rlen-1]
     # gather the losses
     all_losses = []
     pred_range_min, pred_range_max = max(
         1, conf.min_pred_rank), self.pred_size - 1
     if _tie_input_embeddings:
         pred_W = self.inputter_embed_node.E.E[:self.
                                               pred_size]  # [PSize, Dim]
     else:
         pred_W = None
     # get input embeddings for output
     for pred_name, hid_node, pred_node, input_t, trg_t in \
                 zip(["l2r", "r2l"], [self.l2r_hid_layer, self.r2l_hid_layer], [self.l2r_pred, self.r2l_pred],
                     [l2r_repr_t, r2l_repr_t], [l2r_trg_t, r2l_trg_t]):
         if input_t is None:
             continue
         # hidden
         hid_t = hid_node(
             input_t) if hid_node else input_t  # [bs, slen, hid]
         # pred: [bs, slen, Vsize]
         if _tie_input_embeddings:
             scores_t = BK.matmul(hid_t, pred_W.T)
         else:
             scores_t = pred_node(hid_t)
         # loss
         mask_t = ((trg_t >= pred_range_min) &
                   (trg_t <= pred_range_max)).float()  # [bs, slen]
         trg_t.clamp_(max=pred_range_max)  # make it in range
         losses_t = BK.loss_nll(scores_t, trg_t) * mask_t  # [bs, slen]
         _, argmax_idxes = scores_t.max(-1)  # [bs, slen]
         corrs_t = (argmax_idxes == trg_t).float() * mask_t  # [bs, slen]
         # compile leaf loss
         one_loss = LossHelper.compile_leaf_info(pred_name,
                                                 losses_t.sum(),
                                                 mask_t.sum(),
                                                 loss_lambda=1.,
                                                 corr=corrs_t.sum())
         all_losses.append(one_loss)
     return self._compile_component_loss("plm", all_losses)
Beispiel #9
0
 def loss(self,
          repr_ts,
          input_erase_mask_arr,
          orig_map: Dict,
          active_hid=True,
          **kwargs):
     conf = self.conf
     _tie_input_embeddings = conf.tie_input_embeddings
     # prepare idxes for the masked ones
     if self.add_root_token:  # offset for the special root added in embedder
         mask_idxes, mask_valids = BK.mask2idx(
             BK.input_real(input_erase_mask_arr),
             padding_idx=-1)  # [bsize, ?]
         repr_mask_idxes = mask_idxes + 1
         mask_idxes.clamp_(min=0)
     else:
         mask_idxes, mask_valids = BK.mask2idx(
             BK.input_real(input_erase_mask_arr))  # [bsize, ?]
         repr_mask_idxes = mask_idxes
     # get the losses
     if BK.get_shape(mask_idxes, -1) == 0:  # no loss
         return self._compile_component_loss("mlm", [])
     else:
         if not isinstance(repr_ts, (List, Tuple)):
             repr_ts = [repr_ts]
         target_word_scores, target_pos_scores = [], []
         target_pos_scores = None  # todo(+N): for simplicity, currently ignore this one!!
         for layer_idx in conf.loss_layers:
             # calculate scores
             target_reprs = BK.gather_first_dims(repr_ts[layer_idx],
                                                 repr_mask_idxes,
                                                 1)  # [bsize, ?, *]
             if self.hid_layer and active_hid:  # todo(+N): sometimes, we only want last softmax, need to ensure dim at outside!
                 target_hids = self.hid_layer(target_reprs)
             else:
                 target_hids = target_reprs
             if _tie_input_embeddings:
                 pred_W = self.inputter_word_node.E.E[:self.
                                                      pred_word_size]  # [PSize, Dim]
                 target_word_scores.append(BK.matmul(
                     target_hids, pred_W.T))  # List[bsize, ?, Vw]
             else:
                 target_word_scores.append(self.pred_word_layer(
                     target_hids))  # List[bsize, ?, Vw]
         # gather the losses
         all_losses = []
         for pred_name, target_scores, loss_lambda, range_min, range_max in \
                 zip(["word", "pos"], [target_word_scores, target_pos_scores], [conf.lambda_word, conf.lambda_pos],
                     [conf.min_pred_rank, 0], [min(conf.max_pred_rank, self.pred_word_size-1), self.pred_pos_size-1]):
             if loss_lambda > 0.:
                 seq_idx_t = BK.input_idx(
                     orig_map[pred_name])  # [bsize, slen]
                 target_idx_t = seq_idx_t.gather(-1,
                                                 mask_idxes)  # [bsize, ?]
                 ranged_mask_valids = mask_valids * (
                     target_idx_t >= range_min).float() * (
                         target_idx_t <= range_max).float()
                 target_idx_t[(ranged_mask_valids <
                               1.)] = 0  # make sure invalid ones in range
                 # calculate for each layer
                 all_layer_losses, all_layer_scores = [], []
                 for one_layer_idx, one_target_scores in enumerate(
                         target_scores):
                     # get loss: [bsize, ?]
                     one_pred_losses = BK.loss_nll(
                         one_target_scores,
                         target_idx_t) * conf.loss_weights[one_layer_idx]
                     all_layer_losses.append(one_pred_losses)
                     # get scores
                     one_pred_scores = BK.log_softmax(
                         one_target_scores,
                         -1) * conf.loss_weights[one_layer_idx]
                     all_layer_scores.append(one_pred_scores)
                 # combine all layers
                 pred_losses = self.loss_comb_f(all_layer_losses)
                 pred_loss_sum = (pred_losses * ranged_mask_valids).sum()
                 pred_loss_count = ranged_mask_valids.sum()
                 # argmax
                 _, argmax_idxes = self.score_comb_f(all_layer_scores).max(
                     -1)
                 pred_corrs = (argmax_idxes
                               == target_idx_t).float() * ranged_mask_valids
                 pred_corr_count = pred_corrs.sum()
                 # compile leaf loss
                 r_loss = LossHelper.compile_leaf_info(
                     pred_name,
                     pred_loss_sum,
                     pred_loss_count,
                     loss_lambda=loss_lambda,
                     corr=pred_corr_count)
                 all_losses.append(r_loss)
         return self._compile_component_loss("mlm", all_losses)