def perform_attention(self, query, dim_per_head, key, relations_keys, mask, value, relations_values, batch_size, head_count, query_len, key_len, shape, unshape): # 2) Calculate and scale scores. query = query / math.sqrt(dim_per_head) # batch x num_heads x query_len x key_len query_key = torch.matmul(query, key.transpose(2, 3)) if self.max_relative_positions > 0 and type == "self": scores = query_key + relative_matmul(query, relations_keys, True) else: scores = query_key scores = scores.float() if mask is not None: mask = mask.unsqueeze(1) # [B, 1, 1, T_values] scores = scores.masked_fill(mask, -1e18) # 3) Apply attention dropout and compute context vectors. attn = self.softmax(scores).to(query.dtype) drop_attn = self.dropout(attn) context_original = torch.matmul(drop_attn, value) if self.max_relative_positions > 0 and type == "self": context = unshape( context_original + relative_matmul(drop_attn, relations_values, False)) else: context = unshape(context_original) output = self.final_linear(context) # Return one attn top_attn = attn \ .view(batch_size, head_count, query_len, key_len)[:, 0, :, :] \ .contiguous() return output, top_attn
def forward(self, key, value, query, mask=None, layer_cache=None, attn_type=None, gold_par_attn=None, gold_ch_attn=None): """ Compute the context vector and the attention vectors. Args: key (FloatTensor): set of `key_len` key vectors ``(batch, key_len, dim)`` value (FloatTensor): set of `key_len` value vectors ``(batch, key_len, dim)`` query (FloatTensor): set of `query_len` query vectors ``(batch, query_len, dim)`` mask: binary mask 1/0 indicating which keys have zero / non-zero attention ``(batch, query_len, key_len)`` Returns: (FloatTensor, FloatTensor): * output context vectors ``(batch, query_len, dim)`` * one of the attention vectors ``(batch, query_len, key_len)`` """ batch_size = key.size(0) dim_per_head = self.dim_per_head head_count = self.head_count key_len = key.size(1) query_len = query.size(1) def shape(x): """Projection.""" return x.view(batch_size, -1, head_count, dim_per_head) \ .transpose(1, 2) def unshape(x): """Compute context.""" return x.transpose(1, 2).contiguous() \ .view(batch_size, -1, head_count * dim_per_head) def predict_ch_label(attn, value): if not self.opt.biaffine: value = unshape(value) ch_attn_repeat = torch.repeat_interleave(attn, value.size(2), dim=2) \ .view(value.size(0), value.size(1), value.size(1), value.size(2)) value_repeat = torch.repeat_interleave(value, value.size(1), dim=1) \ .view(value.size(0), value.size(1), value.size(1), value.size(2)).transpose(1,2).contiguous() chs = ch_attn_repeat * value_repeat ch_label_h = torch.cat([chs, value_repeat], 3) if self.opt.biaffine: w_ch = self.Wlabel_ch_linear(chs, value_repeat) b_ch = self.blabel_ch_linear(ch_label_h) ch_labels = w_ch + b_ch else: ch_labels = self.p_ch_label(ch_label_h) return ch_labels def predict_par_label(attn, value): if not self.opt.biaffine: value = unshape(value) par = torch.matmul(attn, value) par_label_h = torch.cat([par, value], 2) #par_label_h = torch.cat([value, par],2) if self.opt.biaffine: w_par = self.Wlabel_par_linear(par, value) b_par = self.blabel_par_linear(par_label_h) par_labels = w_par + b_par else: par_labels = self.p_par_label(par_label_h) #par_labels = self.p_par_label(par_label_h) return par_labels # 1) Project key, value, and query. if layer_cache is not None: if attn_type == "self": query, key, value = self.linear_query(query),\ self.linear_keys(query),\ self.linear_values(query) key = shape(key) value = shape(value) if layer_cache["self_keys"] is not None: key = torch.cat((layer_cache["self_keys"], key), dim=2) if layer_cache["self_values"] is not None: value = torch.cat((layer_cache["self_values"], value), dim=2) layer_cache["self_keys"] = key layer_cache["self_values"] = value elif attn_type == "context": query = self.linear_query(query) if layer_cache["memory_keys"] is None: key, value = self.linear_keys(key),\ self.linear_values(value) key = shape(key) value = shape(value) else: key, value = layer_cache["memory_keys"],\ layer_cache["memory_values"] layer_cache["memory_keys"] = key layer_cache["memory_values"] = value else: key = self.linear_keys(key) value = self.linear_values(value) query = self.linear_query(query) key = shape(key) value = shape(value) if self.max_relative_positions > 0 and attn_type == "self": key_len = key.size(2) # 1 or key_len x key_len relative_positions_matrix = generate_relative_positions_matrix( key_len, self.max_relative_positions, cache=True if layer_cache is not None else False) # 1 or key_len x key_len x dim_per_head relations_keys = self.relative_positions_embeddings( relative_positions_matrix.to(key.device)) # 1 or key_len x key_len x dim_per_head relations_values = self.relative_positions_embeddings( relative_positions_matrix.to(key.device)) query = shape(query) key_len = key.size(2) query_len = query.size(2) # 2) Calculate and scale scores. query = query / math.sqrt(dim_per_head) if self.label_emb is not None and self.opt.biaffine: query_key = torch.matmul(query[:, 2:], key.transpose(2, 3)[:, 2:]) w_par = torch.matmul(self.Warc_par_linear(query[:, 0]), key.transpose(2, 3)[:, 0]) w_ch = torch.matmul(self.Warc_ch_linear(query[:, 1]), key.transpose(2, 3)[:, 1]) b_par = self.barc_par_linear(query[:, 0]).repeat_interleave( query.size(2), dim=2) b_ch = self.barc_ch_linear(query[:, 1]).repeat_interleave( query.size(2), dim=2) arc_par = w_par + b_par arc_ch = w_ch + b_ch query_key = torch.cat( [arc_par.unsqueeze(1), arc_ch.unsqueeze(1), query_key], dim=1) else: # batch x num_heads x query_len x key_len query_key = torch.matmul(query, key.transpose(2, 3)) if self.max_relative_positions > 0 and attn_type == "self": scores = query_key + relative_matmul(query, relations_keys, True) else: scores = query_key scores = scores.float() if mask is not None: mask = mask.unsqueeze(1) # [B, 1, 1, T_values] scores = scores.masked_fill(mask, -1e18) # 3) Apply attention dropout and compute context vectors. attn = self.softmax(scores).to(query.dtype) drop_attn = self.dropout(attn) context_original = torch.matmul(drop_attn, value) if self.max_relative_positions > 0 and attn_type == "self": context = unshape( context_original + relative_matmul(drop_attn, relations_values, False)) else: context = unshape(context_original) ch_labels = None par_labels = None if gold_ch_attn is not None: if self.opt.biaffine: par_labels = predict_par_label(gold_par_attn, value[:, 0]) ch_labels = predict_ch_label(gold_ch_attn, value[:, 1]) else: par_labels = predict_par_label(gold_par_attn, value) ch_labels = predict_ch_label(gold_ch_attn, value) output = self.final_linear(context) top_attn = attn \ .view(batch_size, head_count, query_len, key_len)[:, 0, :, :] \ .contiguous() second_attn = attn \ .view(batch_size, head_count, query_len, key_len)[:, 1, :, :] \ .contiguous() return output, top_attn, second_attn, ch_labels, par_labels
def forward(self, key, value, query, mask=None, layer_cache=None, attn_type=None, decoder=False): """ Compute the context vector and the attention vectors. Args: key (FloatTensor): set of `key_len` key vectors ``(batch, key_len, dim)`` value (FloatTensor): set of `key_len` value vectors ``(batch, key_len, dim)`` query (FloatTensor): set of `query_len` query vectors ``(batch, query_len, dim)`` mask: binary mask 1/0 indicating which keys have zero / non-zero attention ``(batch, query_len, key_len)`` decoder: indicates the self-attention is coming from the decoder. Returns: (FloatTensor, FloatTensor): * output context vectors ``(batch, query_len, dim)`` * Attention vector in heads ``(batch, head, query_len, key_len)``. """ # CHECKS # batch, k_len, d = key.size() # batch_, k_len_, d_ = value.size() # aeq(batch, batch_) # aeq(k_len, k_len_) # aeq(d, d_) # batch_, q_len, d_ = query.size() # aeq(batch, batch_) # aeq(d, d_) # aeq(self.model_dim % 8, 0) # if mask is not None: # batch_, q_len_, k_len_ = mask.size() # aeq(batch_, batch) # aeq(k_len_, k_len) # aeq(q_len_ == q_len) # END CHECKS batch_size = key.size(0) dim_per_head = self.dim_per_head head_count = self.head_count key_len = key.size(1) query_len = query.size(1) use_causal = decoder is True and attn_type == 'self' and self.training is True def shape(x): """Projection.""" return x.view(batch_size, -1, head_count, dim_per_head) \ .transpose(1, 2) def unshape(x): """Compute context.""" return x.transpose(1, 2).contiguous() \ .view(batch_size, 0, head_count * dim_per_head) # 1) Project key, value, and query. if layer_cache is not None: if attn_type == "self": query, key, value = self.linear_query(query),\ self.linear_keys(query),\ self.linear_values(query) key = shape(key) value = shape(value) if layer_cache["self_keys"] is not None: key = torch.cat((layer_cache["self_keys"], key), dim=2) if layer_cache["self_values"] is not None: value = torch.cat((layer_cache["self_values"], value), dim=2) layer_cache["self_keys"] = key layer_cache["self_values"] = value elif attn_type == "context": query = self.linear_query(query) if layer_cache["memory_keys"] is None: key, value = self.linear_keys(key),\ self.linear_values(value) key = shape(key) value = shape(value) else: key, value = layer_cache["memory_keys"],\ layer_cache["memory_values"] layer_cache["memory_keys"] = key layer_cache["memory_values"] = value else: key = self.linear_keys(key) value = self.linear_values(value) query = self.linear_query(query) key = shape(key) value = shape(value) if self.max_relative_positions > 0 and attn_type == "self": key_len = key.size(2) # 1 or key_len x key_len relative_positions_matrix = generate_relative_positions_matrix( key_len, self.max_relative_positions, cache=True if layer_cache is not None else False) # 1 or key_len x key_len x dim_per_head relations_keys = self.relative_positions_embeddings( relative_positions_matrix.to(key.device)) # 1 or key_len x key_len x dim_per_head relations_values = self.relative_positions_embeddings( relative_positions_matrix.to(key.device)) query = shape(query) key_len = key.size(2) query_len = query.size(2) # 2) Calculate and scale scores. query = query / math.sqrt(dim_per_head) # batch x num_heads x query_len x key_len query_key = torch.matmul(query, key.transpose(2, 3)) # Elliott: mask out some backward pass. if use_causal: assert query_len == key_len bk_mask = 1.0 - torch.diag( torch.ones(query_len, device=mask.device)).unsqueeze(0).repeat( [batch_size, 1, 1]).to(mask.dtype) # [bz, len, len] bk_mask = (bk_mask + mask).gt(0).to(query_key.dtype) # [bz, 1, len, len] incre_mask = (bk_mask.to(mask.dtype) - mask).to( query_key.dtype).unsqueeze(1) # [bz, 1, len, len] bk_mask = bk_mask.unsqueeze(1) # [bz, num_heads, len, len] query_key_detach = query_key.detach() # [bz, num_heads, len, len] query_key = bk_mask * query_key + incre_mask * query_key_detach if self.max_relative_positions > 0 and attn_type == "self": scores = query_key + relative_matmul(query, relations_keys, True) else: scores = query_key scores = scores.float() if mask is not None: mask = mask.unsqueeze(1) # [B, 1, 1, T_values] scores = scores.masked_fill(mask, -1e18) # 3) Apply attention dropout and compute context vectors. attn = self.softmax(scores).to(query.dtype) drop_attn = self.dropout(attn) if use_causal: # [bz, num_heads, q_len, k_len, 1] * [bz, num_heads, 1, k_len, dim] --> [bz, num_heads, q_len, k_len, dim] context_original = drop_attn.unsqueeze(-1) * value.unsqueeze(2) context_original_detach = context_original.detach() context_original = bk_mask.unsqueeze( -1) * context_original + incre_mask.unsqueeze( -1) * context_original_detach # [bz, num_heads, q_len, dim] context_original = context_original.sum(3) else: context_original = torch.matmul(drop_attn, value) if self.max_relative_positions > 0 and attn_type == "self": context = unshape( context_original + relative_matmul(drop_attn, relations_values, False)) else: context = unshape(context_original) output = self.final_linear(context) # CHECK # batch_, q_len_, d_ = output.size() # aeq(q_len, q_len_) # aeq(batch, batch_) # aeq(d, d_) # Return multi-head attn attns = attn \ .view(batch_size, head_count, query_len, key_len) return output, attns
def forward(self, key, value, query, mask=None, layer_cache=None, type=None): """ Compute the context vector and the attention vectors. Args: key (FloatTensor): set of `key_len` key vectors ``(batch, key_len, dim)`` value (FloatTensor): set of `key_len` value vectors ``(batch, key_len, dim)`` query (FloatTensor): set of `query_len` query vectors ``(batch, query_len, dim)`` mask: binary mask indicating which keys have non-zero attention ``(batch, query_len, key_len)`` Returns: (FloatTensor, FloatTensor): * output context vectors ``(batch, query_len, dim)`` * one of the attention vectors ``(batch, query_len, key_len)`` """ # CHECKS # batch, k_len, d = key.size() # batch_, k_len_, d_ = value.size() # aeq(batch, batch_) # aeq(k_len, k_len_) # aeq(d, d_) # batch_, q_len, d_ = query.size() # aeq(batch, batch_) # aeq(d, d_) # aeq(self.model_dim % 8, 0) # if mask is not None: # batch_, q_len_, k_len_ = mask.size() # aeq(batch_, batch) # aeq(k_len_, k_len) # aeq(q_len_ == q_len) # END CHECKS batch_size = key.size(0) dim_per_head = self.dim_per_head head_count = self.head_count key_len = key.size(1) query_len = query.size(1) device = key.device def shape(x): """Projection.""" return x.view(batch_size, -1, head_count, dim_per_head) \ .transpose(1, 2) def unshape(x): """Compute context.""" return x.transpose(1, 2).contiguous() \ .view(batch_size, -1, head_count * dim_per_head) # 1) Project key, value, and query. if layer_cache is not None: if type == "self": query, key, value = self.linear_query(query),\ self.linear_keys(query),\ self.linear_values(query) key = shape(key) value = shape(value) if layer_cache["self_keys"] is not None: key = torch.cat((layer_cache["self_keys"].to(device), key), dim=2) if layer_cache["self_values"] is not None: value = torch.cat( (layer_cache["self_values"].to(device), value), dim=2) layer_cache["self_keys"] = key layer_cache["self_values"] = value elif type == "context": query = self.linear_query(query) if layer_cache["memory_keys"] is None: key, value = self.linear_keys(key),\ self.linear_values(value) key = shape(key) value = shape(value) else: key, value = layer_cache["memory_keys"],\ layer_cache["memory_values"] layer_cache["memory_keys"] = key layer_cache["memory_values"] = value else: key = self.linear_keys(key) value = self.linear_values(value) query = self.linear_query(query) key = shape(key) value = shape(value) if self.max_relative_positions > 0 and type == "self": key_len = key.size(2) # 1 or key_len x key_len relative_positions_matrix = generate_relative_positions_matrix( key_len, self.max_relative_positions, cache=True if layer_cache is not None else False) # 1 or key_len x key_len x dim_per_head relations_keys = self.relative_positions_embeddings( relative_positions_matrix.to(device)) # 1 or key_len x key_len x dim_per_head relations_values = self.relative_positions_embeddings( relative_positions_matrix.to(device)) query = shape(query) key_len = key.size(2) query_len = query.size(2) # 2) Calculate and scale scores. query = query / math.sqrt(dim_per_head) # batch x num_heads x query_len x key_len query_key = torch.matmul(query, key.transpose(2, 3)) if self.max_relative_positions > 0 and type == "self": scores = query_key + relative_matmul(query, relations_keys, True) else: scores = query_key scores = scores.float() if mask is not None: mask = mask.unsqueeze(1) # [B, 1, 1, T_values] scores = scores.masked_fill(mask, -1e18) # 3) Apply attention dropout and compute context vectors. attn = self.softmax(scores).to(query.dtype) drop_attn = self.dropout(attn) context_original = torch.matmul(drop_attn, value) if self.max_relative_positions > 0 and type == "self": context = unshape( context_original + relative_matmul(drop_attn, relations_values, False)) else: context = unshape(context_original) output = self.final_linear(context) # CHECK # batch_, q_len_, d_ = output.size() # aeq(q_len, q_len_) # aeq(batch, batch_) # aeq(d, d_) # Return one attn top_attn = attn \ .view(batch_size, head_count, query_len, key_len)[:, 0, :, :] \ .contiguous() return output, top_attn
def forward(self, key, value, query, mask=None, layer_cache=None, type=None): """ Compute the context vector and the attention vectors. Args: key (FloatTensor): set of `key_len` key vectors ``(batch, key_len, dim)`` value (FloatTensor): set of `key_len` value vectors ``(batch, key_len, dim)`` query (FloatTensor): set of `query_len` query vectors ``(batch, query_len, dim)`` mask: binary mask indicating which keys have non-zero attention ``(batch, query_len, key_len)`` Returns: (FloatTensor, FloatTensor): * output context vectors ``(batch, query_len, dim)`` * one of the attention vectors ``(batch, query_len, key_len)`` """ # CHECKS # batch, k_len, d = key.size() # batch_, k_len_, d_ = value.size() # aeq(batch, batch_) # aeq(k_len, k_len_) # aeq(d, d_) # batch_, q_len, d_ = query.size() # aeq(batch, batch_) # aeq(d, d_) # aeq(self.model_dim % 8, 0) # if mask is not None: # batch_, q_len_, k_len_ = mask.size() # aeq(batch_, batch) # aeq(k_len_, k_len) # aeq(q_len_ == q_len) # END CHECKS batch_size = key.size(0) dim_per_head = self.dim_per_head head_count = self.head_count key_len = key.size(1) query_len = query.size(1) device = key.device def shape(x): """Projection.""" return x.view(batch_size, -1, head_count, dim_per_head) \ .transpose(1, 2) def unshape(x): """Compute context.""" return x.transpose(1, 2).contiguous() \ .view(batch_size, -1, head_count * dim_per_head) # 1) Project key, value, and query. if layer_cache is not None: if type == "self": query, key, value = self.linear_query(query),\ self.linear_keys(query),\ self.linear_values(query) key = shape(key) value = shape(value) if layer_cache["self_keys"] is not None: key = torch.cat( (layer_cache["self_keys"].to(device), key), dim=2) if layer_cache["self_values"] is not None: value = torch.cat( (layer_cache["self_values"].to(device), value), dim=2) layer_cache["self_keys"] = key layer_cache["self_values"] = value elif type == "context": query = self.linear_query(query) if layer_cache["memory_keys"] is None: key, value = self.linear_keys(key),\ self.linear_values(value) key = shape(key) value = shape(value) else: key, value = layer_cache["memory_keys"],\ layer_cache["memory_values"] layer_cache["memory_keys"] = key layer_cache["memory_values"] = value else: key = self.linear_keys(key) value = self.linear_values(value) query = self.linear_query(query) key = shape(key) value = shape(value) if self.max_relative_positions > 0 and type == "self": key_len = key.size(2) # 1 or key_len x key_len relative_positions_matrix = generate_relative_positions_matrix( key_len, self.max_relative_positions, cache=True if layer_cache is not None else False) # 1 or key_len x key_len x dim_per_head relations_keys = self.relative_positions_embeddings( relative_positions_matrix.to(device)) # 1 or key_len x key_len x dim_per_head relations_values = self.relative_positions_embeddings( relative_positions_matrix.to(device)) query = shape(query) key_len = key.size(2) query_len = query.size(2) # 2) Calculate and scale scores. query = query / math.sqrt(dim_per_head) # batch x num_heads x query_len x key_len query_key = torch.matmul(query, key.transpose(2, 3)) if self.max_relative_positions > 0 and type == "self": scores = query_key + relative_matmul(query, relations_keys, True) else: scores = query_key scores = scores.float() if mask is not None: mask = mask.unsqueeze(1) # [B, 1, 1, T_values] scores = scores.masked_fill(mask, -1e18) # 3) Apply attention dropout and compute context vectors. attn = self.softmax(scores).to(query.dtype) drop_attn = self.dropout(attn) context_original = torch.matmul(drop_attn, value) if self.max_relative_positions > 0 and type == "self": context = unshape(context_original + relative_matmul(drop_attn, relations_values, False)) else: context = unshape(context_original) output = self.final_linear(context) # CHECK # batch_, q_len_, d_ = output.size() # aeq(q_len, q_len_) # aeq(batch, batch_) # aeq(d, d_) # Return one attn top_attn = attn \ .view(batch_size, head_count, query_len, key_len)[:, 0, :, :] \ .contiguous() return output, top_attn
def forward(self, key, value, query, mask=None, layer_cache=None, attn_type=None): """ Compute the context vector and the attention vectors. Args: key (FloatTensor): set of `key_len` key vectors ``(batch, key_len, dim)`` value (FloatTensor): set of `key_len` value vectors ``(batch, key_len, dim)`` query (FloatTensor): set of `query_len` query vectors ``(batch, query_len, dim)`` mask: binary mask 1/0 indicating which keys have zero / non-zero attention ``(batch, query_len, key_len)`` Returns: (FloatTensor, FloatTensor): * output context vectors ``(batch, query_len, dim)`` * one of the attention vectors ``(batch, query_len, key_len)`` """ # CHECKS # batch, k_len, d = key.size() # batch_, k_len_, d_ = value.size() # aeq(batch, batch_) # aeq(k_len, k_len_) # aeq(d, d_) # batch_, q_len, d_ = query.size() # aeq(batch, batch_) # aeq(d, d_) # aeq(self.model_dim % 8, 0) # if mask is not None: # batch_, q_len_, k_len_ = mask.size() # aeq(batch_, batch) # aeq(k_len_, k_len) # aeq(q_len_ == q_len) # END CHECKS batch_size = key.size(0) dim_per_head = self.dim_per_head head_count = self.head_count key_len = key.size(1) query_len = query.size(1) def shape(x): """Projection.""" return x.view(batch_size, -1, head_count, dim_per_head) \ .transpose(1, 2) def unshape(x): """Compute context.""" return x.transpose(1, 2).contiguous() \ .view(batch_size, -1, head_count * dim_per_head) if self.with_saliency_selection: selection_query = self.linear_selection_query(query) selection_key = self.linear_selection_key(key) selection_key = shape(selection_key) selection_query = shape(selection_query) # 1) Project key, value, and query. if layer_cache is not None: if attn_type == "self": query, key, value = self.linear_query(query),\ self.linear_keys(query),\ self.linear_values(query) key = shape(key) value = shape(value) if layer_cache["self_keys"] is not None: key = torch.cat( (layer_cache["self_keys"], key), dim=2) if layer_cache["self_values"] is not None: value = torch.cat( (layer_cache["self_values"], value), dim=2) layer_cache["self_keys"] = key layer_cache["self_values"] = value elif attn_type == "context": query = self.linear_query(query) if layer_cache["memory_keys"] is None: key, value = self.linear_keys(key),\ self.linear_values(value) key = shape(key) value = shape(value) else: key, value = layer_cache["memory_keys"],\ layer_cache["memory_values"] layer_cache["memory_keys"] = key layer_cache["memory_values"] = value else: key = self.linear_keys(key) value = self.linear_values(value) query = self.linear_query(query) key = shape(key) value = shape(value) if self.max_relative_positions > 0 and attn_type == "self": key_len = key.size(2) # 1 or key_len x key_len relative_positions_matrix = generate_relative_positions_matrix( key_len, self.max_relative_positions, cache=True if layer_cache is not None else False) # 1 or key_len x key_len x dim_per_head relations_keys = self.relative_positions_embeddings( relative_positions_matrix.to(key.device)) # 1 or key_len x key_len x dim_per_head relations_values = self.relative_positions_embeddings( relative_positions_matrix.to(key.device)) if self.with_focus_attention == True: glo = torch.mean(query, dim=1, keepdim=True) c = self.tanh(self.linear_focus_query(query) + self.linear_focus_global(glo)) # c = self.tanh(self.linear_focus_query(query))# + self.linear_focus_global(glo)) c = shape(c) p = c * self.up p = p.sum(3).squeeze() z = c * self.uz z = z.sum(3).squeeze() P = self.sigmoid(p) * key_len Z = self.sigmoid(z) * key_len j = torch.arange(start=0, end=key_len, dtype=P.dtype).unsqueeze(0).unsqueeze(0).unsqueeze(0).to('cuda') P = P.unsqueeze(-1) Z = Z.unsqueeze(-1) G = - (j-P)**2 * 2 / (Z**2) query = shape(query) if self.with_saliency_selection == True: # gate_key = self.linear_selection_key(unshape(key)) # gate_query = self.linear_selection_query(unshape(query)) # gate_key = shape(gate_key) # gate_query = shape(gate_query) gate = self.sigmoid(torch.matmul(selection_query, selection_key.transpose(2, 3))) key_len = key.size(2) query_len = query.size(2) # 2) Calculate and scale scores. query = query / math.sqrt(dim_per_head) # batch x num_heads x query_len x key_len query_key = torch.matmul(query, key.transpose(2, 3)) if self.max_relative_positions > 0 and attn_type == "self": scores = query_key + relative_matmul(query, relations_keys, True) else: scores = query_key scores = scores.float() if self.with_focus_attention == True: scores = scores + G if mask is not None: mask = mask.unsqueeze(1) # [B, 1, 1, T_values] scores = scores.masked_fill(mask, -1e18) # 3) Apply attention dropout and compute context vectors. attn = self.softmax(scores).to(query.dtype) if self.with_saliency_selection: new_attn = attn * gate drop_attn = self.dropout(new_attn) else: drop_attn = self.dropout(attn) context_original = torch.matmul(drop_attn, value) if self.max_relative_positions > 0 and attn_type == "self": print('relative') context = unshape(context_original + relative_matmul(drop_attn, relations_values, False)) else: context = unshape(context_original) output = self.final_linear(context) # CHECK # batch_, q_len_, d_ = output.size() # aeq(q_len, q_len_) # aeq(batch, batch_) # aeq(d, d_) # Return one attn top_attn = attn \ .view(batch_size, head_count, query_len, key_len)[:, 0, :, :] \ .contiguous() return output, top_attn
def forward(self, self_kvq, ctx_kv, self_mask=None, ctx_mask=None, layer_cache=None, type=None): """ Compute the context vector and the attention vectors. Args: self_kvq (FloatTensor): set of `self_len` key vectors ``(batch, self_len, dim)`` ctz_kv (FloatTensor): set of `ctx_len` value vectors ``(batch, ctx_len, dim)`` mask: binary mask indicating which keys have non-zero attention ``(batch, self_len, self_len)`` Returns: (FloatTensor, FloatTensor): * output context vectors ``(batch, self_len, dim)`` * one of the attention vectors ``(batch, self_len, ctx_len)`` """ # CHECKS # batch, k_len, d = key.size() # batch_, k_len_, d_ = value.size() # aeq(batch, batch_) # aeq(k_len, k_len_) # aeq(d, d_) # batch_, q_len, d_ = query.size() # aeq(batch, batch_) # aeq(d, d_) # aeq(self.model_dim % 8, 0) # if mask is not None: # batch_, q_len_, k_len_ = mask.size() # aeq(batch_, batch) # aeq(k_len_, k_len) # aeq(q_len_ == q_len) # END CHECKS batch_size = self_kvq.size(0) dim_per_head = self.dim_per_head head_count = self.head_count self_len = self_kvq.size(1) ctx_len = ctx_kv.size(1) device = self_kvq.device def shape(x): """Projection.""" return x.view(batch_size, -1, head_count, dim_per_head) \ .transpose(1, 2) def unshape(x): """Compute context.""" return x.transpose(1, 2).contiguous() \ .view(batch_size, -1, head_count * dim_per_head) # 1) Project key, value, and query. if layer_cache is not None: query, self_key, self_value = self.linear_query(self_kvq),\ self.linear_keys(self_kvq),\ self.linear_values(self_kvq) #self_key = shape(self_key) #self_value = shape(self_value) if layer_cache["self_keys"] is not None: self_key = torch.cat( (layer_cache["self_keys"].to(device), self_key), dim=1) if layer_cache["self_values"] is not None: self_value = torch.cat( (layer_cache["self_values"].to(device), self_value), dim=1) layer_cache["self_keys"] = self_key layer_cache["self_values"] = self_value if layer_cache["memory_keys"] is None: ctx_key = self.ctx_linear_keys(ctx_kv) # [batch, ctx_len, dim] ctx_value = self.ctx_linear_values(ctx_kv) layer_cache["memory_keys"] = ctx_key layer_cache["memory_values"] = ctx_value else: ctx_key = layer_cache["memory_keys"] ctx_value = layer_cache["memory_values"] else: self_key = self.linear_keys(self_kvq) # [batch, self_len, dim] self_value = self.linear_values(self_kvq) query = self.linear_query(self_kvq) ctx_key = self.ctx_linear_keys(ctx_kv) # [batch, ctx_len, dim] ctx_value = self.ctx_linear_values(ctx_kv) self_len = self_key.size( 1) # Need to do this again to include the layer_cache length ctx_len = ctx_key.shape[1] key = torch.cat((self_key, ctx_key), dim=1) value = torch.cat((self_value, ctx_value), dim=1) key = shape(key) value = shape(value) if self.max_relative_positions > 0 and type == "self": raise NotImplementedError key_len = key.size(2) # 1 or key_len x key_len relative_positions_matrix = generate_relative_positions_matrix( key_len, self.max_relative_positions, cache=True if layer_cache is not None else False) # 1 or key_len x key_len x dim_per_head relations_keys = self.relative_positions_embeddings( relative_positions_matrix.to(device)) # 1 or key_len x key_len x dim_per_head relations_values = self.relative_positions_embeddings( relative_positions_matrix.to(device)) query = shape(query) key_len = key.size(2) # self_len+ctx_len query_len = query.size(2) # self_len # 2) Calculate and scale scores. query = query / math.sqrt(dim_per_head) # batch x num_heads x query_len x key_len query_key = torch.matmul(query, key.transpose( 2, 3)) # [batch, head, self_len, self_len+ctx_len] if self.ctx_weight_param: query_key[..., self_len:] += self.ctx_bias #print(query_key.mean(), query_key.std()) if self.max_relative_positions > 0 and type == "self": scores = query_key + relative_matmul(query, relations_keys, True) else: scores = query_key scores = scores.float() if self_mask is not None: self_mask = self_mask.unsqueeze(1) # [B, 1, self_len, self_len] scores[:, :, :, :self_len] = scores[:, :, :, : self_len].masked_fill( self_mask, -1e18) if ctx_mask is not None: ctx_mask = ctx_mask.unsqueeze(1) # [B, 1, 1, ctx_len] scores[:, :, :, self_len:] = scores[:, :, :, self_len:].masked_fill(ctx_mask, -1e18) # 3) Apply attention dropout and compute context vectors. attn = self.softmax(scores).to(query.dtype) drop_attn = self.dropout(attn) context_original = torch.matmul(drop_attn, value) # [batch, head, self_len, dim] if self.max_relative_positions > 0 and type == "self": context = unshape( context_original + relative_matmul(drop_attn, relations_values, False)) else: context = unshape(context_original) output = self.final_linear(context) # CHECK # batch_, q_len_, d_ = output.size() # aeq(q_len, q_len_) # aeq(batch, batch_) # aeq(d, d_) # Return one attn (to context) ctx_attn_probs = attn[:, :, :, self_len:] ctx_attn_probs = ctx_attn_probs / ctx_attn_probs.sum(dim=-1, keepdim=True) top_attn = ctx_attn_probs \ .view(batch_size, head_count, query_len, ctx_len)[:, 0, :, :] \ .contiguous() return output, top_attn, attn
def forward(self, key, value, query, grh=None, mask=None, layer_cache=None, attn_type=None): """ Compute the context vector and the attention vectors. Args: key (FloatTensor): set of `key_len` key vectors ``(batch, key_len, dim)`` value (FloatTensor): set of `key_len` value vectors ``(batch, key_len, dim)`` query (FloatTensor): set of `query_len` query vectors ``(batch, query_len, dim)`` mask: binary mask 1/0 indicating which keys have zero / non-zero attention ``(batch, query_len, key_len)`` Returns: (FloatTensor, FloatTensor): * output context vectors ``(batch, query_len, dim)`` * one of the attention vectors ``(batch, query_len, key_len)`` """ # CHECKS # batch, k_len, d = key.size() # batch_, k_len_, d_ = value.size() # aeq(batch, batch_) # aeq(k_len, k_len_) # aeq(d, d_) # batch_, q_len, d_ = query.size() # aeq(batch, batch_) # aeq(d, d_) # aeq(self.model_dim % 8, 0) # if mask is not None: # batch_, q_len_, k_len_ = mask.size() # aeq(batch_, batch) # aeq(k_len_, k_len) # aeq(q_len_ == q_len) # END CHECKS batch_size = key.size(0) dim_per_head = self.dim_per_head head_count = self.head_count key_len = key.size(1) query_len = query.size(1) def shape(x): """Projection.""" return x.view(batch_size, -1, head_count, dim_per_head) \ .transpose(1, 2) def unshape(x): """Compute context.""" return x.transpose(1, 2).contiguous() \ .view(batch_size, -1, head_count * dim_per_head) # 1) Project key, value, and query. if layer_cache is not None: if attn_type == "self": query, key, value = self.linear_query(query),\ self.linear_keys(query),\ self.linear_values(query) key = shape(key) value = shape(value) if layer_cache["self_keys"] is not None: key = torch.cat((layer_cache["self_keys"], key), dim=2) if layer_cache["self_values"] is not None: value = torch.cat((layer_cache["self_values"], value), dim=2) layer_cache["self_keys"] = key layer_cache["self_values"] = value elif attn_type == "context": query = self.linear_query(query) if layer_cache["memory_keys"] is None: key, value = self.linear_keys(key),\ self.linear_values(value) key = shape(key) value = shape(value) else: key, value = layer_cache["memory_keys"],\ layer_cache["memory_values"] layer_cache["memory_keys"] = key layer_cache["memory_values"] = value else: key = self.linear_keys(key) value = self.linear_values(value) query = self.linear_query(query) key = shape(key) value = shape(value) if self.max_relative_positions > 0 and attn_type == "self": key_len = key.size(2) # 1 or key_len x key_len relative_positions_matrix = generate_relative_positions_matrix( key_len, self.max_relative_positions, cache=True if layer_cache is not None else False) # 1 or key_len x key_len x dim_per_head relations_keys = self.relative_positions_embeddings( relative_positions_matrix.to(key.device)) # 1 or key_len x key_len x dim_per_head relations_values = self.relative_positions_embeddings( relative_positions_matrix.to(key.device)) query = shape(query) key_len = key.size(2) query_len = query.size(2) # 2) Calculate and scale scores. query = query / math.sqrt(dim_per_head) # batch x num_heads x query_len x key_len query_key = torch.matmul(query, key.transpose(2, 3)) if self.max_relative_positions > 0 and attn_type == "self": scores = query_key + relative_matmul(query, relations_keys, True) else: scores = query_key scores = scores.float() if mask is not None: mask = mask.unsqueeze(1) # [B, 1, 1, T_values] scores = scores.masked_fill(mask, -1e18) # 3) Apply attention dropout and compute context vectors. attn = self.softmax(scores).to(query.dtype) drop_attn = self.dropout(attn) context_original = torch.matmul(drop_attn, value) if self.max_relative_positions > 0 and attn_type == "self": context = unshape( context_original + relative_matmul(drop_attn, relations_values, False)) else: context = unshape(context_original) if self.gate: gate = torch.sigmoid(self.gate_linear[0](context)) gate_context = gate * context # CHECK # batch_, q_len_, d_ = output.size() # aeq(q_len, q_len_) # aeq(batch, batch_) # aeq(d, d_) # Return one attn top_attn = attn \ .view(batch_size, head_count, query_len, key_len)[:, 0, :, :] \ .contiguous() ## above is fully-connected graph ## Multi-View self-attention if grh is not None: assert query_len == key_len #assert key_len-1 != grh[0][-1][0], "the num of nodes is not consistent" views = [] index = [(0, 1), (2, 3), (4, 5), (6, 7)] # whole sub graph h_i = self.linear_attention[index[-1][0]](value) h_j = self.linear_attention[index[-1][1]](value) e = nn.functional.leaky_relu( h_i + h_j.transpose(2, 3)) # default alpha=0.01, but =0.2 in tf grh_mask = torch.ones_like(grh) adj = (grh_mask < grh).unsqueeze(1).expand(-1, self.head_count, -1, -1) zero = torch.ones_like(e) * (-9e15) e_shape = e.shape attention = self.softmax(e.where(adj > 0, zero)) # 17 8 56 56 whole_sub_view = torch.matmul(attention, value) if self.gate: gate = torch.sigmoid(self.gate_linear[-1]( unshape(whole_sub_view))) views.append(gate * unshape(whole_sub_view)) else: views.append(unshape(whole_sub_view)) # edge-aware sub graph for i in range(self.edge_type - 1): h_i = self.linear_attention[index[i][0]](value) h_j = self.linear_attention[index[i][1]](value) e = nn.functional.leaky_relu( h_i + h_j.transpose(2, 3)) # default alpha=0.01, but =0.2 in tf label_id = i + 2 # +2 because the followed is ones_like, so 1 can't be the edge type grh_mask = torch.ones_like(grh) * label_id eye = (torch.eye(grh.size(-1), dtype=torch.int64) * (4 - label_id)).cuda() # here doesnt support multi-gpu grh_mask = grh_mask + eye adj = (grh_mask == grh).unsqueeze(1).expand( -1, self.head_count, -1, -1) zero = torch.ones_like(e) * (-9e15) e_shape = e.shape attention = self.softmax(e.where(adj > 0, zero)) # 17 8 56 56 sub_view = torch.matmul(attention, value) if self.gate: gate = torch.sigmoid(self.gate_linear[i + 1]( unshape(sub_view))) views.append(gate * unshape(sub_view)) else: views.append(unshape(sub_view)) if self.fusion == "cat": ## TODO: MAX_P LSTM if self.gate: Views = [gate_context] + views else: Views = [context] + views output = torch.cat(Views, dim=-1) return self.sub_final_linear(output), top_attn output = self.final_linear(context) return output, top_attn