class GRUCell(nn.Module):
    def __init__(self,
                 input_size,
                 hidden_size,
                 gate_act="sigmoid",
                 state_act="tanh"):
        super(GRUCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        if gate_act == "sigmoid":
            self.gate_activation = torch.sigmoid
        elif gate_act == "relu":
            self.gate_activation = torch.relu

        if state_act == "tanh":
            self.state_activation = torch.tanh
        elif state_act == "relu":
            self.state_activation = torch.relu

        # order: w_u, w_r, w_c
        self.weight_i = Parameter(torch.randn(input_size, 3 * hidden_size))
        self.weight_h = Parameter(torch.randn(hidden_size, 3 * hidden_size))

        # order: b_u, b_r, b_c
        self.bias = Parameter(torch.randn(3 * hidden_size))
        self.bn = nn.BatchNorm1d(3 * hidden_size)

    def forward(self, input, state):
        # type: (Tensor, Tensor) -> (Tensor, List[Tensor])
        # check batch size matches
        for i in range(len(state)):
            assert input.shape[0] == state[i].shape[
                0], "input batch:{}, {}th hidden element batch:{}".format(
                    input.shape[0], i, state[i].shape[0])
        assert input.shape[1] == self.weight_i.shape[0]
        hx = state[0]

        gates_input = torch.mm(input, self.weight_i)
        gates_input = self.bn(
            gates_input)  # deepspeech only normalize the input part
        gates_input += self.bias
        u, r, c = gates_input.chunk(3, 1)
        u_h, r_h, c_h = self.weight_h.chunk(3, 1)

        u += torch.mm(hx, u_h)
        r += torch.mm(hx, r_h)
        u = self.gate_activation(u)
        r = self.gate_activation(r)

        c += torch.mm((r * hx), c_h)
        c = self.state_activation(c)

        # this is how paddlepaddle implement gru.
        # it is different to the general implementation "hy = (u * hx) + ((1.0 -u) * c)"
        hy = ((1 - u) * hx) + (u * c)

        # keep it as the same format as lstm's ouput for future upgrading
        return hy, [
            hy,
        ]
class GRU_hiddenCell(nn.Module):
    '''
    This GRU layer leave the input * hidden_i outside.
    This mimics PaddlePaddle's dynamicGRU. Only difference is my implementation based on GRUCell instead of a complete GRU layer
    '''
    def __init__(self, hidden_size, gate_act="sigmoid", state_act="tanh"):
        super(GRU_hiddenCell, self).__init__()
        self.hidden_size = hidden_size
        if gate_act == "sigmoid":
            self.gate_activation = torch.sigmoid
        elif gate_act == "relu":
            self.gate_activation = torch.relu

        if state_act == "tanh":
            self.state_activation = torch.tanh
        elif state_act == "relu":
            self.state_activation = torch.relu

        # order: w_u, w_r, w_c
        self.weight_h = Parameter(torch.randn(hidden_size, 3 * hidden_size))

        # order: b_u, b_r, b_c
        self.bias = Parameter(torch.randn(3 * hidden_size))

    def forward(self, input: Tensor,
                state: List[Tensor]) -> (Tensor, List[Tensor]):
        # check batch size matches
        assert input.shape[1] == self.weight_h.shape[1], \
            "input's shape ({}) should be the same as the hidden shape ({})".format(input.shape[1], self.weight_h.shape[1])
        hx = state[0]

        gates_input = input + self.bias
        u, r, c = gates_input.chunk(3, 1)
        u_h, r_h, c_h = self.weight_h.chunk(3, 1)

        u += torch.mm(hx, u_h)
        r += torch.mm(hx, r_h)
        u = self.gate_activation(u)
        r = self.gate_activation(r)

        c += torch.mm((r * hx), c_h)
        c = self.state_activation(c)

        # this is how paddlepaddle implement gru.
        # it is different to the general implementation "hy = (u * hx) + ((1.0 -u) * c)"
        hy = ((1 - u) * hx) + (u * c)

        # keep it as the same format as lstm's ouput for future upgrading
        return hy, [
            hy,
        ]