def test_lstm_cell_operator(dtype): batch_size = 1 input_size = 16 hidden_size = 128 X_shape = [batch_size, input_size] H_t_shape = [batch_size, hidden_size] C_t_shape = [batch_size, hidden_size] W_shape = [4 * hidden_size, input_size] R_shape = [4 * hidden_size, hidden_size] B_shape = [4 * hidden_size] parameter_X = ng.parameter(X_shape, name="X", dtype=dtype) parameter_H_t = ng.parameter(H_t_shape, name="H_t", dtype=dtype) parameter_C_t = ng.parameter(C_t_shape, name="C_t", dtype=dtype) parameter_W = ng.parameter(W_shape, name="W", dtype=dtype) parameter_R = ng.parameter(R_shape, name="R", dtype=dtype) parameter_B = ng.parameter(B_shape, name="B", dtype=dtype) expected_shape = [1, 128] node_default = ng.lstm_cell( parameter_X, parameter_H_t, parameter_C_t, parameter_W, parameter_R, parameter_B, hidden_size, ) assert node_default.get_type_name() == "LSTMCell" assert node_default.get_output_size() == 2 assert list(node_default.get_output_shape(0)) == expected_shape assert list(node_default.get_output_shape(1)) == expected_shape activations = ["tanh", "Sigmoid", "RELU"] activation_alpha = [1.0, 2.0, 3.0] activation_beta = [3.0, 2.0, 1.0] clip = 0.5 node_param = ng.lstm_cell( parameter_X, parameter_H_t, parameter_C_t, parameter_W, parameter_R, parameter_B, hidden_size, activations, activation_alpha, activation_beta, clip, ) assert node_param.get_type_name() == "LSTMCell" assert node_param.get_output_size() == 2 assert list(node_param.get_output_shape(0)) == expected_shape assert list(node_param.get_output_shape(1)) == expected_shape
def test_lstm_cell_operator(dtype): batch_size = 1 input_size = 16 hidden_size = 128 X_shape = [batch_size, input_size] H_t_shape = [batch_size, hidden_size] C_t_shape = [batch_size, hidden_size] W_shape = [4 * hidden_size, input_size] R_shape = [4 * hidden_size, hidden_size] B_shape = [4 * hidden_size] parameter_X = ng.parameter(X_shape, name='X', dtype=dtype) parameter_H_t = ng.parameter(H_t_shape, name='H_t', dtype=dtype) parameter_C_t = ng.parameter(C_t_shape, name='C_t', dtype=dtype) parameter_W = ng.parameter(W_shape, name='W', dtype=dtype) parameter_R = ng.parameter(R_shape, name='R', dtype=dtype) parameter_B = ng.parameter(B_shape, name='B', dtype=dtype) expected_shape = [1, 128] node_default = ng.lstm_cell(parameter_X, parameter_H_t, parameter_C_t, parameter_W, parameter_R, parameter_B, hidden_size) assert node_default.get_type_name() == 'LSTMCell' assert node_default.get_output_size() == 2 assert list(node_default.get_output_shape(0)) == expected_shape assert list(node_default.get_output_shape(1)) == expected_shape activations = ['tanh', 'Sigmoid', 'RELU'] activation_alpha = [1.0, 2.0, 3.0] activation_beta = [3.0, 2.0, 1.0] clip = 0.5 node_param = ng.lstm_cell(parameter_X, parameter_H_t, parameter_C_t, parameter_W, parameter_R, parameter_B, hidden_size, activations, activation_alpha, activation_beta, clip) assert node_param.get_type_name() == 'LSTMCell' assert node_param.get_output_size() == 2 assert list(node_param.get_output_shape(0)) == expected_shape assert list(node_param.get_output_shape(1)) == expected_shape