def forward(self, x, state_old):
        states = States(state_old)
        x = self.bn(x)
        x = self.act(x)
        if self.pad is not None:
            #x = self.pad(x)
            pad = self.pad
            if pad[0] > 0:
                #save last times to use on next iter for left padding
                x = states.pad_left(x, pad[0], 3)
                pad = (0, ) + pad[1:]
            x = torch.nn.functional.pad(x, pad, mode='constant', value=0)

        x = self.conv(x)
        return x, states.state
    def forward(self, x, state_old=None):

        B, C, F, _ = get_shape(x)

        # [B,C,F,T] -> [B,3C,F,T]
        qkv = self.linear_qkv(x)

        # [B,3C,F,T] -> [B,F,3C,T]
        qkv = qkv.transpose(1, 2)

        # [B,F,3C,T] -> 3*[B,F,C,T]
        query, key, value = qkv.chunk(3, dim=2)

        #[B,F,C,T] -> [B,F,C,attn_range+T]
        states = States(state_old)
        key = states.pad_left(key, self.attn_range[0], dim=3)
        value = states.pad_left(value, self.attn_range[0], dim=3)

        # [B, F, C, T] -> [BF, n_heads, k_channels, T]
        T_q = get_shape(query)[3]
        T_kv = get_shape(key)[3]
        assert C == self.n_heads * self.k_channels
        query = query.reshape(B * F, self.n_heads, self.k_channels, T_q)
        key = key.reshape(B * F, self.n_heads, self.k_channels, T_kv)
        value = value.reshape(B * F, self.n_heads, self.k_channels, T_kv)

        # transpose [BF, n_heads, k_channels, T] -> [BF, n_heads, T, k_channels]
        query = query.transpose(2, 3)
        key = key.transpose(2, 3)
        value = value.transpose(2, 3)

        assert self.attn_unroll > 0
        q_s = 0
        outputs = []
        while q_s < T_q:
            q_e = min(q_s + self.attn_unroll, T_q)
            kv_s = q_s
            kv_e = min(q_e + sum(self.attn_range), T_kv)
            if q_s == 0 and q_e == T_q:
                query_slice = query
            else:
                query_slice = query[:, :, q_s:q_e, :]
            if kv_s == 0 and kv_e == T_kv:
                key_slice = key
                value_slice = value
            else:
                key_slice = key[:, :, kv_s:kv_e, :]
                value_slice = value[:, :, kv_s:kv_e, :]
            scores = torch.matmul(query_slice, key_slice.transpose(
                -2, -1)) / math.sqrt(self.k_channels)

            # mask key scores that out of attn range
            q_n = q_e - q_s
            k_mn = 1 + sum(self.attn_range)
            mask = torch.ones(q_n,
                              k_mn,
                              dtype=scores.dtype,
                              device=scores.device)
            mask = torch.nn.functional.pad(mask, (0, q_n),
                                           mode='constant',
                                           value=0)
            mask = mask.reshape(-1)[:q_n * (k_mn + q_n - 1)]
            mask = mask.reshape(q_n, (k_mn + q_n - 1))
            mask = mask[:, :kv_e - kv_s]
            mask = mask.unsqueeze(0).unsqueeze(0)

            scores = scores * mask + -1e4 * (1 - mask)
            p_attn = torch.nn.functional.softmax(scores,
                                                 dim=3)  # [b, n_h, l_q, l_kv]

            output = torch.matmul(p_attn, value_slice)
            outputs.append(output)
            q_s = q_e

        output = torch.cat(outputs, 2) if len(outputs) > 1 else outputs[0]

        # [BF, n_h, T_q, d_k] -> [BF, n_h, d_k, T_q]
        output = output.transpose(2, 3)

        # [BF, n_h, d_k, T_q] -> [B, F, C, T_q]
        output = output.reshape(B, F, C, T_q)

        # [B, F, C, T] -> [B, C, F, T]
        output = output.transpose(1, 2)

        #[B, C, F, T] -> [B, Co, F, T]
        output = self.conv_o(output)

        return output, states.state
    def forward(self, x, state_old=None):

        states = States(state_old)

        tail_size = self.wnd_length - self.hop_length
        x_padded = states.pad_left(x, tail_size, 1)

        X = self.encode(x_padded)
        # [B,2,F,T]
        z = X

        #DOWN
        skips = []
        for b in self.blocks_down:

            z, state = b(z, states.state_old)
            states.update(state)

            skips.append(z)
            z = self.pool(z)

        #BOTTOM
        z, state = self.block_bottom(z, states.state_old)
        states.update(state)

        #UP
        for skip, conv_up, block_up in zip(reversed(skips), self.convs_up,
                                           self.blocks_up):
            z = torch.nn.functional.interpolate(z,
                                                scale_factor=2,
                                                mode='nearest')
            Fs = get_shape(skip)[-2]
            Fz = get_shape(z)[-2]
            if Fz != Fs:
                z = torch.nn.functional.pad(z, (0, 0, 0, 1), mode='replicate')
            z = torch.cat([z, skip], 1)

            pad = self.convs_up_pad
            if pad[0] > 0:
                z = states.pad_left(z, pad[0], 3)
                pad = (0, ) + pad[1:]
            z = torch.nn.functional.pad(z, pad, mode='constant', value=0)
            z = conv_up(z)

            z, state = block_up(z, states.state_old)
            states.update(state)

        X = states.pad_left(X, self.ahead, 3, shift_right=True)

        # [B,2,F,T] -> [B,F,T],[B,F,T] ->
        Mr, Mi = z[:, 0], z[:, 1]
        Xr, Xi = X[:, 0], X[:, 1]

        # mask in complex space
        Yr = Xr * Mr - Xi * Mi
        Yi = Xr * Mi + Xi * Mr

        #[B,F,T] + [B,F,T] -> [B,2,F,T]
        Y = torch.stack([Yr, Yi], 1)

        # decode and return only valid samples
        Y_paded = states.pad_left(Y, self.ahead_ifft, 3)
        y = self.decode(Y_paded)
        y = y[:, tail_size:-self.ahead_ifft * self.hop_length]

        assert not states.state_old
        return y, Y, states.state