def _forward_expanded(self, x, incremental_state): """Turn the convolution filters into band matrices and do matrix multiplication. This is faster when the sequence is short, but less memory efficient. This is not used in the decoder during inference. """ T, B, C = x.size() K, H = self.kernel_size, self.num_heads R = C // H assert R * H == C == self.input_size weight = self.weight.view(H, K) if self.weight_softmax: weight = utils.softmax(weight, dim=1, onnx_trace=self.onnx_trace).type_as(weight) weight = weight.view(1, H, K).expand(T * B, H, K).contiguous() weight = weight.view(T, B * H, K).transpose(0, 1) x = x.view(T, B * H, R).transpose(0, 1) P = self.padding_l if K > T and P == K - 1: weight = weight.narrow(2, K - T, T) K, P = T, T - 1 # turn the convolution filters into band matrices weight_expanded = weight.new_zeros(B * H, T, T + K - 1, requires_grad=False) weight_expanded.as_strided((B * H, T, K), (T * (T + K - 1), T + K, 1)).copy_(weight) weight_expanded = weight_expanded.narrow(2, P, T) weight_expanded = self.weight_dropout_module(weight_expanded) output = torch.bmm(weight_expanded, x) output = output.transpose(0, 1).contiguous().view(T, B, C) return output
def get_normalized_probs_scriptable( self, net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], log_probs: bool, sample: Optional[Dict[str, Tensor]] = None, ): """Get normalized probabilities (or log probs) from a net's output.""" if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None: if sample is not None: assert "target" in sample target = sample["target"] else: target = None out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) return out.exp_() if not log_probs else out logits = net_output[0] if log_probs: return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) else: return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
def get_normalized_probs(self, net_output, log_probs): """Get normalized probabilities (or log probs) from a net's output.""" logits = net_output["encoder_out"] if log_probs: return utils.log_softmax(logits.float(), dim=-1) else: return utils.softmax(logits.float(), dim=-1)
def get_normalized_probs(self, net_output, log_probs, sample=None): """Get normalized probabilities (or log probs) from a net's output.""" if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None: if sample is not None: assert "target" in sample target = sample["target"] else: target = None out = self.adaptive_softmax.get_log_prob(net_output, target=target) return out.exp_() if not log_probs else out # A Pipe() module returns a tuple of tensors as the output. # In this case, the tuple has one element - the output tensor of logits logits = net_output if isinstance(net_output, torch.Tensor) else net_output[0] if log_probs: return utils.log_softmax(logits, dim=-1, onnx_trace=False) else: return utils.softmax(logits, dim=-1, onnx_trace=False)
def _forward_unfolded(self, x, incremental_state): """The conventional implementation of convolutions. Unfolding the input by having a window shifting to the right.""" T, B, C = x.size() K, H = self.kernel_size, self.num_heads R = C // H assert R * H == C == self.input_size weight = self.weight.view(H, K) if incremental_state is not None: input_buffer = self._get_input_buffer(incremental_state) if input_buffer is None: input_buffer = x.new() x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) if self.kernel_size > 1: self._set_input_buffer( incremental_state, x_unfold[:, :, :, -self.kernel_size + 1:]) x_unfold = x_unfold.view(T * B * H, R, -1) else: # unfold the input: T x B x C --> T' x B x C x K x_unfold = unfold1d(x, self.kernel_size, self.padding_l, 0) x_unfold = x_unfold.view(T * B * H, R, K) if self.weight_softmax: weight = utils.softmax(weight, dim=1, onnx_trace=self.onnx_trace).type_as(weight) if incremental_state is not None: weight = weight[:, -x_unfold.size(2):] K = weight.size(1) weight = (weight.view(1, H, K).expand(T * B, H, K).contiguous().view(T * B * H, K, 1)) weight = self.weight_dropout_module(weight) output = torch.bmm(x_unfold, weight) # T*B*H x R x 1 output = output.view(T, B, C) return output
def forward( self, query, key: Optional[Tensor], value: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, need_weights: bool = True, static_kv: bool = False, attn_mask: Optional[Tensor] = None, before_softmax: bool = False, need_head_weights: bool = False, ) -> Tuple[Tensor, Optional[Tensor]]: """Input shape: Time x Batch x Channel Args: key_padding_mask (ByteTensor, optional): mask to exclude keys that are pads, of shape `(batch, src_len)`, where padding elements are indicated by 1s. need_weights (bool, optional): return the attention weights, averaged over heads (default: False). attn_mask (ByteTensor, optional): typically used to implement causal attention, where the mask prevents the attention from looking forward in time (default: None). before_softmax (bool, optional): return the raw attention weights and values before the attention softmax. need_head_weights (bool, optional): return the attention weights for each head. Implies *need_weights*. Default: return the average attention weights over all heads. """ if need_head_weights: need_weights = True is_tpu = query.device.type == "xla" tgt_len, bsz, embed_dim = query.size() assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] if (not self.onnx_trace and not is_tpu # don't use PyTorch version on TPUs and incremental_state is None and not static_kv # A workaround for quantization to work. Otherwise JIT compilation # treats bias in linear module as method. and not torch.jit.is_scripting()): assert key is not None and value is not None return F.multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, torch.empty([0]), torch.cat( (self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), self.bias_k, self.bias_v, self.add_zero_attn, self.dropout_module.p, self.out_proj.weight, self.out_proj.bias, self.training or self.dropout_module.apply_during_inference, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight=True, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight, ) if incremental_state is not None: saved_state = self._get_input_buffer(incremental_state) if saved_state is not None and "prev_key" in saved_state: # previous time steps are cached - no need to recompute # key and value if they are static if static_kv: assert self.encoder_decoder_attention and not self.self_attention key = value = None else: saved_state = None if self.self_attention: q = self.q_proj(query) k = self.k_proj(query) v = self.v_proj(query) elif self.encoder_decoder_attention: # encoder-decoder attention q = self.q_proj(query) if key is None: assert value is None k = v = None else: k = self.k_proj(key) v = self.v_proj(key) else: assert key is not None and value is not None q = self.q_proj(query) k = self.k_proj(key) v = self.v_proj(value) q *= self.scaling if self.bias_k is not None: assert self.bias_v is not None k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) if attn_mask is not None: attn_mask = torch.cat( [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) if key_padding_mask is not None: key_padding_mask = torch.cat( [ key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1), ], dim=1, ) q = (q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)) if k is not None: k = (k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)) if v is not None: v = (v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)) if saved_state is not None: # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) if "prev_key" in saved_state: _prev_key = saved_state["prev_key"] assert _prev_key is not None prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) if static_kv: k = prev_key else: assert k is not None k = torch.cat([prev_key, k], dim=1) if "prev_value" in saved_state: _prev_value = saved_state["prev_value"] assert _prev_value is not None prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) if static_kv: v = prev_value else: assert v is not None v = torch.cat([prev_value, v], dim=1) prev_key_padding_mask: Optional[Tensor] = None if "prev_key_padding_mask" in saved_state: prev_key_padding_mask = saved_state["prev_key_padding_mask"] assert k is not None and v is not None key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( key_padding_mask=key_padding_mask, prev_key_padding_mask=prev_key_padding_mask, batch_size=bsz, src_len=k.size(1), static_kv=static_kv, ) saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) saved_state["prev_key_padding_mask"] = key_padding_mask # In this branch incremental_state is never None assert incremental_state is not None incremental_state = self._set_input_buffer(incremental_state, saved_state) assert k is not None src_len = k.size(1) # This is part of a workaround to get around fork/join parallelism # not supporting Optional types. if key_padding_mask is not None and key_padding_mask.dim() == 0: key_padding_mask = None if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len if self.add_zero_attn: assert v is not None src_len += 1 k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) if attn_mask is not None: attn_mask = torch.cat( [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) if key_padding_mask is not None: key_padding_mask = torch.cat( [ key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask), ], dim=1, ) attn_weights = torch.bmm(q, k.transpose(1, 2)) attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) assert list( attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] if attn_mask is not None: attn_mask = attn_mask.unsqueeze(0) if self.onnx_trace: attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) attn_weights += attn_mask if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) if not is_tpu: attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf"), ) else: attn_weights = attn_weights.transpose(0, 2) attn_weights = attn_weights.masked_fill( key_padding_mask, float("-inf")) attn_weights = attn_weights.transpose(0, 2) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if before_softmax: return attn_weights, v attn_weights_float = utils.softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace) attn_weights = attn_weights_float.type_as(attn_weights) attn_probs = self.dropout_module(attn_weights) assert v is not None attn = torch.bmm(attn_probs, v) assert list( attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] if self.onnx_trace and attn.size(1) == 1: # when ONNX tracing a single decoder step (sequence length == 1) # the transpose is a no-op copy before view, thus unnecessary attn = attn.contiguous().view(tgt_len, bsz, embed_dim) else: attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn = self.out_proj(attn) attn_weights: Optional[Tensor] = None if need_weights: attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) if not need_head_weights: # average attention weights over heads attn_weights = attn_weights.mean(dim=0) return attn, attn_weights
def forward( self, query, key: Optional[Tensor], value: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, static_kv: bool = False, attn_mask: Optional[Tensor] = None, **unused_kwargs, ) -> Tuple[Tensor, Optional[Tensor]]: """Input shape: Time x Batch x Channel Args: key_padding_mask (ByteTensor, optional): mask to exclude keys that are pads, of shape `(batch, src_len)`, where padding elements are indicated by 1s. attn_mask (ByteTensor, optional): typically used to implement causal attention, where the mask prevents the attention from looking forward in time (default: None). """ tgt_len, bsz, embed_dim = query.size() assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] is_tpu = query.device.type == "xla" if incremental_state is not None: saved_state = self._get_input_buffer(incremental_state) if saved_state is not None and "prev_key" in saved_state: # previous time steps are cached - no need to recompute # key and value if they are static if static_kv: assert self.encoder_decoder_attention and not self.self_attention key = value = None else: saved_state = None if self.self_attention: q = self.q_proj(query) k = self.k_proj(query) v = self.v_proj(query) elif self.encoder_decoder_attention: # encoder-decoder attention q = self.q_proj(query) if key is None: assert value is None k = v = None else: k = self.k_proj(key) v = self.v_proj(key) else: assert key is not None and value is not None q = self.q_proj(query) k = self.k_proj(key) v = self.v_proj(value) q *= self.scaling q = (q.contiguous().view(tgt_len, bsz * self.num_heads_partition, self.head_dim).transpose(0, 1)) if k is not None: k = (k.contiguous().view(-1, bsz * self.num_heads_partition, self.head_dim).transpose(0, 1)) if v is not None: v = (v.contiguous().view(-1, bsz * self.num_heads_partition, self.head_dim).transpose(0, 1)) if saved_state is not None: # saved states are stored with shape (bsz, num_heads_partition, seq_len, head_dim) if "prev_key" in saved_state: _prev_key = saved_state["prev_key"] assert _prev_key is not None prev_key = _prev_key.view(bsz * self.num_heads_partition, -1, self.head_dim) if static_kv: k = prev_key else: assert k is not None k = torch.cat([prev_key, k], dim=1) if "prev_value" in saved_state: _prev_value = saved_state["prev_value"] assert _prev_value is not None prev_value = _prev_value.view(bsz * self.num_heads_partition, -1, self.head_dim) if static_kv: v = prev_value else: assert v is not None v = torch.cat([prev_value, v], dim=1) prev_key_padding_mask: Optional[Tensor] = None if "prev_key_padding_mask" in saved_state: prev_key_padding_mask = saved_state["prev_key_padding_mask"] assert k is not None and v is not None key_padding_mask = ( ModelParallelMultiheadAttention._append_prev_key_padding_mask( key_padding_mask=key_padding_mask, prev_key_padding_mask=prev_key_padding_mask, batch_size=bsz, src_len=k.size(1), static_kv=static_kv, )) saved_state["prev_key"] = k.view(bsz, self.num_heads_partition, -1, self.head_dim) saved_state["prev_value"] = v.view(bsz, self.num_heads_partition, -1, self.head_dim) saved_state["prev_key_padding_mask"] = key_padding_mask # In this branch incremental_state is never None assert incremental_state is not None incremental_state = self._set_input_buffer(incremental_state, saved_state) assert k is not None src_len = k.size(1) # This is part of a workaround to get around fork/join parallelism # not supporting Optional types. if key_padding_mask is not None and key_padding_mask.dim() == 0: key_padding_mask = None if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len attn_weights = torch.bmm(q, k.transpose(1, 2)) assert list(attn_weights.size()) == [ bsz * self.num_heads_partition, tgt_len, src_len, ] if attn_mask is not None: attn_mask = attn_mask.unsqueeze(0) attn_weights += attn_mask if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.num_heads_partition, tgt_len, src_len) if not is_tpu: attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf"), ) else: attn_weights = attn_weights.transpose(0, 2) attn_weights = attn_weights.masked_fill( key_padding_mask, float("-inf")) attn_weights = attn_weights.transpose(0, 2) attn_weights = attn_weights.view(bsz * self.num_heads_partition, tgt_len, src_len) attn_weights_float = utils.softmax(attn_weights, dim=-1) attn_weights = attn_weights_float.type_as(attn_weights) with get_cuda_rng_tracker().fork(): attn_probs = self.dropout_module(attn_weights) assert v is not None attn = torch.bmm(attn_probs, v) assert list(attn.size()) == [ bsz * self.num_heads_partition, tgt_len, self.head_dim, ] embed_dim_partition = embed_dim // self.model_parallel_size attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim_partition) attn = self.out_proj(attn) # return attn_weights None to keep the return type same as single gpu multihead attention # This will be deprecated. attn_weights: Optional[Tensor] = None return attn, attn_weights