コード例 #1
0
ファイル: pooling.py プロジェクト: Satssuki/zeta-learn
    def pass_backward(self, grad):
        input_num, input_depth, input_height, input_width = self.inputs.shape

        d_input_col = np.zeros_like(self.input_col)
        grad_col = grad.transpose(2, 3, 0, 1).ravel()

        d_input_col = self.pool_backward(d_input_col, grad_col,
                                         self.pool_cache)
        d_input = col2im_indices(
            d_input_col,
            (input_num * input_depth, 1, input_height, input_width),
            self.pool_size[0],
            self.pool_size[1],
            padding=(self.pad_height, self.pad_width),
            stride=self.strides[0])

        return d_input.reshape(self.inputs.shape)
コード例 #2
0
    def pass_backward(self, grad):
        input_num, input_depth, input_height, input_width = self.input_shape

        if self.is_trainable:

            dbias = np.sum(grad, axis=(0, 2, 3))
            dbias = dbias.reshape(self.filter_num, -1)

            doutput_reshaped = grad.transpose(1, 2, 3,
                                              0).reshape(self.filter_num, -1)

            dweights = doutput_reshaped @ self.input_col.T
            dweights = dweights.reshape(self.weights.shape)

            # optimize the weights and bias
            self.weights = optimizer(self.weight_optimizer).update(
                self.weights, dweights)
            self.bias = optimizer(self.weight_optimizer).update(
                self.bias, dbias)

        # endif self.is_trainable

        weight_reshape = self.weights.reshape(self.filter_num, -1)
        dinput_col = weight_reshape.T @ doutput_reshaped

        pad_height, pad_width = get_pad(self.padding, input_height,
                                        input_width, self.strides[0],
                                        self.strides[1], self.kernel_size[0],
                                        self.kernel_size[1])

        dinputs = col2im_indices(dinput_col,
                                 self.input_shape,
                                 self.kernel_size[0],
                                 self.kernel_size[1],
                                 padding=(pad_height, pad_width),
                                 stride=self.strides[0])

        return dinputs