def get_incremental_state( module: MultiheadAttention, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], key: str, ) -> Optional[Dict[str, Optional[Tensor]]]: """Helper for getting incremental state for an nn.Module.""" return module.get_incremental_state(incremental_state, key)
def test_append_prev_key_padding_mask(self): bsz = 1 src_len = 4 cases = [ # no padding mask (None, None, None), # current padding mask only ( torch.tensor([[1]]).bool(), None, torch.tensor([[0, 0, 0, 1]]).bool(), ), # previous padding mask only ( None, torch.tensor([[0, 1, 0]]).bool(), torch.tensor([[0, 1, 0, 0]]).bool(), ), # both padding masks ( torch.tensor([[1]]).bool(), torch.tensor([[0, 1, 0]]).bool(), torch.tensor([[0, 1, 0, 1]]).bool(), ), # prev_key_padding_mask already full ( torch.tensor([[0, 1, 0, 1]]).bool(), None, torch.tensor([[0, 1, 0, 1]]).bool(), ), # key_padding_mask already full ( None, torch.tensor([[0, 1, 0, 1]]).bool(), torch.tensor([[0, 1, 0, 1]]).bool(), ), ] for c in cases: key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( c[0], c[1], batch_size=bsz, src_len=src_len, static_kv=False, ) if key_padding_mask is not None: self.assertTrue( torch.all(torch.eq(key_padding_mask, c[2])), f"Unexpected resultant key padding mask: {key_padding_mask}" f" given current: {c[0]} and previous: {c[1]}", ) self.assertEqual(key_padding_mask.size(0), bsz) self.assertEqual(key_padding_mask.size(1), src_len) else: self.assertIsNone(c[2])
def test_add_bias_parity(): # values don't matter for this test. mha = MultiheadAttention( embedding=8, num_heads=2, dropout=0.0, add_bias_kv=True, add_zero_attn=True, ) def old_bias_code(k, v, key_padding_mask, attn_mask, bsz): k = torch.cat([k, mha.bias_k.repeat(1, bsz, 1)]) v = torch.cat([v, mha.bias_v.repeat(1, bsz, 1)]) if attn_mask is not None: attn_mask = torch.cat( [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) if key_padding_mask is not None: key_padding_mask = torch.cat( [ key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1), ], dim=1, ) return k, v, key_padding_mask, attn_mask seq_len = 64 bsz = 8 embedding = 8 key_padding_mask = torch.rand((bsz, seq_len)) attn_mask = torch.rand((seq_len, seq_len)) k = torch.rand((seq_len, bsz, embedding)) v = torch.rand((seq_len, bsz, embedding)) k_orig, v_orig, kp_mask_orig, a_mask_orig = old_bias_code( k, v, key_padding_mask, attn_mask, bsz) k_new, v_new, kp_mask_new, a_mask_new = mha._add_bias( k, v, key_padding_mask, attn_mask, bsz) assert torch.equal(k_orig, k_new) assert torch.equal(v_orig, v_new) assert torch.equal(kp_mask_orig, kp_mask_new) assert torch.equal(a_mask_orig, a_mask_new)
def set_incremental_state( module: MultiheadAttention, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], key: str, value: Dict[str, Optional[Tensor]], ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: """Helper for setting incremental state for an nn.Module.""" if incremental_state is not None: result = module.set_incremental_state(incremental_state, key, value) if result is not None: incremental_state = result return incremental_state
def test_pruning_heads(self): embed_dim = 768 num_heads = 12 num_heads_to_keep = 8 dummy_input = torch.randn(32, 2, embed_dim) mha = MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads) reserve_head_index = mha._get_reserve_head_index( num_heads_to_keep=num_heads_to_keep) mha._adaptive_prune_heads(reserve_head_index=reserve_head_index) mha._set_skip_embed_dim_check() mha(query=dummy_input, key=dummy_input, value=dummy_input) self.assertEqual(mha.head_dim, embed_dim / num_heads) self.assertEqual(mha.num_heads, num_heads_to_keep)
def test_mask_padding_parity(): def old_padding_code(key_padding_mask, attn_mask): if attn_mask is not None: attn_mask = torch.cat( [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) if key_padding_mask is not None: key_padding_mask = torch.cat( [ key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask), ], dim=1, ) return key_padding_mask, attn_mask # values don't matter for this test. mha = MultiheadAttention( embedding=8, num_heads=2, dropout=0.0, add_bias_kv=True, add_zero_attn=True, ) key_padding_mask = torch.rand((8, 64)) attn_mask = torch.rand((64, 64)) kp_mask_orig, a_mask_orig = old_padding_code(key_padding_mask, attn_mask) kp_mask_new, a_mask_new = mha._pad_masks(key_padding_mask, attn_mask) assert kp_mask_orig.size() == kp_mask_new.size() assert a_mask_orig.size() == a_mask_new.size() assert torch.equal(kp_mask_orig, kp_mask_new) assert torch.equal(a_mask_orig, a_mask_new)
def build_self_attention( self, embed_dim, num_attention_heads, dropout, self_attention, q_noise, qn_block_size, ): return MultiheadAttention( embed_dim, num_attention_heads, dropout=dropout, self_attention=True, q_noise=q_noise, qn_block_size=qn_block_size, )
def __init__( self, embedding_dim: float = 768, ffn_embedding_dim: float = 3072, num_attention_heads: float = 8, dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.1, activation_fn: str = "relu", layer_norm_first: bool = False, ) -> None: super().__init__() # Initialize parameters self.embedding_dim = embedding_dim self.dropout = dropout self.activation_dropout = activation_dropout # Initialize blocks self.activation_fn = utils.get_activation_fn(activation_fn) self.self_attn = MultiheadAttention( self.embedding_dim, num_attention_heads, dropout=attention_dropout, self_attention=True, ) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(self.activation_dropout) self.dropout3 = nn.Dropout(dropout) self.layer_norm_first = layer_norm_first # layer norm associated with the self attention layer self.self_attn_layer_norm = LayerNorm(self.embedding_dim) self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) # layer norm associated with the position wise feed-forward NN self.final_layer_norm = LayerNorm(self.embedding_dim)
def forward(self, query, key: Optional[Tensor], value: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, incremental_state: Optional[Dict[str, Dict[ str, Optional[Tensor]]]] = None, need_weights: bool = True, static_kv: bool = False, attn_mask: Optional[Tensor] = None, before_softmax: bool = False, need_head_weights: bool = False, mask=None, loss_type: str = 'nmt') -> Tuple[Tensor, Optional[Tensor]]: if need_head_weights: need_weights = True tgt_len, bsz, embed_dim = query.size() assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] if (not self.onnx_trace and not self.tpu # don't use PyTorch version on TPUs and incremental_state is None and not static_kv # A workaround for quantization to work. Otherwise JIT compilation # treats bias in linear module as method. and not torch.jit.is_scripting()): assert key is not None and value is not None return F.multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, torch.empty([0]), torch.cat( (self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), self.bias_k, self.bias_v, self.add_zero_attn, self.dropout_module.p, self.out_proj.weight, self.out_proj.bias, self.training or self.dropout_module.apply_during_inference, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight=True, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight, ) if incremental_state is not None: saved_state = self._get_input_buffer(incremental_state) if saved_state is not None and "prev_key" in saved_state: # previous time steps are cached - no need to recompute # key and value if they are static if static_kv: assert self.encoder_decoder_attention and not self.self_attention key = value = None else: saved_state = None if self.self_attention: q = self.q_proj(query) k = self.k_proj(query) v = self.v_proj(query) elif self.encoder_decoder_attention: # encoder-decoder attention q = self.q_proj(query) if key is None: assert value is None k = v = None else: k = self.k_proj(key) v = self.v_proj(key) else: assert key is not None and value is not None q = self.q_proj(query) k = self.k_proj(key) v = self.v_proj(value) q *= self.scaling if self.bias_k is not None: assert self.bias_v is not None k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) if attn_mask is not None: attn_mask = torch.cat( [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) if key_padding_mask is not None: key_padding_mask = torch.cat( [ key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1), ], dim=1, ) q = (q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)) if k is not None: k = (k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)) if v is not None: v = (v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)) if saved_state is not None: # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) if "prev_key" in saved_state: _prev_key = saved_state["prev_key"] assert _prev_key is not None prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) if static_kv: k = prev_key else: assert k is not None k = torch.cat([prev_key, k], dim=1) if "prev_value" in saved_state: _prev_value = saved_state["prev_value"] assert _prev_value is not None prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) if static_kv: v = prev_value else: assert v is not None v = torch.cat([prev_value, v], dim=1) prev_key_padding_mask: Optional[Tensor] = None if "prev_key_padding_mask" in saved_state: prev_key_padding_mask = saved_state["prev_key_padding_mask"] assert k is not None and v is not None key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( key_padding_mask=key_padding_mask, prev_key_padding_mask=prev_key_padding_mask, batch_size=bsz, src_len=k.size(1), static_kv=static_kv, ) saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) saved_state["prev_key_padding_mask"] = key_padding_mask # In this branch incremental_state is never None assert incremental_state is not None incremental_state = self._set_input_buffer(incremental_state, saved_state) assert k is not None src_len = k.size(1) # This is part of a workaround to get around fork/join parallelism # not supporting Optional types. if key_padding_mask is not None and key_padding_mask.dim() == 0: key_padding_mask = None if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len if self.add_zero_attn: assert v is not None src_len += 1 k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) if attn_mask is not None: attn_mask = torch.cat( [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) if key_padding_mask is not None: key_padding_mask = torch.cat( [ key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask), ], dim=1, ) attn_weights = torch.bmm(q, k.transpose(1, 2)) attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) assert list( attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] if attn_mask is not None: attn_mask = attn_mask.unsqueeze(0) if self.onnx_trace: attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) attn_weights += attn_mask if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) if not self.tpu: attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")) else: attn_weights = attn_weights.transpose(0, 2) attn_weights = attn_weights.masked_fill( key_padding_mask, float('-inf')) attn_weights = attn_weights.transpose(0, 2) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if before_softmax: return attn_weights, v attn_weights_float = utils.softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace) if loss_type == 'nmt': attn_weights_float = attn_weights_float + 0.1 * torch.mul( attn_weights_float, torch.exp(1 - mask)) elif loss_type == 'mask': # attn_weights_float = torch.mul(attn_weights, mask) attn_weights_float = torch.add( torch.mul(attn_weights_float, mask), torch.mul(torch.mean(attn_weights_float, -1, True), 1 - mask)) # tmp=attn_weights_float # if key_padding_mask is not None: # attn_weights_float = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len) # attn_weights_float = attn_weights_float.masked_fill( # key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), # float("-inf") # ) # attn_weights_float = attn_weights_float.view(bsz * self.num_heads, tgt_len, src_len) # attn_weights_float = utils.softmax(attn_weights_float, dim=-1, onnx_trace=self.onnx_trace) attn_weights = attn_weights_float.type_as(attn_weights) attn_probs = self.dropout_module(attn_weights) assert v is not None attn = torch.bmm(attn_probs, v) assert list( attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] if self.onnx_trace and attn.size(1) == 1: # when ONNX tracing a single decoder step (sequence length == 1) # the transpose is a no-op copy before view, thus unnecessary attn = attn.contiguous().view(tgt_len, bsz, embed_dim) else: attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn = self.out_proj(attn) attn_weights: Optional[Tensor] = None if need_weights: attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) if not need_head_weights: # average attention weights over heads attn_weights = attn_weights.mean(dim=0) # attn_weights = attn_weights[0] return attn, attn_weights
def benchmark_multihead_attention( label="", attn_dtype=torch.uint8, key_padding_dtype=torch.uint8, add_bias_kv=False, add_zero_attn=False, static_kv=False, batch_size=20, embedding=EMB, seq_len=SEQ, num_heads=HEADS, ): results = [] # device = torch.device("cuda") xformers_att_config = '{"name": "scaled_dot_product"}' attn_mask = _get_mask(to_dtype=attn_dtype, dim0=seq_len, dim1=seq_len) key_padding_mask = _get_mask(to_dtype=key_padding_dtype, dim0=batch_size, dim1=seq_len) q = torch.rand(seq_len, batch_size, embedding, requires_grad=True) k = torch.rand(seq_len, batch_size, embedding, requires_grad=True) v = torch.rand(seq_len, batch_size, embedding, requires_grad=True) _reset_seeds() original_mha = MultiheadAttention( embedding, num_heads, dropout=0.0, xformers_att_config=None, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, ) xformers_mha = MultiheadAttention( embedding, num_heads, dropout=0.0, xformers_att_config=xformers_att_config, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, ) def original_bench_fw(q, k, v, key_padding_mask, attn_mask, static_kv): original_mha( query=q, key=k, value=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask, static_kv=static_kv, ) def xformers_bench_fw(q, k, v, key_padding_mask, attn_mask, static_kv): xformers_mha( query=q, key=k, value=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask, static_kv=static_kv, ) def original_bench_fw_bw(q, k, v, key_padding_mask, attn_mask, static_kv): output, _ = original_mha( query=q, key=k, value=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask, static_kv=static_kv, ) loss = torch.norm(output) loss.backward() def xformers_bench_fw_bw(q, k, v, key_padding_mask, attn_mask, static_kv): output, _ = xformers_mha( query=q, key=k, value=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask, static_kv=static_kv, ) loss = torch.norm(output) loss.backward() fns = [ original_bench_fw, xformers_bench_fw, original_bench_fw_bw, xformers_bench_fw_bw, ] for fn in fns: results.append( benchmark.Timer( stmt="fn(q, k, v, key_padding_mask, attn_mask, static_kv)", globals={ "q": q, "k": k, "v": v, "key_padding_mask": key_padding_mask, "attn_mask": attn_mask, "static_kv": static_kv, "fn": fn, }, label="multihead fw + bw", sub_label=f"{fn.__name__}", description=label, ).blocked_autorange(min_run_time=1)) compare = benchmark.Compare(results) compare.print()
def forward( self, query, key: Optional[Tensor], value: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, need_weights: bool = True, static_kv: bool = False, attn_mask: Optional[Tensor] = None, before_softmax: bool = False, need_head_weights: bool = True, gold_dependency: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: """Input shape: Time x Batch x Channel Args: key_padding_mask (ByteTensor, optional): mask to exclude keys that are pads, of shape `(batch, src_len)`, where padding elements are indicated by 1s. need_weights (bool, optional): return the attention weights, averaged over heads (default: False). attn_mask (ByteTensor, optional): typically used to implement causal attention, where the mask prevents the attention from looking forward in time (default: None). before_softmax (bool, optional): return the raw attention weights and values before the attention softmax. need_head_weights (bool, optional): return the attention weights for each head. Implies *need_weights*. Default: return the average attention weights over all heads. """ if need_head_weights: need_weights = True tgt_len, bsz, embed_dim = query.size() assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] if incremental_state is not None: saved_state = self._get_input_buffer(incremental_state) if saved_state is not None and "prev_key" in saved_state: # previous time steps are cached - no need to recompute # key and value if they are static if static_kv: assert self.encoder_decoder_attention and not self.self_attention key = value = None else: saved_state = None if self.self_attention: q = self.q_proj(query) k = self.k_proj(query) v = self.v_proj(query) elif self.encoder_decoder_attention: # encoder-decoder attention q = self.q_proj(query) if key is None: assert value is None k = v = None else: k = self.k_proj(key) v = self.v_proj(key) else: assert key is not None and value is not None q = self.q_proj(query) k = self.k_proj(key) v = self.v_proj(value) q *= self.scaling # Biaffine operation if gold_dependency is None: q[:, :, :self.dep_dim] = F.linear(q[:, :, :self.dep_dim], self.weight_biaffine) bias_biaffine = F.linear(k[:, :, :self.dep_dim], self.bias_biaffine) bias_biaffine = ( bias_biaffine.contiguous() .view(-1, bsz) .transpose(0, 1) .unsqueeze(1).unsqueeze(2) ) if self.bias_k is not None: assert self.bias_v is not None k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) if attn_mask is not None: attn_mask = torch.cat( [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 ) if key_padding_mask is not None: key_padding_mask = torch.cat( [ key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1), ], dim=1, ) q = ( q.contiguous() .view(tgt_len, bsz * self.num_heads, self.head_dim) .transpose(0, 1) ) if k is not None: k = ( k.contiguous() .view(-1, bsz * self.num_heads, self.head_dim) .transpose(0, 1) ) if v is not None: v = ( v.contiguous() .view(-1, bsz * self.num_heads, self.head_dim) .transpose(0, 1) ) if saved_state is not None: # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) if "prev_key" in saved_state: _prev_key = saved_state["prev_key"] assert _prev_key is not None prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) if static_kv: k = prev_key else: assert k is not None k = torch.cat([prev_key, k], dim=1) if "prev_value" in saved_state: _prev_value = saved_state["prev_value"] assert _prev_value is not None prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) if static_kv: v = prev_value else: assert v is not None v = torch.cat([prev_value, v], dim=1) prev_key_padding_mask: Optional[Tensor] = None if "prev_key_padding_mask" in saved_state: prev_key_padding_mask = saved_state["prev_key_padding_mask"] assert k is not None and v is not None key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( key_padding_mask=key_padding_mask, prev_key_padding_mask=prev_key_padding_mask, batch_size=bsz, src_len=k.size(1), static_kv=static_kv, ) saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) saved_state["prev_key_padding_mask"] = key_padding_mask # In this branch incremental_state is never None assert incremental_state is not None incremental_state = self._set_input_buffer(incremental_state, saved_state) assert k is not None src_len = k.size(1) # This is part of a workaround to get around fork/join parallelism # not supporting Optional types. if key_padding_mask is not None and key_padding_mask.dim() == 0: key_padding_mask = None if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len if self.add_zero_attn: assert v is not None src_len += 1 k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) if attn_mask is not None: attn_mask = torch.cat( [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 ) if key_padding_mask is not None: key_padding_mask = torch.cat( [ key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as( key_padding_mask ), ], dim=1, ) attn_weights = torch.bmm(q, k.transpose(1, 2)) # DBSA's biaffine operation (bias term) if gold_dependency is None: attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights[:, :self.dep_heads, :, :] += bias_biaffine attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] if attn_mask is not None: attn_mask = attn_mask.unsqueeze(0) if self.onnx_trace: attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) attn_weights += attn_mask if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) if not self.tpu: attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") ) else: attn_weights = attn_weights.transpose(0, 2) attn_weights = attn_weights.masked_fill(key_padding_mask, float('-inf')) attn_weights = attn_weights.transpose(0, 2) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if before_softmax: return attn_weights, v attn_weights_float = utils.softmax( attn_weights, dim=-1, onnx_trace=self.onnx_trace ) # gold dependency if gold_dependency is not None: attn_weights_float = ( attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len) .transpose(0, 1) .contiguous() .view(self.num_heads, bsz * tgt_len, src_len) ) attn_weights_float[:self.dep_heads] = 0 attn_weights_float[:self.dep_heads, gold_dependency[:, 0][:, None], gold_dependency[:, 1][:, None]] = 1 attn_weights_float = ( attn_weights_float.view(self.num_heads, bsz, tgt_len, src_len) .transpose(0, 1) .contiguous() .view(bsz * self.num_heads, tgt_len, src_len) ) attn_weights = attn_weights_float.type_as(attn_weights) attn_probs = self.dropout_module(attn_weights) assert v is not None attn = torch.bmm(attn_probs, v) assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] if self.onnx_trace and attn.size(1) == 1: # when ONNX tracing a single decoder step (sequence length == 1) # the transpose is a no-op copy before view, thus unnecessary attn = attn.contiguous().view(tgt_len, bsz, embed_dim) else: attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn = self.out_proj(attn) attn_weights: Optional[Tensor] = None if need_weights: attn_weights = attn_weights_float.view( bsz, self.num_heads, tgt_len, src_len ).transpose(1, 0) if not need_head_weights: # average attention weights over heads attn_weights = attn_weights.mean(dim=0) return attn, attn_weights
def test_xformers_single_forward_parity( device, attn_dtype, key_padding_dtype, add_bias_kv, add_zero_attn, static_kv, batch_size, embedding, seq_len, num_heads, ): xformers_att_config = '{"name": "scaled_dot_product"}' attn_mask = (None if attn_dtype is None else _get_mask( to_dtype=attn_dtype, dim0=seq_len, dim1=seq_len).to(device)) key_padding_mask = (None if key_padding_dtype is None else _get_mask( to_dtype=key_padding_dtype, dim0=batch_size, dim1=seq_len).to(device)) q = torch.rand(seq_len, batch_size, embedding).to(device) q.requires_grad = True k = torch.rand(seq_len, batch_size, embedding).to(device) k.requires_grad = True v = torch.rand(seq_len, batch_size, embedding).to(device) v.requires_grad = True q_ = q.detach().clone() q_.requires_grad = True k_ = k.detach().clone() k_.requires_grad = True v_ = v.detach().clone() v_.requires_grad = True # TODO: dropouts in the two implementations lead to different entries dropped. _reset_seeds() xformers_mha = MultiheadAttention( embedding, num_heads, dropout=0.0, xformers_att_config=xformers_att_config, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, ).to(device) xformers_output, _ = xformers_mha( q, k, v, key_padding_mask=key_padding_mask, attn_mask=attn_mask, static_kv=static_kv, ) _reset_seeds() original_mha = MultiheadAttention( embedding, num_heads, dropout=0.0, xformers_att_config=None, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, ).to(device) original_output, _ = original_mha( q_, k_, v_, key_padding_mask=key_padding_mask, attn_mask=attn_mask, static_kv=static_kv, ) # account for when nan != nan if xformers_output.isnan().any() or original_output.isnan().any(): rand = random.uniform(0, 1) xformers_output = xformers_output.masked_fill(xformers_output.isnan(), rand) original_output = original_output.masked_fill(original_output.isnan(), rand) # torch.equal works for cpu, on cuda allclose is needed. assert torch.allclose( xformers_output, original_output, atol=1e-06 ), f"max diff is {torch.max(torch.abs(xformers_output - original_output))}" loss_xformers = torch.norm(xformers_output) loss_original = torch.norm(original_output) loss_xformers.backward() loss_original.backward() # torch.equal works for cpu, on cuda allclose is needed. assert torch.allclose( q.grad, q_.grad), f"max diff is {torch.max(torch.abs(q.grad - q_.grad))}" assert torch.allclose( k.grad, k_.grad), f"max diff is {torch.max(torch.abs(k.grad - k_.grad))}" assert torch.allclose( v.grad, v_.grad), f"max diff is {torch.max(torch.abs(v.grad - v_.grad))}"
def test_xformers_blocksparse_parity( device, add_zero_attn, batch_size, embedding, seq_len, num_heads, ): xformers_att_config = '{"name": "scaled_dot_product"}' xformers_blocksparse_blocksize = 16 xformers_blocksparse_layout = torch.ones( seq_len // xformers_blocksparse_blocksize, seq_len // xformers_blocksparse_blocksize, dtype=torch.int32, ) q = torch.rand(seq_len, batch_size, embedding).to(device).half() q.requires_grad = True k = torch.rand(seq_len, batch_size, embedding).to(device).half() k.requires_grad = True v = torch.rand(seq_len, batch_size, embedding).to(device).half() v.requires_grad = True q_ = q.detach().clone().half() q_.requires_grad = True k_ = k.detach().clone().half() k_.requires_grad = True v_ = v.detach().clone().half() v_.requires_grad = True _reset_seeds() xf_blocksparse_mha = (MultiheadAttention( embedding, num_heads, dropout=0.0, add_zero_attn=add_zero_attn, xformers_att_config=xformers_att_config, xformers_blocksparse_layout=xformers_blocksparse_layout, xformers_blocksparse_blocksize=xformers_blocksparse_blocksize, ).to(device).half()) xf_blocksparse_output, _ = xf_blocksparse_mha( q, k, v, ) _reset_seeds() xformers_mha = (MultiheadAttention( embedding, num_heads, dropout=0.0, add_zero_attn=add_zero_attn, xformers_att_config=xformers_att_config, xformers_blocksparse_layout=None, ).to(device).half()) xformers_output, _ = xformers_mha( q_, k_, v_, ) # # account for when nan != nan rand = random.uniform(0, 1) xformers_output = xformers_output.masked_fill(xformers_output.isnan(), rand) xf_blocksparse_output = xf_blocksparse_output.masked_fill( xf_blocksparse_output.isnan(), rand) assert_almost_equal(xformers_output, xf_blocksparse_output) loss_blocksparse = torch.norm(xformers_output) loss_original = torch.norm(xf_blocksparse_output) loss_blocksparse.backward() loss_original.backward() q.masked_fill(q.isnan(), rand) q_.masked_fill(q_.isnan(), rand) k.masked_fill(k.isnan(), rand) k_.masked_fill(k_.isnan(), rand) v.masked_fill(v.isnan(), rand) v_.masked_fill(v_.isnan(), rand) assert_almost_equal(q.grad, q_.grad) assert_almost_equal(k.grad, k_.grad) assert_almost_equal(v.grad, v_.grad)
def forward( self, query, key: Optional[Tensor], value: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, need_weights: bool = True, static_kv: bool = False, attn_mask: Optional[Tensor] = None, before_softmax: bool = False, need_head_weights: bool = False, ) -> Tuple[Tensor, Optional[Tensor]]: """Input shape: Time x Batch x Channel Args: key_padding_mask (ByteTensor, optional): mask to exclude keys that are pads, of shape `(batch, src_len)`, where padding elements are indicated by 1s. need_weights (bool, optional): return the attention weights, averaged over heads (default: False). attn_mask (ByteTensor, optional): typically used to implement causal attention, where the mask prevents the attention from looking forward in time (default: None). before_softmax (bool, optional): return the raw attention weights and values before the attention softmax. need_head_weights (bool, optional): return the attention weights for each head. Implies *need_weights*. Default: return the average attention weights over all heads. """ if need_head_weights: need_weights = True tgt_len, bsz, embed_dim = query.size() assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] if (not self.onnx_trace and not self.tpu # don't use PyTorch version on TPUs and incremental_state is None and not static_kv # A workaround for quantization to work. Otherwise JIT compilation # treats bias in linear module as method. and not torch.jit.is_scripting()): assert key is not None and value is not None return F.multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, torch.empty([0]), torch.cat( (self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), self.bias_k, self.bias_v, self.add_zero_attn, self.dropout_module.p, self.out_proj.weight, self.out_proj.bias, self.training or self.dropout_module.apply_during_inference, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight=True, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight, ) if incremental_state is not None: saved_state = self._get_input_buffer(incremental_state) if saved_state is not None and "prev_key" in saved_state: # previous time steps are cached - no need to recompute # key and value if they are static if static_kv: assert self.encoder_decoder_attention and not self.self_attention key = value = None else: saved_state = None if self.self_attention: q: Tensor = self.q_proj(query) k: Tensor = self.k_proj(query) v: Tensor = self.v_proj(query) elif self.encoder_decoder_attention: q = self.q_proj(query) if key is None: assert value is None k = v = None else: if self.beam_size > 1 and bsz == key.size(1): # key is [T, bsz*beam_size, C], reduce to [T, bsz, C] key = key.view(key.size(0), -1, self.beam_size, key.size(2))[:, :, 0, :] if key_padding_mask is not None: key_padding_mask = key_padding_mask.view( -1, self.beam_size, key_padding_mask.size(1))[:, 0, :] k = self.k_proj(key) v = self.v_proj(key) else: assert key is not None and value is not None q = self.q_proj(query) k = self.k_proj(key) v = self.v_proj(value) q *= self.scaling if self.bias_k is not None: assert self.bias_v is not None k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) if attn_mask is not None: attn_mask = torch.cat( [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) if key_padding_mask is not None: key_padding_mask = torch.cat( [ key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1), ], dim=1, ) q = (q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)) kv_bsz = 0 if k is not None: kv_bsz = k.size(1) k = (k.contiguous().view(-1, kv_bsz * self.num_heads, self.head_dim).transpose(0, 1)) if v is not None: assert kv_bsz v = (v.contiguous().view(-1, kv_bsz * self.num_heads, self.head_dim).transpose(0, 1)) if saved_state is not None: # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) if "prev_key" in saved_state: _prev_key = saved_state["prev_key"] assert _prev_key is not None kv_bsz = _prev_key.size(0) prev_key = _prev_key.view(kv_bsz * self.num_heads, -1, self.head_dim) if static_kv: k = prev_key else: assert k is not None k = torch.cat([prev_key, k], dim=1) if "prev_value" in saved_state: _prev_value = saved_state["prev_value"] assert _prev_value is not None assert kv_bsz == _prev_value.size(0) prev_value = _prev_value.view(kv_bsz * self.num_heads, -1, self.head_dim) if static_kv: v = prev_value else: assert v is not None v = torch.cat([prev_value, v], dim=1) prev_key_padding_mask: Optional[Tensor] = None if "prev_key_padding_mask" in saved_state: prev_key_padding_mask = saved_state["prev_key_padding_mask"] assert k is not None and v is not None key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( key_padding_mask=key_padding_mask, prev_key_padding_mask=prev_key_padding_mask, batch_size=kv_bsz, src_len=k.size(1), static_kv=static_kv, ) saved_state["prev_key"] = k.view(kv_bsz, self.num_heads, -1, self.head_dim) saved_state["prev_value"] = v.view(kv_bsz, self.num_heads, -1, self.head_dim) saved_state["prev_key_padding_mask"] = key_padding_mask # In this branch incremental_state is never None assert incremental_state is not None incremental_state = self._set_input_buffer(incremental_state, saved_state) assert k is not None src_len = k.size(1) # This is part of a workaround to get around fork/join parallelism # not supporting Optional types. if key_padding_mask is not None and key_padding_mask.dim() == 0: key_padding_mask = None if key_padding_mask is not None: assert key_padding_mask.size(0) == kv_bsz assert key_padding_mask.size(1) == src_len if self.add_zero_attn: assert v is not None src_len += 1 k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) if attn_mask is not None: attn_mask = torch.cat( [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) if key_padding_mask is not None: key_padding_mask = torch.cat( [ key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask), ], dim=1, ) if self.encoder_decoder_attention and bsz != kv_bsz: q_shape = (kv_bsz, -1, self.num_heads) + q.size()[1:] k_shape = (kv_bsz, self.num_heads) + k.size()[1:] attn_weights = torch.einsum('bxhtd,bhsd->bxhts', q.view(q_shape), k.view(k_shape)) aw_shape = (-1, ) + attn_weights.size()[-2:] attn_weights = attn_weights.reshape(aw_shape) else: attn_weights = torch.bmm(q, k.transpose(1, 2)) attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) assert list( attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] if attn_mask is not None: attn_mask = attn_mask.unsqueeze(0) if self.onnx_trace: attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) attn_weights += attn_mask if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) if not self.tpu: attn_weights = attn_weights.view(kv_bsz, -1, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3).to( torch.bool), float("-inf"), ) else: attn_weights = attn_weights.transpose(0, 2) attn_weights = attn_weights.masked_fill( key_padding_mask, float("-inf")) attn_weights = attn_weights.transpose(0, 2) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if before_softmax: return attn_weights, v attn_weights_float = utils.softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace) attn_weights = attn_weights_float.type_as(attn_weights) attn_probs = self.dropout_module(attn_weights) assert v is not None if self.encoder_decoder_attention and bsz != kv_bsz: ap_shape = (kv_bsz, -1, self.num_heads) + attn_probs.size()[1:] v_shape = (-1, self.num_heads) + v.size()[1:] attn = torch.einsum('bxhts,bhsd->bxhtd', attn_probs.view(ap_shape), v.view(v_shape)) a_shape = (-1, ) + attn.size()[-2:] attn = attn.reshape(a_shape) else: attn = torch.bmm(attn_probs, v) assert list( attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] if self.onnx_trace and attn.size(1) == 1: # when ONNX tracing a single decoder step (sequence length == 1) # the transpose is a no-op copy before view, thus unnecessary attn = attn.contiguous().view(tgt_len, bsz, embed_dim) else: attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn = self.out_proj(attn) attn_weights: Optional[Tensor] = None if need_weights: attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) if not need_head_weights: # average attention weights over heads attn_weights = attn_weights.mean(dim=0) return attn, attn_weights