def __call__(self, query, key, accu_attn, mask_k, mask_qk, rel_dist): conf = self.conf # == calculate the dot-product scores # calculate the three: # [bs, len_?, head*D]; and also add sta ones if needed query_up, key_up = self.affine_q(query), self.affine_k( key) # [*, len?, head?*Dqk] query_up, key_up = self._shape_project( query_up, True), self._shape_project(key_up, True) # [*, head?, len_?, D] # original scores scores = BK.matmul(query_up, BK.transpose( key_up, -1, -2)) / self._att_scale_qk # [*, head?, len_q, len_k] # == adding rel_dist ones if conf.use_rel_dist: scores = self.dist_helper(query_up, key_up, rel_dist=rel_dist, input_scores=scores) # tranpose scores = scores.transpose(-2, -3).transpose(-1, -2) # [*, len_q, len_k, head?] # == unhead score if conf.use_unhead_score: scores_t0, score_t1 = BK.split(scores, [1, self.head_count], -1) # [*, len_q, len_k, 1|head] scores = scores_t0 + score_t1 # [*, len_q, len_k, head] # == combining with history accumulated attns if conf.use_lambq and accu_attn is not None: # todo(note): here we only consider "query" and "head", would it be necessary for "key"? lambq_vals = self.lambq_aff( query ) # [*, len_q, head], if for eg., using relu as fact, this>=0 scores -= lambq_vals.unsqueeze(-2) * accu_attn # == score offset if conf.use_soff: # todo(note): here we only consider "query" and "head", key may be handled by "unhead_score" score_offset_t = self.soff_aff(query) # [*, len_q, 1+head] score_offset_t0, score_offset_t1 = BK.split( score_offset_t, [1, self.head_count], -1) # [*, len_q, 1|head] scores -= score_offset_t0.unsqueeze(-2) scores -= score_offset_t1.unsqueeze( -2) # still [*, len_q, len_k, head] # == apply mask & no-self-loop # NEG_INF = Constants.REAL_PRAC_MIN NEG_INF = -1000. # this should be enough NEG_INF2 = -2000. # this should be enough if mask_k is not None: # [*, 1, len_k, 1] scores += (1. - mask_k).unsqueeze(-2).unsqueeze(-1) * NEG_INF2 if mask_qk is not None: # [*, len_q, len_k, 1] scores += (1. - mask_qk).unsqueeze(-1) * NEG_INF2 if self.no_self_loop: query_len = BK.get_shape(query, -2) assert query_len == BK.get_shape( key, -2), "Shape not matched for no_self_loop" scores += BK.eye(query_len).unsqueeze( -1) * NEG_INF # [len_q, len_k, 1] return scores.contiguous() # [*, len_q, len_k, head]
def _normalize(self, cnode: ConcreteNode, orig_scores, use_noop: bool, noop_fixed_val: float, temperature: float, dim: int): cur_shape = BK.get_shape(orig_scores) # original orig_that_dim = cur_shape[dim] cur_shape[dim] = 1 if use_noop: noop_scores = BK.constants(cur_shape, value=noop_fixed_val) # [*, 1, *] to_norm_scores = BK.concat([orig_scores, noop_scores], dim=dim) # [*, D+1, *] else: to_norm_scores = orig_scores # [*, D, *] # normalize prob_full = cnode(to_norm_scores, temperature=temperature, dim=dim) # [*, ?, *] if use_noop: prob_valid, prob_noop = BK.split(prob_full, [orig_that_dim, 1], dim) # [*, D|1, *] else: prob_valid, prob_noop = prob_full, None return prob_valid, prob_noop, prob_full