class Convolution(Layer, ParamMixin):
    def __init__(self, n_filters=8, filter_shape=(3, 3), padding=(0, 0), stride=(1, 1), parameters=None):
        """A 2D convolutional layer.
        Input shape: (n_images, n_channels, height, width)

        n_filters : int, default 8
            The number of filters (kernels).
        filter_shape : tuple(int, int), default (3, 3)
            The shape of the filters. (height, width)
        parameters : Parameters instance, default None
        stride : tuple(int, int), default (1, 1)
            The step of the convolution. (height, width).
        padding : tuple(int, int), default (0, 0)
            The number of pixel to add to each side of the input. (height, weight)

        self.padding = padding
        self._params = parameters
        self.stride = stride
        self.filter_shape = filter_shape
        self.n_filters = n_filters
        if self._params is None:
            self._params = Parameters()

    def setup(self, X_shape):
        n_channels, self.height, self.width = X_shape[1:]

        W_shape = (self.n_filters, n_channels) + self.filter_shape
        b_shape = (self.n_filters)
        self._params.setup_weights(W_shape, b_shape)

    def forward_pass(self, X):
        n_images, n_channels, height, width = self.shape(X.shape)
        self.last_input = X
        self.col = image_to_column(X, self.filter_shape, self.stride, self.padding)
        self.col_W = self._params['W'].reshape(self.n_filters, -1).T

        out = np.dot(self.col, self.col_W) + self._params['b']
        out = out.reshape(n_images, height, width, -1).transpose(0, 3, 1, 2)
        return out

    def backward_pass(self, delta):
        delta = delta.transpose(0, 2, 3, 1).reshape(-1, self.n_filters)

        d_W = np.dot(self.col.T, delta).transpose(1, 0).reshape(self._params['W'].shape)
        d_b = np.sum(delta, axis=0)
        self._params.update_grad('b', d_b)
        self._params.update_grad('W', d_W)

        d_c = np.dot(delta, self.col_W.T)
        return column_to_image(d_c, self.last_input.shape, self.filter_shape, self.stride, self.padding)

    def shape(self, x_shape):
        height, width = convoltuion_shape(self.height, self.width, self.filter_shape, self.stride, self.padding)
        return x_shape[0], self.n_filters, height, width
class Dense(Layer, ParamMixin):
    def __init__(
        """A fully connected layer.

        output_dim : int
        self._params = parameters
        self.output_dim = output_dim
        self.last_input = None

        if parameters is None:
            self._params = Parameters()

    def setup(self, x_shape):
        self._params.setup_weights((x_shape[1], self.output_dim))

    def forward_pass(self, X):
        self.last_input = X
        return self.weight(X)

    def weight(self, X):
        W = np.dot(X, self._params['W'])
        return W + self._params['b']

    def backward_pass(self, delta):
        dW = np.dot(self.last_input.T, delta)
        db = np.sum(delta, axis=0)

        # Update gradient values
        self._params.update_grad('W', dW)
        self._params.update_grad('b', db)
        return np.dot(delta, self._params['W'].T)

    def shape(self, x_shape):
        return x_shape[0], self.output_dim
class Dense(Layer, ParamMixin):
    def __init__(self, output_dim, parameters=None, ):
        """A fully connected layer.

        output_dim : int
        self._params = parameters
        self.output_dim = output_dim
        self.last_input = None

        if parameters is None:
            self._params = Parameters()

    def setup(self, x_shape):
        self._params.setup_weights((x_shape[1], self.output_dim))

    def forward_pass(self, X):
        self.last_input = X
        return self.weight(X)

    def weight(self, X):
        W = np.dot(X, self._params['W'])
        return W + self._params['b']

    def backward_pass(self, delta):
        dW = np.dot(self.last_input.T, delta)
        db = np.sum(delta, axis=0)

        # Update gradient values
        self._params.update_grad('W', dW)
        self._params.update_grad('b', db)
        return np.dot(delta, self._params['W'].T)

    def shape(self, x_shape):
        return x_shape[0], self.output_dim
class Convolution(Layer, ParamMixin):
    def __init__(self,
                 filter_shape=(3, 3),
                 padding=(0, 0),
                 stride=(1, 1),
        """A 2D convolutional layer.
        Input shape: (n_images, n_channels, height, width)

        n_filters : int, default 8
            The number of filters (kernels).
        filter_shape : tuple(int, int), default (3, 3)
            The shape of the filters. (height, width)
        parameters : Parameters instance, default None
        stride : tuple(int, int), default (1, 1)
            The step of the convolution. (height, width).
        padding : tuple(int, int), default (0, 0)
            The number of pixel to add to each side of the input. (height, weight)

        self.padding = padding
        self._params = parameters
        self.stride = stride
        self.filter_shape = filter_shape
        self.n_filters = n_filters
        if self._params is None:
            self._params = Parameters()

    def setup(self, X_shape):
        n_channels, self.height, self.width = X_shape[1:]

        W_shape = (self.n_filters, n_channels) + self.filter_shape
        b_shape = (self.n_filters)
        self._params.setup_weights(W_shape, b_shape)

    def forward_pass(self, X):
        n_images, n_channels, height, width = self.shape(X.shape)
        self.last_input = X
        self.col = image_to_column(X, self.filter_shape, self.stride,
        self.col_W = self._params['W'].reshape(self.n_filters, -1).T

        out = np.dot(self.col, self.col_W) + self._params['b']
        out = out.reshape(n_images, height, width, -1).transpose(0, 3, 1, 2)
        return out

    def backward_pass(self, delta):
        delta = delta.transpose(0, 2, 3, 1).reshape(-1, self.n_filters)

        d_W = np.dot(self.col.T,
                     delta).transpose(1, 0).reshape(self._params['W'].shape)
        d_b = np.sum(delta, axis=0)
        self._params.update_grad('b', d_b)
        self._params.update_grad('W', d_W)

        d_c = np.dot(delta, self.col_W.T)
        return column_to_image(d_c, self.last_input.shape, self.filter_shape,
                               self.stride, self.padding)

    def shape(self, x_shape):
        height, width = convoltuion_shape(self.height, self.width,
                                          self.filter_shape, self.stride,
        return x_shape[0], self.n_filters, height, width
class LSTM(Layer, ParamMixin):
    def __init__(self,
        self.return_sequences = return_sequences
        self.hidden_dim = hidden_dim
        self.inner_init = get_initializer(inner_init)
        self.activation = get_activation(activation)
        self.activation_d = elementwise_grad(self.activation)
        self.sigmoid_d = elementwise_grad(sigmoid)

        if parameters is None:
            self._params = Parameters()
            self._params = parameters

        self.last_input = None
        self.states = None
        self.outputs = None
        self.gates = None
        self.hprev = None
        self.input_dim = None
        self.W = None
        self.U = None

    def setup(self, x_shape):
        Naming convention:
        i : input gate
        f : forget gate
        c : cell
        o : output gate

        x_shape : np.array(batch size, time steps, input shape)
        self.input_dim = x_shape[2]
        # Input -> Hidden
        W_params = ['W_i', 'W_f', 'W_o', 'W_c']
        # Hidden -> Hidden
        U_params = ['U_i', 'U_f', 'U_o', 'U_c']
        # Bias terms
        b_params = ['b_i', 'b_f', 'b_o', 'b_c']

        # Initialize params
        for param in W_params:
            self._params[param] = self._params.init(
                (self.input_dim, self.hidden_dim))

        for param in U_params:
            self._params[param] = self.inner_init(
                (self.hidden_dim, self.hidden_dim))

        for param in b_params:
            self._params[param] = np.full((self.hidden_dim, ),

        # Combine weights for simplicity
        self.W = [self._params[param] for param in W_params]
        self.U = [self._params[param] for param in U_params]

        # Init gradient arrays for all weights

        self.hprev = np.zeros((x_shape[0], self.hidden_dim))
        self.oprev = np.zeros((x_shape[0], self.hidden_dim))

    def forward_pass(self, X):
        n_samples, n_timesteps, input_shape = X.shape
        p = self._params
        self.last_input = X

        self.states = np.zeros((n_samples, n_timesteps + 1, self.hidden_dim))
        self.outputs = np.zeros((n_samples, n_timesteps + 1, self.hidden_dim))
        self.gates = {
            k: np.zeros((n_samples, n_timesteps, self.hidden_dim))
            for k in ['i', 'f', 'o', 'c']

        self.states[:, -1, :] = self.hprev
        self.outputs[:, -1, :] = self.oprev

        for i in range(n_timesteps):
            t_gates = np.dot(X[:, i, :], self.W) + np.dot(
                self.outputs[:, i - 1, :], self.U)

            # Input
            self.gates['i'][:, i, :] = sigmoid(t_gates[:, 0, :] + p['b_i'])
            # Forget
            self.gates['f'][:, i, :] = sigmoid(t_gates[:, 1, :] + p['b_f'])
            # Output
            self.gates['o'][:, i, :] = sigmoid(t_gates[:, 2, :] + p['b_o'])
            # Cell
            self.gates['c'][:, i, :] = self.activation(t_gates[:, 3, :] +

            # (previous state * forget) + input + cell
            self.states[:, i, :] = self.states[:, i - 1, :] * self.gates['f'][:, i, :] + \
                                   self.gates['i'][:, i, :] * self.gates['c'][:, i, :]
            self.outputs[:, i, :] = self.gates['o'][:, i, :] * self.activation(
                self.states[:, i, :])

        self.hprev = self.states[:, n_timesteps - 1, :].copy()
        self.oprev = self.outputs[:, n_timesteps - 1, :].copy()

        if self.return_sequences:
            return self.outputs[:, 0:-1, :]
            return self.outputs[:, -2, :]

    def backward_pass(self, delta):
        if len(delta.shape) == 2:
            delta = delta[:, np.newaxis, :]

        n_samples, n_timesteps, input_shape = delta.shape

        # Temporal gradient arrays
        grad = {k: np.zeros_like(self._params[k]) for k in self._params.keys()}

        dh_next = np.zeros((n_samples, input_shape))
        output = np.zeros((n_samples, n_timesteps, self.input_dim))

        # Backpropagation through time
        for i in reversed(range(n_timesteps)):
            dhi = delta[:,
                        i, :] * self.gates['o'][:, i, :] * self.activation_d(
                            self.states[:, i, :]) + dh_next

            og = delta[:, i, :] * self.activation(self.states[:, i, :])
            de_o = og * self.sigmoid_d(self.gates['o'][:, i, :])

            grad['W_o'] += np.dot(self.last_input[:, i, :].T, de_o)
            grad['U_o'] += np.dot(self.outputs[:, i - 1, :].T, de_o)
            grad['b_o'] += de_o.sum(axis=0)

            de_f = (dhi * self.states[:, i - 1, :]) * self.sigmoid_d(
                self.gates['f'][:, i, :])
            grad['W_f'] += np.dot(self.last_input[:, i, :].T, de_f)
            grad['U_f'] += np.dot(self.outputs[:, i - 1, :].T, de_f)
            grad['b_f'] += de_f.sum(axis=0)

            de_i = (dhi * self.gates['c'][:, i, :]) * self.sigmoid_d(
                self.gates['i'][:, i, :])
            grad['W_i'] += np.dot(self.last_input[:, i, :].T, de_i)
            grad['U_i'] += np.dot(self.outputs[:, i - 1, :].T, de_i)
            grad['b_i'] += de_i.sum(axis=0)

            de_c = (dhi * self.gates['i'][:, i, :]) * self.activation_d(
                self.gates['c'][:, i, :])
            grad['W_c'] += np.dot(self.last_input[:, i, :].T, de_c)
            grad['U_c'] += np.dot(self.outputs[:, i - 1, :].T, de_c)
            grad['b_c'] += de_c.sum(axis=0)

            dh_next = dhi * self.gates['f'][:, i, :]

        # TODO: propagate error to the next layer

        # Change actual gradient arrays
        for k in grad.keys():
            self._params.update_grad(k, grad[k])
        return output

    def shape(self, x_shape):
        if self.return_sequences:
            return x_shape[0], x_shape[1], self.hidden_dim
            return x_shape[0], self.hidden_dim
class BatchNormalization(Layer, ParamMixin, PhaseMixin):
    def __init__(self, momentum=0.9, eps=1e-5, parameters=None):
        self._params = parameters
        if self._params is None:
            self._params = Parameters()
        self.momentum = momentum
        self.eps = eps
        self.ema_mean = None
        self.ema_var = None

    def setup(self, x_shape):
        self._params.setup_weights((1, x_shape[1]))

    def _forward_pass(self, X):
        gamma = self._params["W"]
        beta = self._params["b"]

        if self.is_testing:
            mu = self.ema_mean
            xmu = X - mu
            var = self.ema_var
            sqrtvar = np.sqrt(var + self.eps)
            ivar = 1.0 / sqrtvar
            xhat = xmu * ivar
            gammax = gamma * xhat
            return gammax + beta

        N, D = X.shape

        # step1: calculate mean
        mu = 1.0 / N * np.sum(X, axis=0)

        # step2: subtract mean vector of every trainings example
        xmu = X - mu

        # step3: following the lower branch - calculation denominator
        sq = xmu**2

        # step4: calculate variance
        var = 1.0 / N * np.sum(sq, axis=0)

        # step5: add eps for numerical stability, then sqrt
        sqrtvar = np.sqrt(var + self.eps)

        # step6: invert sqrtwar
        ivar = 1.0 / sqrtvar

        # step7: execute normalization
        xhat = xmu * ivar

        # step8: Nor the two transformation steps
        gammax = gamma * xhat

        # step9
        out = gammax + beta

        # store running averages of mean and variance during training for use during testing
        if self.ema_mean is None or self.ema_var is None:
            self.ema_mean = mu
            self.ema_var = var
            self.ema_mean = self.momentum * self.ema_mean + (
                1 - self.momentum) * mu
            self.ema_var = self.momentum * self.ema_var + (1 -
                                                           self.momentum) * var
        # store intermediate
        self.cache = (xhat, gamma, xmu, ivar, sqrtvar, var)

        return out

    def forward_pass(self, X):
        if len(X.shape) == 2:
            # input is a regular layer
            return self._forward_pass(X)
        elif len(X.shape) == 4:
            # input is a convolution layer
            N, C, H, W = X.shape
            x_flat = X.transpose(0, 2, 3, 1).reshape(-1, C)
            out_flat = self._forward_pass(x_flat)
            return out_flat.reshape(N, H, W, C).transpose(0, 3, 1, 2)
            raise NotImplementedError(
                "Unknown model with dimensions = {}".format(len(X.shape)))

    def _backward_pass(self, delta):
        # unfold the variables stored in cache
        xhat, gamma, xmu, ivar, sqrtvar, var = self.cache

        # get the dimensions of the input/output
        N, D = delta.shape

        # step9
        dbeta = np.sum(delta, axis=0)
        dgammax = delta  # not necessary, but more understandable

        # step8
        dgamma = np.sum(dgammax * xhat, axis=0)
        dxhat = dgammax * gamma

        # step7
        divar = np.sum(dxhat * xmu, axis=0)
        dxmu1 = dxhat * ivar

        # step6
        dsqrtvar = -1.0 / (sqrtvar**2) * divar

        # step5
        dvar = 0.5 * 1.0 / np.sqrt(var + self.eps) * dsqrtvar

        # step4
        dsq = 1.0 / N * np.ones((N, D)) * dvar

        # step3
        dxmu2 = 2 * xmu * dsq

        # step2
        dx1 = dxmu1 + dxmu2
        dmu = -1 * np.sum(dxmu1 + dxmu2, axis=0)

        # step1
        dx2 = 1.0 / N * np.ones((N, D)) * dmu

        # step0
        dx = dx1 + dx2

        # Update gradient values
        self._params.update_grad("W", dgamma)
        self._params.update_grad("b", dbeta)

        return dx

    def backward_pass(self, X):
        if len(X.shape) == 2:
            # input is a regular layer
            return self._backward_pass(X)
        elif len(X.shape) == 4:
            # input is a convolution layer
            N, C, H, W = X.shape
            x_flat = X.transpose(0, 2, 3, 1).reshape(-1, C)
            out_flat = self._backward_pass(x_flat)
            return out_flat.reshape(N, H, W, C).transpose(0, 3, 1, 2)
            raise NotImplementedError("Unknown model shape: {}".format(

    def shape(self, x_shape):
        return x_shape
class RNN(Layer, ParamMixin):
    """Vanilla RNN."""
    def __init__(self,
        self.return_sequences = return_sequences
        self.hidden_dim = hidden_dim
        self.inner_init = get_initializer(inner_init)
        self.activation = get_activation(activation)
        self.activation_d = elementwise_grad(self.activation)
        if parameters is None:
            self._params = Parameters()
            self._params = parameters
        self.last_input = None
        self.states = None
        self.hprev = None
        self.input_dim = None

    def setup(self, x_shape):
        x_shape : np.array(batch size, time steps, input shape)
        self.input_dim = x_shape[2]

        # Input -> Hidden
        self._params["W"] = self._params.init(
            (self.input_dim, self.hidden_dim))
        # Bias
        self._params["b"] = np.full((self.hidden_dim, ),
        # Hidden -> Hidden layer
        self._params["U"] = self.inner_init((self.hidden_dim, self.hidden_dim))

        # Init gradient arrays

        self.hprev = np.zeros((x_shape[0], self.hidden_dim))

    def forward_pass(self, X):
        self.last_input = X
        n_samples, n_timesteps, input_shape = X.shape
        states = np.zeros((n_samples, n_timesteps + 1, self.hidden_dim))
        states[:, -1, :] = self.hprev.copy()
        p = self._params

        for i in range(n_timesteps):
            states[:, i, :] = np.tanh(
                np.dot(X[:, i, :], p["W"]) +
                np.dot(states[:, i - 1, :], p["U"]) + p["b"])

        self.states = states
        self.hprev = states[:, n_timesteps - 1, :].copy()
        if self.return_sequences:
            return states[:, 0:-1, :]
            return states[:, -2, :]

    def backward_pass(self, delta):
        if len(delta.shape) == 2:
            delta = delta[:, np.newaxis, :]
        n_samples, n_timesteps, input_shape = delta.shape
        p = self._params

        # Temporal gradient arrays
        grad = {k: np.zeros_like(p[k]) for k in p.keys()}

        dh_next = np.zeros((n_samples, input_shape))
        output = np.zeros((n_samples, n_timesteps, self.input_dim))

        # Backpropagation through time
        for i in reversed(range(n_timesteps)):
            dhi = self.activation_d(
                self.states[:, i, :]) * (delta[:, i, :] + dh_next)

            grad["W"] += np.dot(self.last_input[:, i, :].T, dhi)
            grad["b"] += delta[:, i, :].sum(axis=0)
            grad["U"] += np.dot(self.states[:, i - 1, :].T, dhi)

            dh_next = np.dot(dhi, p["U"].T)

            d = np.dot(delta[:, i, :], p["U"].T)
            output[:, i, :] = np.dot(d, p["W"].T)

        # Change actual gradient arrays
        for k in grad.keys():
            self._params.update_grad(k, grad[k])
        return output

    def shape(self, x_shape):
        if self.return_sequences:
            return x_shape[0], x_shape[1], self.hidden_dim
            return x_shape[0], self.hidden_dim
class LSTM(Layer, ParamMixin):
    def __init__(self, hidden_dim, activation='tanh', inner_init='orthogonal', parameters=None, return_sequences=True):
        self.return_sequences = return_sequences
        self.hidden_dim = hidden_dim
        self.inner_init = get_initializer(inner_init)
        self.activation = get_activation(activation)
        self.activation_d = elementwise_grad(self.activation)
        self.sigmoid_d = elementwise_grad(sigmoid)

        if parameters is None:
            self._params = Parameters()
            self._params = parameters

        self.last_input = None
        self.states = None
        self.outputs = None
        self.gates = None
        self.hprev = None
        self.input_dim = None
        self.W = None
        self.U = None

    def setup(self, x_shape):
        Naming convention:
        i : input gate
        f : forget gate
        c : cell
        o : output gate

        x_shape : np.array(batch size, time steps, input shape)
        self.input_dim = x_shape[2]
        # Input -> Hidden
        W_params = ['W_i', 'W_f', 'W_o', 'W_c']
        # Hidden -> Hidden
        U_params = ['U_i', 'U_f', 'U_o', 'U_c']
        # Bias terms
        b_params = ['b_i', 'b_f', 'b_o', 'b_c']

        # Initialize params
        for param in W_params:
            self._params[param] = self._params.init((self.input_dim, self.hidden_dim))

        for param in U_params:
            self._params[param] = self.inner_init((self.hidden_dim, self.hidden_dim))

        for param in b_params:
            self._params[param] = np.full((self.hidden_dim,), self._params.initial_bias)

        # Combine weights for simplicity
        self.W = [self._params[param] for param in W_params]
        self.U = [self._params[param] for param in U_params]

        # Init gradient arrays for all weights

        self.hprev = np.zeros((x_shape[0], self.hidden_dim))
        self.oprev = np.zeros((x_shape[0], self.hidden_dim))

    def forward_pass(self, X):
        n_samples, n_timesteps, input_shape = X.shape
        p = self._params
        self.last_input = X

        self.states = np.zeros((n_samples, n_timesteps + 1, self.hidden_dim))
        self.outputs = np.zeros((n_samples, n_timesteps + 1, self.hidden_dim))
        self.gates = {k: np.zeros((n_samples, n_timesteps, self.hidden_dim)) for k in ['i', 'f', 'o', 'c']}

        self.states[:, -1, :] = self.hprev
        self.outputs[:, -1, :] = self.oprev

        for i in range(n_timesteps):
            t_gates = np.dot(X[:, i, :], self.W) + np.dot(self.outputs[:, i - 1, :], self.U)

            # Input
            self.gates['i'][:, i, :] = sigmoid(t_gates[:, 0, :] + p['b_i'])
            # Forget
            self.gates['f'][:, i, :] = sigmoid(t_gates[:, 1, :] + p['b_f'])
            # Output
            self.gates['o'][:, i, :] = sigmoid(t_gates[:, 2, :] + p['b_o'])
            # Cell
            self.gates['c'][:, i, :] = self.activation(t_gates[:, 3, :] + p['b_c'])

            # (previous state * forget) + input + cell
            self.states[:, i, :] = self.states[:, i - 1, :] * self.gates['f'][:, i, :] + \
                                   self.gates['i'][:, i, :] * self.gates['c'][:, i, :]
            self.outputs[:, i, :] = self.gates['o'][:, i, :] * self.activation(self.states[:, i, :])

        self.hprev = self.states[:, n_timesteps - 1, :].copy()
        self.oprev = self.outputs[:, n_timesteps - 1, :].copy()

        if self.return_sequences:
            return self.outputs[:, 0:-1, :]
            return self.outputs[:, -2, :]

    def backward_pass(self, delta):
        if len(delta.shape) == 2:
            delta = delta[:, np.newaxis, :]

        n_samples, n_timesteps, input_shape = delta.shape

        # Temporal gradient arrays
        grad = {k: np.zeros_like(self._params[k]) for k in self._params.keys()}

        dh_next = np.zeros((n_samples, input_shape))
        output = np.zeros((n_samples, n_timesteps, self.input_dim))

        # Backpropagation through time
        for i in reversed(range(n_timesteps)):
            dhi = delta[:, i, :] * self.gates['o'][:, i, :] * self.activation_d(self.states[:, i, :]) + dh_next

            og = delta[:, i, :] * self.activation(self.states[:, i, :])
            de_o = og * self.sigmoid_d(self.gates['o'][:, i, :])

            grad['W_o'] += np.dot(self.last_input[:, i, :].T, de_o)
            grad['U_o'] += np.dot(self.outputs[:, i - 1, :].T, de_o)
            grad['b_o'] += de_o.sum(axis=0)

            de_f = (dhi * self.states[:, i - 1, :]) * self.sigmoid_d(self.gates['f'][:, i, :])
            grad['W_f'] += np.dot(self.last_input[:, i, :].T, de_f)
            grad['U_f'] += np.dot(self.outputs[:, i - 1, :].T, de_f)
            grad['b_f'] += de_f.sum(axis=0)

            de_i = (dhi * self.gates['c'][:, i, :]) * self.sigmoid_d(self.gates['i'][:, i, :])
            grad['W_i'] += np.dot(self.last_input[:, i, :].T, de_i)
            grad['U_i'] += np.dot(self.outputs[:, i - 1, :].T, de_i)
            grad['b_i'] += de_i.sum(axis=0)

            de_c = (dhi * self.gates['i'][:, i, :]) * self.activation_d(self.gates['c'][:, i, :])
            grad['W_c'] += np.dot(self.last_input[:, i, :].T, de_c)
            grad['U_c'] += np.dot(self.outputs[:, i - 1, :].T, de_c)
            grad['b_c'] += de_c.sum(axis=0)

            dh_next = dhi * self.gates['f'][:, i, :]

        # TODO: propagate error to the next layer

        # Change actual gradient arrays
        for k in grad.keys():
            self._params.update_grad(k, grad[k])
        return output

    def shape(self, x_shape):
        if self.return_sequences:
            return x_shape[0], x_shape[1], self.hidden_dim
            return x_shape[0], self.hidden_dim
class RNN(Layer, ParamMixin):
    """Vanilla RNN."""

    def __init__(self, hidden_dim, activation='tanh', inner_init='orthogonal', parameters=None, return_sequences=True):
        self.return_sequences = return_sequences
        self.hidden_dim = hidden_dim
        self.inner_init = get_initializer(inner_init)
        self.activation = get_activation(activation)
        self.activation_d = elementwise_grad(self.activation)
        if parameters is None:
            self._params = Parameters()
            self._params = parameters
        self.last_input = None
        self.states = None
        self.hprev = None
        self.input_dim = None

    def setup(self, x_shape):
        x_shape : np.array(batch size, time steps, input shape)
        self.input_dim = x_shape[2]

        # Input -> Hidden
        self._params['W'] = self._params.init((self.input_dim, self.hidden_dim))
        # Bias
        self._params['b'] = np.full((self.hidden_dim,), self._params.initial_bias)
        # Hidden -> Hidden layer
        self._params['U'] = self.inner_init((self.hidden_dim, self.hidden_dim))

        # Init gradient arrays

        self.hprev = np.zeros((x_shape[0], self.hidden_dim))

    def forward_pass(self, X):
        self.last_input = X
        n_samples, n_timesteps, input_shape = X.shape
        states = np.zeros((n_samples, n_timesteps + 1, self.hidden_dim))
        states[:, -1, :] = self.hprev.copy()
        p = self._params

        for i in range(n_timesteps):
            states[:, i, :] = np.tanh(np.dot(X[:, i, :], p['W']) + np.dot(states[:, i - 1, :], p['U']) + p['b'])

        self.states = states
        self.hprev = states[:, n_timesteps - 1, :].copy()
        if self.return_sequences:
            return states[:, 0:-1, :]
            return states[:, -2, :]

    def backward_pass(self, delta):
        if len(delta.shape) == 2:
            delta = delta[:, np.newaxis, :]
        n_samples, n_timesteps, input_shape = delta.shape
        p = self._params

        # Temporal gradient arrays
        grad = {k: np.zeros_like(p[k]) for k in p.keys()}

        dh_next = np.zeros((n_samples, input_shape))
        output = np.zeros((n_samples, n_timesteps, self.input_dim))

        # Backpropagation through time
        for i in reversed(range(n_timesteps)):
            dhi = self.activation_d(self.states[:, i, :]) * (delta[:, i, :] + dh_next)

            grad['W'] += np.dot(self.last_input[:, i, :].T, dhi)
            grad['b'] += delta[:, i, :].sum(axis=0)
            grad['U'] += np.dot(self.states[:, i - 1, :].T, dhi)

            dh_next = np.dot(dhi, p['U'].T)

            d = np.dot(delta[:, i, :], p['U'].T)
            output[:, i, :] = np.dot(d, p['W'].T)

        # Change actual gradient arrays
        for k in grad.keys():
            self._params.update_grad(k, grad[k])
        return output

    def shape(self, x_shape):
        if self.return_sequences:
            return x_shape[0], x_shape[1], self.hidden_dim
            return x_shape[0], self.hidden_dim