def _forward_unfolded(self, x, incremental_state):
        '''The conventional implementation of convolutions.
        Unfolding the input by having a window shifting to the right.'''
        T, B, C = x.size()
        K, H = self.kernel_size, self.num_heads
        R = C // H
        assert R * H == C == self.input_size

        weight = self.weight.view(H, K)
        if incremental_state is not None:
            input_buffer = self._get_input_buffer(incremental_state)
            if input_buffer is None:
                input_buffer = x.new()
            x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3)
            if self.kernel_size > 1:
                self._set_input_buffer(incremental_state, x_unfold[:, :, :, -self.kernel_size+1:])
            x_unfold = x_unfold.view(T*B*H, R, -1)
        else:
            # unfold the input: T x B x C --> T' x B x C x K
            x_unfold = unfold1d(x, self.kernel_size, self.padding_l, 0)
            x_unfold = x_unfold.view(T*B*H, R, K)

        if self.weight_softmax:
            weight = utils.softmax(weight, dim=1, onnx_trace=self.onnx_trace).type_as(weight)

        if incremental_state is not None:
            weight = weight[:, -x_unfold.size(2):]
            K = weight.size(1)

        weight = weight.view(1, H, K).expand(T*B, H, K).contiguous().view(T*B*H, K, 1)

        weight = self.weight_dropout_module(weight)
        output = torch.bmm(x_unfold, weight)  # T*B*H x R x 1
        output = output.view(T, B, C)
        return output
Esempio n. 2
0
    def _forward_unfolded(self, x, incremental_state, query):
        '''The conventional implementation of convolutions.
        Unfolding the input by having a window shifting to the right.'''
        T, B, C = x.size()
        K, H = self.kernel_size, self.num_heads
        R = C // H
        assert R * H == C == self.input_size

        if self.in_proj:
            proj = self.weight_linear(x)
            x = proj.narrow(2, 0, self.input_size).contiguous()
            weight = proj.narrow(2, self.input_size, H * K).contiguous().view(T * B * H, -1)
        else:
            weight = self.weight_linear(query).view(T * B * H, -1)

        # renorm_padding is only implemented in _forward_expanded
        assert not self.renorm_padding or incremental_state is not None

        if incremental_state is not None:
            input_buffer = self._get_input_buffer(incremental_state)
            if input_buffer is None:
                input_buffer = x.new()
            x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3)
            if self.kernel_size > 1:
                self._set_input_buffer(incremental_state, x_unfold[:, :, :, -self.kernel_size + 1:])
            x_unfold = x_unfold.view(T * B * H, R, -1)
        else:
            padding_l = self.padding_l
            if K > T and padding_l == K - 1:
                weight = weight.narrow(1, K - T, T)
                K, padding_l = T, T - 1
            # unfold the input: T x B x C --> T' x B x C x K
            x_unfold = unfold1d(x, K, padding_l, 0)
            x_unfold = x_unfold.view(T * B * H, R, K)

        if self.weight_softmax and not self.renorm_padding:
            weight = F.softmax(weight, dim=1)
        weight = weight.narrow(1, 0, K)

        if incremental_state is not None:
            weight = weight[:, -x_unfold.size(2):]
            K = weight.size(1)

        if self.weight_softmax and self.renorm_padding:
            weight = F.softmax(weight, dim=1)

        weight = F.dropout(weight, self.weight_dropout, training=self.training, inplace=False)
        if bmm_fp16_support:
            output = torch.bmm(x_unfold, weight.unsqueeze(2))  # T*B*H x R x 1
        else:
            output = torch.bmm(x_unfold.float(), weight.unsqueeze(2).float()).type_as(weight)  # T*B*H x R x 1
        output = output.view(T, B, C)
        return output