def test_lstm_cell_operator(dtype):
    batch_size = 1
    input_size = 16
    hidden_size = 128

    X_shape = [batch_size, input_size]
    H_t_shape = [batch_size, hidden_size]
    C_t_shape = [batch_size, hidden_size]
    W_shape = [4 * hidden_size, input_size]
    R_shape = [4 * hidden_size, hidden_size]
    B_shape = [4 * hidden_size]

    parameter_X = ng.parameter(X_shape, name="X", dtype=dtype)
    parameter_H_t = ng.parameter(H_t_shape, name="H_t", dtype=dtype)
    parameter_C_t = ng.parameter(C_t_shape, name="C_t", dtype=dtype)
    parameter_W = ng.parameter(W_shape, name="W", dtype=dtype)
    parameter_R = ng.parameter(R_shape, name="R", dtype=dtype)
    parameter_B = ng.parameter(B_shape, name="B", dtype=dtype)

    expected_shape = [1, 128]

    node_default = ng.lstm_cell(
        parameter_X,
        parameter_H_t,
        parameter_C_t,
        parameter_W,
        parameter_R,
        parameter_B,
        hidden_size,
    )

    assert node_default.get_type_name() == "LSTMCell"
    assert node_default.get_output_size() == 2
    assert list(node_default.get_output_shape(0)) == expected_shape
    assert list(node_default.get_output_shape(1)) == expected_shape

    activations = ["tanh", "Sigmoid", "RELU"]
    activation_alpha = [1.0, 2.0, 3.0]
    activation_beta = [3.0, 2.0, 1.0]
    clip = 0.5

    node_param = ng.lstm_cell(
        parameter_X,
        parameter_H_t,
        parameter_C_t,
        parameter_W,
        parameter_R,
        parameter_B,
        hidden_size,
        activations,
        activation_alpha,
        activation_beta,
        clip,
    )

    assert node_param.get_type_name() == "LSTMCell"
    assert node_param.get_output_size() == 2
    assert list(node_param.get_output_shape(0)) == expected_shape
    assert list(node_param.get_output_shape(1)) == expected_shape
def test_lstm_cell_operator(dtype):
    batch_size = 1
    input_size = 16
    hidden_size = 128

    X_shape = [batch_size, input_size]
    H_t_shape = [batch_size, hidden_size]
    C_t_shape = [batch_size, hidden_size]
    W_shape = [4 * hidden_size, input_size]
    R_shape = [4 * hidden_size, hidden_size]
    B_shape = [4 * hidden_size]

    parameter_X = ng.parameter(X_shape, name='X', dtype=dtype)
    parameter_H_t = ng.parameter(H_t_shape, name='H_t', dtype=dtype)
    parameter_C_t = ng.parameter(C_t_shape, name='C_t', dtype=dtype)
    parameter_W = ng.parameter(W_shape, name='W', dtype=dtype)
    parameter_R = ng.parameter(R_shape, name='R', dtype=dtype)
    parameter_B = ng.parameter(B_shape, name='B', dtype=dtype)

    expected_shape = [1, 128]

    node_default = ng.lstm_cell(parameter_X, parameter_H_t, parameter_C_t,
                                parameter_W, parameter_R, parameter_B,
                                hidden_size)

    assert node_default.get_type_name() == 'LSTMCell'
    assert node_default.get_output_size() == 2
    assert list(node_default.get_output_shape(0)) == expected_shape
    assert list(node_default.get_output_shape(1)) == expected_shape

    activations = ['tanh', 'Sigmoid', 'RELU']
    activation_alpha = [1.0, 2.0, 3.0]
    activation_beta = [3.0, 2.0, 1.0]
    clip = 0.5

    node_param = ng.lstm_cell(parameter_X, parameter_H_t, parameter_C_t,
                              parameter_W, parameter_R, parameter_B,
                              hidden_size, activations, activation_alpha,
                              activation_beta, clip)

    assert node_param.get_type_name() == 'LSTMCell'
    assert node_param.get_output_size() == 2
    assert list(node_param.get_output_shape(0)) == expected_shape
    assert list(node_param.get_output_shape(1)) == expected_shape