def __call__(self, query_up, key_up, rel_dist=None, input_scores=None): _att_scale_qk = self._att_scale_qk # ----- # get dim info len_q, len_k = BK.get_shape(query_up, -2), BK.get_shape(key_up, -2) # get distance embeddings if rel_dist is None: rel_dist = self.get_rel_dist(len_q, len_k) if self.rel_dist_abs: # use abs? rel_dist = BK.abs(rel_dist) dist_embs = self.E(rel_dist) # [len_q, len_k, Demb] # ----- # dist_up dist_up0 = self.affine_rel(dist_embs) # [len_q, len_k, head*D] # -> [head, len_q, len_k, D] dist_up1 = dist_up0.view( BK.get_shape(dist_up0)[:-1] + self.split_dims).transpose( -2, -3).transpose(-3, -4) # ----- # all items are [*, head, len_q, len_k] posi_scores = (input_scores if (input_scores is not None) else 0.) # item (b): <query, dist>: [head, len_q, len_k, D] * [*, head, len_q, D, 1] -> [*, head, len_q, len_k] item_b = (BK.matmul(dist_up1, query_up.unsqueeze(-1)) / _att_scale_qk).squeeze(-1) posi_scores += item_b # todo(note): remove this item_c since it is not related with rel_dist # # item (c): <key, u>: [*, head, len_k, D] * [head, D, 1] -> [*, head, 1, len_k] # item_c = (BK.matmul(key_up, self.vec_u.unsqueeze(-1)) / _att_scale_qk).squeeze(-1).unsqueeze(-2) # posi_scores += item_c # item (d): <dist, v>: [head, len_q, len_k, D] * [head, 1, D, 1] -> [head, len_q, len_k] item_d = (BK.matmul(dist_up1, self.vec_v.unsqueeze(-2).unsqueeze(-1)) / _att_scale_qk).squeeze(-1) posi_scores += item_d return posi_scores
def __call__(self, input_repr, mask_arr, require_loss, require_pred, gold_pos_arr=None): enc0_expr = self.enc(input_repr, mask_arr) # [*, len, d] # enc1_expr = enc0_expr pos_probs, pos_losses_expr, pos_preds_expr = None, None, None if self.jpos_multitask: # get probabilities pos_logits = self.pred(enc0_expr) # [*, len, nl] pos_probs = BK.softmax(pos_logits, dim=-1) # stacking for input -> output if self.jpos_stacking: enc1_expr = enc0_expr + BK.matmul(pos_probs, self.pos_weights) # simple cross entropy loss if require_loss and self.jpos_lambda > 0.: gold_probs = BK.gather_one_lastdim( pos_probs, gold_pos_arr).squeeze(-1) # [*, len] # todo(warn): multiplying the factor here, but not maksing here (masking in the final steps) pos_losses_expr = (-self.jpos_lambda) * gold_probs.log() # simple argmax for prediction if require_pred and self.jpos_decode: pos_preds_expr = pos_probs.max(dim=-1)[1] return enc1_expr, (pos_probs, pos_losses_expr, pos_preds_expr)
def _score(self, input_expr, input_mask, scores_aug_tok=None, scores_aug_sent=None): # token level attention and score # calculate the attention query_tok = self.query_tok # [L, D] query_tok_t = query_tok.transpose(0, 1) # [D, L] att_scores = BK.matmul(input_expr, query_tok_t) # [*, slen, L] att_scores += (1. - input_mask).unsqueeze(-1) * Constants.REAL_PRAC_MIN if scores_aug_tok is not None: # margin att_scores += scores_aug_tok attn = BK.softmax(att_scores, -2) # [*, slen, L] score_tok = (att_scores * attn).sum(-2) # [*, L] # token level labeling softmax attn2 = BK.softmax( att_scores.view(BK.get_shape(att_scores)[:-2] + [-1]), -1) # [*, slen*L] # sent level score query_sent = self.query_sent # [L, D] context_sent = input_expr[:, 0] + input_expr[:, -1] # [*, D], simply adding the two ends score_sent = BK.matmul(context_sent, self.query_sent.transpose(0, 1)) # [*, L] # combine if self.lambda_score_tok < 0.: context_tok = BK.matmul(input_expr.transpose( -1, -2), attn).transpose(-1, -2).contiguous() # [*, L, D] # 4*[*,L,D] -> [*, L] cur_lambda_score_tok = self.score_gate([ context_tok, query_tok.unsqueeze(0), context_sent.unsqueeze(-2), query_sent.unsqueeze(0) ]).squeeze(-1) else: cur_lambda_score_tok = self.lambda_score_tok final_score = score_tok * cur_lambda_score_tok + score_sent * ( 1. - cur_lambda_score_tok) if scores_aug_sent is not None: final_score += scores_aug_sent if self.conf.score_sigmoid: # margin final_score = BK.sigmoid(final_score) return final_score, attn, attn2 # [*, L], [*, slen, L], [*, slen*L]
def __call__(self, query, key, accu_attn, mask_k, mask_qk, rel_dist): conf = self.conf # == calculate the dot-product scores # calculate the three: # [bs, len_?, head*D]; and also add sta ones if needed query_up, key_up = self.affine_q(query), self.affine_k( key) # [*, len?, head?*Dqk] query_up, key_up = self._shape_project( query_up, True), self._shape_project(key_up, True) # [*, head?, len_?, D] # original scores scores = BK.matmul(query_up, BK.transpose( key_up, -1, -2)) / self._att_scale_qk # [*, head?, len_q, len_k] # == adding rel_dist ones if conf.use_rel_dist: scores = self.dist_helper(query_up, key_up, rel_dist=rel_dist, input_scores=scores) # tranpose scores = scores.transpose(-2, -3).transpose(-1, -2) # [*, len_q, len_k, head?] # == unhead score if conf.use_unhead_score: scores_t0, score_t1 = BK.split(scores, [1, self.head_count], -1) # [*, len_q, len_k, 1|head] scores = scores_t0 + score_t1 # [*, len_q, len_k, head] # == combining with history accumulated attns if conf.use_lambq and accu_attn is not None: # todo(note): here we only consider "query" and "head", would it be necessary for "key"? lambq_vals = self.lambq_aff( query ) # [*, len_q, head], if for eg., using relu as fact, this>=0 scores -= lambq_vals.unsqueeze(-2) * accu_attn # == score offset if conf.use_soff: # todo(note): here we only consider "query" and "head", key may be handled by "unhead_score" score_offset_t = self.soff_aff(query) # [*, len_q, 1+head] score_offset_t0, score_offset_t1 = BK.split( score_offset_t, [1, self.head_count], -1) # [*, len_q, 1|head] scores -= score_offset_t0.unsqueeze(-2) scores -= score_offset_t1.unsqueeze( -2) # still [*, len_q, len_k, head] # == apply mask & no-self-loop # NEG_INF = Constants.REAL_PRAC_MIN NEG_INF = -1000. # this should be enough NEG_INF2 = -2000. # this should be enough if mask_k is not None: # [*, 1, len_k, 1] scores += (1. - mask_k).unsqueeze(-2).unsqueeze(-1) * NEG_INF2 if mask_qk is not None: # [*, len_q, len_k, 1] scores += (1. - mask_qk).unsqueeze(-1) * NEG_INF2 if self.no_self_loop: query_len = BK.get_shape(query, -2) assert query_len == BK.get_shape( key, -2), "Shape not matched for no_self_loop" scores += BK.eye(query_len).unsqueeze( -1) * NEG_INF # [len_q, len_k, 1] return scores.contiguous() # [*, len_q, len_k, head]
def lookup_soft(self, cascade_scores: List): all_embeds = [] for i in range(self.eff_max_layer): cur_scores = cascade_scores[i] * self.lookup_soft_alphas[ i] # [*, ?] cur_embeds = BK.matmul(cur_scores, self.layered_embeds_lookup[i]) # [*, D] all_embeds.append(cur_embeds) ret_embed = self.lookup_summer(all_embeds) return ret_embed
def _step(self, input_expr, input_mask, hard_coverage, prev_state, force_widx, force_lidx, free_beam_size): conf = self.conf free_mode = (force_widx is None) prev_state_h = prev_state[0] # ===== # collect att scores key_up = self.affine_k([input_expr, hard_coverage.unsqueeze(-1)]) # [*, slen, h] query_up = self.affine_q([self.repos.unsqueeze(0), prev_state_h.unsqueeze(-2)]) # [*, R, h] orig_scores = BK.matmul(key_up, query_up.transpose(-2, -1)) # [*, slen, R] orig_scores += (1.-input_mask).unsqueeze(-1) * Constants.REAL_PRAC_MIN # [*, slen, R] # first maximum across the R dim (this step is hard max) maxr_scores, maxr_idxes = orig_scores.max(-1) # [*, slen] if conf.zero_eos_score: # use mask to make it able to be backward tmp_mask = BK.constants(BK.get_shape(maxr_scores), 1.) tmp_mask.index_fill_(-1, BK.input_idx(0), 0.) maxr_scores *= tmp_mask # then select over the slen dim (this step is prob based) maxr_logprobs = BK.log_softmax(maxr_scores) # [*, slen] if free_mode: cur_beam_size = min(free_beam_size, BK.get_shape(maxr_logprobs, -1)) sel_tok_logprobs, sel_tok_idxes = maxr_logprobs.topk(cur_beam_size, dim=-1, sorted=False) # [*, beam] else: sel_tok_idxes = force_widx.unsqueeze(-1) # [*, 1] sel_tok_logprobs = maxr_logprobs.gather(-1, sel_tok_idxes) # [*, 1] # then collect the info and perform labeling lf_input_expr = BK.gather_first_dims(input_expr, sel_tok_idxes, -2) # [*, ?, ~] lf_coverage = hard_coverage.gather(-1, sel_tok_idxes).unsqueeze(-1) # [*, ?, 1] lf_repos = self.repos[maxr_idxes.gather(-1, sel_tok_idxes)] # [*, ?, ~] # todo(+3): using soft version? lf_prev_state = prev_state_h.unsqueeze(-2) # [*, 1, ~] lab_hid_expr = self.lab_f([lf_input_expr, lf_coverage, lf_repos, lf_prev_state]) # [*, ?, ~] # final predicting labels # todo(+N): here we select only max at labeling part, only beam at previous one if free_mode: sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict(lab_hid_expr, None) # [*, ?] else: sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict(lab_hid_expr, force_lidx.unsqueeze(-1)) # no lab-logprob (*=0) for eos (sel_tok==0) sel_lab_logprobs *= (sel_tok_idxes>0).float() # compute next-state [*, ?, ~] # todo(note): here we flatten the first two dims tmp_rnn_dims = BK.get_shape(sel_tok_idxes) + [-1] tmp_rnn_input = BK.concat([lab_hid_expr, sel_lab_embeds], -1) tmp_rnn_input = tmp_rnn_input.view(-1, BK.get_shape(tmp_rnn_input, -1)) tmp_rnn_hidden = [z.unsqueeze(-2).expand(tmp_rnn_dims).contiguous().view(-1, BK.get_shape(z, -1)) for z in prev_state] # [*, ?, ?, D] next_state = self.rnn_unit(tmp_rnn_input, tmp_rnn_hidden, None) next_state = [z.view(tmp_rnn_dims) for z in next_state] return sel_tok_idxes, sel_tok_logprobs, sel_lab_idxes, sel_lab_logprobs, sel_lab_embeds, next_state
def _raw_scores(self, input_expr): all_scores = [] for i in range(self.eff_max_layer): # first, the scores of the current layer; here no dropout! pred_w, pred_b = self.layered_embeds_pred[i], self.biases_pred[ i] # [?, D], [?] cur_score = BK.matmul(input_expr, pred_w) # [*, ?] if pred_b is not None: cur_score += pred_b # apply None mask (make it score 0., must be before adding prev) if self.zero_nil: cur_score *= (1. - self.layered_isnil[i] ) # make it zero for NIL(None) types all_scores.append(cur_score) return all_scores
def loss(self, repr_t, orig_map: Dict, **kwargs): conf = self.conf _tie_input_embeddings = conf.tie_input_embeddings # -- # specify input add_root_token = self.add_root_token # get from inputs if isinstance(repr_t, (list, tuple)): l2r_repr_t, r2l_repr_t = repr_t elif self.split_input_blm: l2r_repr_t, r2l_repr_t = BK.chunk(repr_t, 2, -1) else: l2r_repr_t, r2l_repr_t = repr_t, None # l2r and r2l word_t = BK.input_idx(orig_map["word"]) # [bs, rlen] slice_zero_t = BK.zeros([BK.get_shape(word_t, 0), 1]).long() # [bs, 1] if add_root_token: l2r_trg_t = BK.concat([word_t, slice_zero_t], -1) # pad one extra 0, [bs, rlen+1] r2l_trg_t = BK.concat( [slice_zero_t, slice_zero_t, word_t[:, :-1]], -1) # pad two extra 0 at front, [bs, 2+rlen-1] else: l2r_trg_t = BK.concat( [word_t[:, 1:], slice_zero_t], -1 ) # pad one extra 0, but remove the first one, [bs, -1+rlen+1] r2l_trg_t = BK.concat( [slice_zero_t, word_t[:, :-1]], -1) # pad one extra 0 at front, [bs, 1+rlen-1] # gather the losses all_losses = [] pred_range_min, pred_range_max = max( 1, conf.min_pred_rank), self.pred_size - 1 if _tie_input_embeddings: pred_W = self.inputter_embed_node.E.E[:self. pred_size] # [PSize, Dim] else: pred_W = None # get input embeddings for output for pred_name, hid_node, pred_node, input_t, trg_t in \ zip(["l2r", "r2l"], [self.l2r_hid_layer, self.r2l_hid_layer], [self.l2r_pred, self.r2l_pred], [l2r_repr_t, r2l_repr_t], [l2r_trg_t, r2l_trg_t]): if input_t is None: continue # hidden hid_t = hid_node( input_t) if hid_node else input_t # [bs, slen, hid] # pred: [bs, slen, Vsize] if _tie_input_embeddings: scores_t = BK.matmul(hid_t, pred_W.T) else: scores_t = pred_node(hid_t) # loss mask_t = ((trg_t >= pred_range_min) & (trg_t <= pred_range_max)).float() # [bs, slen] trg_t.clamp_(max=pred_range_max) # make it in range losses_t = BK.loss_nll(scores_t, trg_t) * mask_t # [bs, slen] _, argmax_idxes = scores_t.max(-1) # [bs, slen] corrs_t = (argmax_idxes == trg_t).float() * mask_t # [bs, slen] # compile leaf loss one_loss = LossHelper.compile_leaf_info(pred_name, losses_t.sum(), mask_t.sum(), loss_lambda=1., corr=corrs_t.sum()) all_losses.append(one_loss) return self._compile_component_loss("plm", all_losses)
def loss(self, repr_ts, input_erase_mask_arr, orig_map: Dict, active_hid=True, **kwargs): conf = self.conf _tie_input_embeddings = conf.tie_input_embeddings # prepare idxes for the masked ones if self.add_root_token: # offset for the special root added in embedder mask_idxes, mask_valids = BK.mask2idx( BK.input_real(input_erase_mask_arr), padding_idx=-1) # [bsize, ?] repr_mask_idxes = mask_idxes + 1 mask_idxes.clamp_(min=0) else: mask_idxes, mask_valids = BK.mask2idx( BK.input_real(input_erase_mask_arr)) # [bsize, ?] repr_mask_idxes = mask_idxes # get the losses if BK.get_shape(mask_idxes, -1) == 0: # no loss return self._compile_component_loss("mlm", []) else: if not isinstance(repr_ts, (List, Tuple)): repr_ts = [repr_ts] target_word_scores, target_pos_scores = [], [] target_pos_scores = None # todo(+N): for simplicity, currently ignore this one!! for layer_idx in conf.loss_layers: # calculate scores target_reprs = BK.gather_first_dims(repr_ts[layer_idx], repr_mask_idxes, 1) # [bsize, ?, *] if self.hid_layer and active_hid: # todo(+N): sometimes, we only want last softmax, need to ensure dim at outside! target_hids = self.hid_layer(target_reprs) else: target_hids = target_reprs if _tie_input_embeddings: pred_W = self.inputter_word_node.E.E[:self. pred_word_size] # [PSize, Dim] target_word_scores.append(BK.matmul( target_hids, pred_W.T)) # List[bsize, ?, Vw] else: target_word_scores.append(self.pred_word_layer( target_hids)) # List[bsize, ?, Vw] # gather the losses all_losses = [] for pred_name, target_scores, loss_lambda, range_min, range_max in \ zip(["word", "pos"], [target_word_scores, target_pos_scores], [conf.lambda_word, conf.lambda_pos], [conf.min_pred_rank, 0], [min(conf.max_pred_rank, self.pred_word_size-1), self.pred_pos_size-1]): if loss_lambda > 0.: seq_idx_t = BK.input_idx( orig_map[pred_name]) # [bsize, slen] target_idx_t = seq_idx_t.gather(-1, mask_idxes) # [bsize, ?] ranged_mask_valids = mask_valids * ( target_idx_t >= range_min).float() * ( target_idx_t <= range_max).float() target_idx_t[(ranged_mask_valids < 1.)] = 0 # make sure invalid ones in range # calculate for each layer all_layer_losses, all_layer_scores = [], [] for one_layer_idx, one_target_scores in enumerate( target_scores): # get loss: [bsize, ?] one_pred_losses = BK.loss_nll( one_target_scores, target_idx_t) * conf.loss_weights[one_layer_idx] all_layer_losses.append(one_pred_losses) # get scores one_pred_scores = BK.log_softmax( one_target_scores, -1) * conf.loss_weights[one_layer_idx] all_layer_scores.append(one_pred_scores) # combine all layers pred_losses = self.loss_comb_f(all_layer_losses) pred_loss_sum = (pred_losses * ranged_mask_valids).sum() pred_loss_count = ranged_mask_valids.sum() # argmax _, argmax_idxes = self.score_comb_f(all_layer_scores).max( -1) pred_corrs = (argmax_idxes == target_idx_t).float() * ranged_mask_valids pred_corr_count = pred_corrs.sum() # compile leaf loss r_loss = LossHelper.compile_leaf_info( pred_name, pred_loss_sum, pred_loss_count, loss_lambda=loss_lambda, corr=pred_corr_count) all_losses.append(r_loss) return self._compile_component_loss("mlm", all_losses)