コード例 #1
0
    def backward_with_dropout(self, A, Y, keep_prob):
        assert A.shape == Y.shape

        m = A.shape[1]
        L = len(self.params) // 2
        # shortcut for cross-entropy
        dZ = A - Y
        for l in range(L, 0, -1):  # [L, L-1, ...2, 1]
            # from dZ[l] we compute dW[l] and db[l]
            A = self.caches['A' + str(l - 1)]
            self.grads['dW' + str(l)] = np.dot(dZ, A.T) / float(m)
            self.grads['db' + str(l)] = np.sum(dZ, axis=1, keepdims=True) / float(m)

            # stop at [1] since we don't need to compute dW[0] and db[0] and same for dZ[0] (and dA[0])
            if l == 1:
                break

            # compute dZ[l-1] (implicitly dA[l-1] also) for the next iteration
            W = self.params['W' + str(l)]
            Z = self.caches['Z' + str(l - 1)]
            D = self.caches['D' + str(l - 1)]
            dA = np.dot(W.T, dZ)
            dA = dA * D
            dA = dA / keep_prob
            self.grads['dA' + str(l - 1)] = dA
            dZ = np.multiply(dA, d_relu(Z))
コード例 #2
0
    def loss(self, x, y=None):
        """Compute loss and gradient for the fully-connected net."""
        if len(y.shape) == 2:
            N, M = y.shape  # N is n_samples, M is dims of each sample
        elif len(y.shape) == 1:
            M = y.shape[0]
        else:
            raise ValueError("y has incorrect shape")

        output, caches = self.prediction_save_cache(x)  # Forward pass
        grads = {}

        # Calculate the loss for the current batch =============================
        # Get the mean squared error loss (1/2 to simplify derivative)
        loss = 0.5 * np.mean((output - y)**2)
        # Add a regularization term
        for l in range(self.n_hidden):
            loss += 0.5 * self.reg * np.sum(self.params[f"w{l}"]**2)
            loss += 0.5 * self.reg * np.sum(self.params[f"b{l}"]**2)
        loss += 0.5 * self.reg * np.sum(self.params["w_out"]**2)
        loss += 0.5 * self.reg * np.sum(self.params["b_out"]**2)

        # Get the gradients through backprop ===================================
        # Gradient from the MSE loss
        dout = (output - y) / N
        # Backprop through output layer
        dout, dw, db = d_affine(dout, caches["affine_out"])
        grads["w_out"] = dw + self.reg * self.params["w_out"]
        grads["b_out"] = db + self.reg * self.params["b_out"]
        # Backprop through each hidden layer
        for l in reversed(range(self.n_hidden)):
            l = str(l)
            dout = d_dropout(dout, caches["dropout" + l])
            dout = d_relu(dout, caches["relu" + l])
            dout, dw, db = d_affine(dout, caches["affine" + l])

            # Save gradients into a dictionary where the key matches the param key
            grads["w" + l] = dw + self.reg * self.params["w" + l]
            grads["b" + l] = db + self.reg * self.params["b" + l]

        # Clip gradients if enabled - really helps stability! ==================
        if self.grad_clip:
            for key, grad in grads.items():
                grads[key] = np.clip(grads[key], -self.grad_clip,
                                     self.grad_clip)

        return loss, grads
コード例 #3
0
    def backward(self, A, Y):
        assert A.shape == Y.shape

        L = len(self.params) // 2
        m = Y.shape[1]

        # shotcut for cross-entropy
        dZ = A - Y
        for l in range(L, 0, -1):  # [L, L-2, 2, 1]
            # from dZ[l] we compute dW[l] and db[l]
            A = self.caches['A' + str(l - 1)]
            self.grads['dW' + str(l)] = np.dot(dZ, A.T) / float(m)
            self.grads['db' + str(l)] = np.sum(dZ, axis=1, keepdims=True) / float(m)

            # stop at [1] since we don't need to compute dW[0] and db[0], so same for dZ[0] (and dA[0])
            if l == 1:
                break

            # compute dZ[l-1] (implicitly dA[l-1] also) for the next iteration
            W = self.params['W' + str(l)]
            Z = self.caches['Z' + str(l - 1)]
            dZ = np.multiply(np.dot(W.T, dZ), d_relu(Z))  # dA = np.dot(W.T, dZ)