def __init__(self,
                 input_shape,
                 out_channel,
                 kernel_size,
                 stride,
                 learning_rate,
                 activate_func: str = None):
        self.in_h, self.in_w, self.in_channel = input_shape
        self.learning_rate = learning_rate
        self.out_channel = out_channel
        self.kernel_h, self.kernel_w = kernel_size
        self.stride_h, self.stride_w = stride
        # ignore padding
        assert (self.in_h - self.kernel_h) % self.stride_h == 0
        assert (self.in_w - self.kernel_w) % self.stride_w == 0
        self.out_h = (self.in_h - self.kernel_h) // self.stride_h + 1
        self.out_w = (self.in_w - self.kernel_w) // self.stride_w + 1

        self.filters = torch.randn(
            (self.kernel_h, self.kernel_w, self.in_channel, out_channel),
            dtype=floatX)
        self.biases = torch.randn((self.out_channel, ), dtype=floatX)
        self.filters_gradient = torch.empty(
            (self.kernel_h, self.kernel_w, self.in_channel, out_channel),
            dtype=floatX)

        if activate_func == 'relu':
            self.activation = other.Relu()
        elif activate_func == 'sigmoid':
            self.activation = other.Sigmoid(100)
        elif activate_func == 'tanh':
            self.activation = other.Tanh(100)
        else:
            self.activation = None
    def __init__(self,
                 input_shape,
                 out_channel,
                 kernel_size,
                 stride,
                 learning_rate,
                 padding=(0, 0),
                 activate_func: str = None):
        self.in_h, self.in_w, self.in_channel = input_shape
        self.learning_rate = learning_rate
        self.out_channel = out_channel
        self.kernel_h, self.kernel_w = kernel_size
        self.stride_h, self.stride_w = stride
        self.in_pad_h, self.in_pad_w = padding

        # padding
        self.real_in_h = self.in_h
        self.real_in_w = self.in_w
        self.in_h += 2 * self.in_pad_h
        self.in_w += 2 * self.in_pad_w
        assert (self.in_h - self.kernel_h) % self.stride_h == 0
        assert (self.in_w - self.kernel_w) % self.stride_w == 0

        self.out_h = (self.in_h - self.kernel_h) // self.stride_h + 1
        self.out_w = (self.in_w - self.kernel_w) // self.stride_w + 1

        # xaiver 初始化
        n_in = self.in_h * self.in_w * self.in_channel
        n_out = self.out_h * self.out_w * self.out_channel
        coe = math.sqrt(6) / math.sqrt(n_in + n_out)
        print('conv coe', coe)
        self.filters = torch.tensor(np.random.uniform(
            -coe,
            coe,
            size=(self.kernel_h, self.kernel_w, self.in_channel, out_channel)),
                                    dtype=floatX,
                                    device=device)
        self.biases = torch.tensor(np.random.uniform(
            -coe, coe, size=(self.out_channel, )),
                                   dtype=floatX,
                                   device=device)
        # self.filters = torch.randn(
        #     (self.kernel_h, self.kernel_w, self.in_channel, out_channel),
        #     dtype=floatX, device=device)
        # self.biases = torch.randn((self.out_channel,), dtype=floatX, device=device)
        # self.filters_gradient = torch.empty(
        #     (self.kernel_h, self.kernel_w, self.in_channel, out_channel),
        #     dtype=floatX, device=device)

        if activate_func == 'relu':
            self.activation = other.Relu()
        elif activate_func == 'mfm':
            self.activation = other.MFM()
        else:
            self.activation = None