def test_lstm_sequence_operator_forward(dtype): batch_size = 2 input_size = 4 hidden_size = 3 num_directions = 1 seq_length = 2 X_shape = [batch_size, seq_length, input_size] H_t_shape = [batch_size, num_directions, hidden_size] C_t_shape = [batch_size, num_directions, hidden_size] seq_len_shape = [batch_size] W_shape = [num_directions, 4 * hidden_size, input_size] R_shape = [num_directions, 4 * hidden_size, hidden_size] B_shape = [num_directions, 4 * hidden_size] parameter_X = ng.parameter(X_shape, name="X", dtype=dtype) parameter_H_t = ng.parameter(H_t_shape, name="H_t", dtype=dtype) parameter_C_t = ng.parameter(C_t_shape, name="C_t", dtype=dtype) parameter_seq_len = ng.parameter(seq_len_shape, name="seq_len", dtype=np.int32) parameter_W = ng.parameter(W_shape, name="W", dtype=dtype) parameter_R = ng.parameter(R_shape, name="R", dtype=dtype) parameter_B = ng.parameter(B_shape, name="B", dtype=dtype) direction = "forward" node_default = ng_opset1.lstm_sequence( parameter_X, parameter_H_t, parameter_C_t, parameter_seq_len, parameter_W, parameter_R, parameter_B, hidden_size, direction, ) assert node_default.get_type_name() == "LSTMSequence" assert node_default.get_output_size() == 3 activations = ["RELU", "tanh", "Sigmoid"] activation_alpha = [2.0] activation_beta = [1.0] clip = 0.5 node = ng_opset1.lstm_sequence( parameter_X, parameter_H_t, parameter_C_t, parameter_seq_len, parameter_W, parameter_R, parameter_B, hidden_size, direction, activations, activation_alpha, activation_beta, clip, ) assert node.get_type_name() == "LSTMSequence" assert node.get_output_size() == 3
def test_lstm_sequence_operator_bidirectional(dtype): batch_size = 1 input_size = 16 hidden_size = 128 num_directions = 2 seq_length = 2 X_shape = [batch_size, seq_length, input_size] H_t_shape = [batch_size, num_directions, hidden_size] C_t_shape = [batch_size, num_directions, hidden_size] seq_len_shape = [batch_size] W_shape = [num_directions, 4 * hidden_size, input_size] R_shape = [num_directions, 4 * hidden_size, hidden_size] B_shape = [num_directions, 4 * hidden_size] parameter_X = ng.parameter(X_shape, name="X", dtype=dtype) parameter_H_t = ng.parameter(H_t_shape, name="H_t", dtype=dtype) parameter_C_t = ng.parameter(C_t_shape, name="C_t", dtype=dtype) parameter_seq_len = ng.parameter(seq_len_shape, name="seq_len", dtype=np.int32) parameter_W = ng.parameter(W_shape, name="W", dtype=dtype) parameter_R = ng.parameter(R_shape, name="R", dtype=dtype) parameter_B = ng.parameter(B_shape, name="B", dtype=dtype) direction = "BIDIRECTIONAL" node = ng_opset1.lstm_sequence( parameter_X, parameter_H_t, parameter_C_t, parameter_seq_len, parameter_W, parameter_R, parameter_B, hidden_size, direction, ) assert node.get_type_name() == "LSTMSequence" assert node.get_output_size() == 3 activations = ["RELU", "tanh", "Sigmoid"] activation_alpha = [1.0, 2.0, 3.0] activation_beta = [3.0, 2.0, 1.0] clip = 1.22 node_param = ng_opset1.lstm_sequence( parameter_X, parameter_H_t, parameter_C_t, parameter_seq_len, parameter_W, parameter_R, parameter_B, hidden_size, direction, activations, activation_alpha, activation_beta, clip, ) assert node_param.get_type_name() == "LSTMSequence" assert node_param.get_output_size() == 3