Пример #1
0
    def backward_naive(self, dout: NPArray) -> tuple[NPArray, ...]:
        """
        Backward pass for batch normalization.

        For this implementation, you should write out a computation graph for
        batch normalization on paper and propagate gradients backward through
        intermediate nodes.

        Inputs:
        - dout: Upstream derivatives, of shape (N, D)
        - cache: Variable of intermediates from batchnorm_forward.

        Returns a tuple of:
        - dx: Gradient with respect to inputs x, of shape (N, D)
        - dgamma: Gradient with respect to scale parameter gamma, of shape (D,)
        - dbeta: Gradient with respect to shift parameter beta, of shape (D,)
        """
        xn, std = self.cache

        if self.train_mode:
            N = dout.shape[0]
            dbeta = dout.sum(axis=0)
            dgamma = np.sum(xn * dout, axis=0)
            dxn = self.gamma * dout
            dxc = dxn / std
            dstd = -np.sum((dxn * xn) / std, axis=0)
            dvar = 0.5 * dstd / std
            dxc += (2 / N) * (xn * std) * dvar
            dmu = np.sum(dxc, axis=0)
            dx = dxc - dmu / N

        else:
            dbeta = dout.sum(axis=0)
            dgamma = np.sum(xn * dout, axis=0)
            dxn = self.gamma * dout
            dx = dxn / std

        return dx, dgamma, dbeta
Пример #2
0
    def backward(self, dout: NPArray) -> tuple[NPArray, ...]:
        """
        Backward pass for temporal affine layer.

        Input:
        - dout: Upstream gradients of shape (N, T, M)
        - cache: Values from forward pass

        Returns a tuple of:
        - dx: Gradient of input, of shape (N, T, D)
        - dw: Gradient of weights, of shape (D, M)
        - db: Gradient of biases, of shape (M,)
        """
        (x, ) = self.cache
        N, T, D = x.shape
        M = self.b.shape[0]

        dx = dout.reshape(N * T, M).dot(self.w.T).reshape(N, T, D)
        dw = dout.reshape(N * T, M).T.dot(x.reshape(N * T, D)).T
        db = dout.sum(axis=(0, 1))

        return dx, dw, db