def forward(self, x, state_old):
        states = States(state_old)

        for m in self.bd:
            x_add, state = m(x, states.state_old)
            states.update(state)
            x = torch.cat([x, x_add], 1)
        x, state = self.bd_out(x, states.state_old)
        states.update(state)

        if self.sa is None:
            return x, states.state

        #apply multihead self attension module
        #[B,C,F,T]
        x_add, state = self.sa(x, state_old=states.state_old)
        states.update(state)

        x = torch.cat([x, x_add], 1)

        x, state = self.sa_out(x, states.state_old)
        states.update(state)

        return x, states.state
    def forward(self, x, state_old=None):

        states = States(state_old)

        tail_size = self.wnd_length - self.hop_length
        x_padded = states.pad_left(x, tail_size, 1)

        X = self.encode(x_padded)
        # [B,2,F,T]
        z = X

        #DOWN
        skips = []
        for b in self.blocks_down:

            z, state = b(z, states.state_old)
            states.update(state)

            skips.append(z)
            z = self.pool(z)

        #BOTTOM
        z, state = self.block_bottom(z, states.state_old)
        states.update(state)

        #UP
        for skip, conv_up, block_up in zip(reversed(skips), self.convs_up,
                                           self.blocks_up):
            z = torch.nn.functional.interpolate(z,
                                                scale_factor=2,
                                                mode='nearest')
            Fs = get_shape(skip)[-2]
            Fz = get_shape(z)[-2]
            if Fz != Fs:
                z = torch.nn.functional.pad(z, (0, 0, 0, 1), mode='replicate')
            z = torch.cat([z, skip], 1)

            pad = self.convs_up_pad
            if pad[0] > 0:
                z = states.pad_left(z, pad[0], 3)
                pad = (0, ) + pad[1:]
            z = torch.nn.functional.pad(z, pad, mode='constant', value=0)
            z = conv_up(z)

            z, state = block_up(z, states.state_old)
            states.update(state)

        X = states.pad_left(X, self.ahead, 3, shift_right=True)

        # [B,2,F,T] -> [B,F,T],[B,F,T] ->
        Mr, Mi = z[:, 0], z[:, 1]
        Xr, Xi = X[:, 0], X[:, 1]

        # mask in complex space
        Yr = Xr * Mr - Xi * Mi
        Yi = Xr * Mi + Xi * Mr

        #[B,F,T] + [B,F,T] -> [B,2,F,T]
        Y = torch.stack([Yr, Yi], 1)

        # decode and return only valid samples
        Y_paded = states.pad_left(Y, self.ahead_ifft, 3)
        y = self.decode(Y_paded)
        y = y[:, tail_size:-self.ahead_ifft * self.hop_length]

        assert not states.state_old
        return y, Y, states.state