def _forward_unfolded(self, x, incremental_state): '''The conventional implementation of convolutions. Unfolding the input by having a window shifting to the right.''' T, B, C = x.size() K, H = self.kernel_size, self.num_heads R = C // H assert R * H == C == self.input_size weight = self.weight.view(H, K) if incremental_state is not None: input_buffer = self._get_input_buffer(incremental_state) if input_buffer is None: input_buffer = x.new() x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) if self.kernel_size > 1: self._set_input_buffer(incremental_state, x_unfold[:, :, :, -self.kernel_size+1:]) x_unfold = x_unfold.view(T*B*H, R, -1) else: # unfold the input: T x B x C --> T' x B x C x K x_unfold = unfold1d(x, self.kernel_size, self.padding_l, 0) x_unfold = x_unfold.view(T*B*H, R, K) if self.weight_softmax: weight = utils.softmax(weight, dim=1, onnx_trace=self.onnx_trace).type_as(weight) if incremental_state is not None: weight = weight[:, -x_unfold.size(2):] K = weight.size(1) weight = weight.view(1, H, K).expand(T*B, H, K).contiguous().view(T*B*H, K, 1) weight = self.weight_dropout_module(weight) output = torch.bmm(x_unfold, weight) # T*B*H x R x 1 output = output.view(T, B, C) return output
def _forward_unfolded(self, x, incremental_state, query): '''The conventional implementation of convolutions. Unfolding the input by having a window shifting to the right.''' T, B, C = x.size() K, H = self.kernel_size, self.num_heads R = C // H assert R * H == C == self.input_size if self.in_proj: proj = self.weight_linear(x) x = proj.narrow(2, 0, self.input_size).contiguous() weight = proj.narrow(2, self.input_size, H * K).contiguous().view(T * B * H, -1) else: weight = self.weight_linear(query).view(T * B * H, -1) # renorm_padding is only implemented in _forward_expanded assert not self.renorm_padding or incremental_state is not None if incremental_state is not None: input_buffer = self._get_input_buffer(incremental_state) if input_buffer is None: input_buffer = x.new() x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) if self.kernel_size > 1: self._set_input_buffer(incremental_state, x_unfold[:, :, :, -self.kernel_size + 1:]) x_unfold = x_unfold.view(T * B * H, R, -1) else: padding_l = self.padding_l if K > T and padding_l == K - 1: weight = weight.narrow(1, K - T, T) K, padding_l = T, T - 1 # unfold the input: T x B x C --> T' x B x C x K x_unfold = unfold1d(x, K, padding_l, 0) x_unfold = x_unfold.view(T * B * H, R, K) if self.weight_softmax and not self.renorm_padding: weight = F.softmax(weight, dim=1) weight = weight.narrow(1, 0, K) if incremental_state is not None: weight = weight[:, -x_unfold.size(2):] K = weight.size(1) if self.weight_softmax and self.renorm_padding: weight = F.softmax(weight, dim=1) weight = F.dropout(weight, self.weight_dropout, training=self.training, inplace=False) if bmm_fp16_support: output = torch.bmm(x_unfold, weight.unsqueeze(2)) # T*B*H x R x 1 else: output = torch.bmm(x_unfold.float(), weight.unsqueeze(2).float()).type_as(weight) # T*B*H x R x 1 output = output.view(T, B, C) return output