def test_lstm_forward_training_fuzz(ops, args): params, H0, C0, X, size_at_t = args reference_ops = Ops() reference = reference_ops.lstm_forward_training(params, H0, C0, X, size_at_t) Y, fwd_state = ops.lstm_forward_training(params, H0, C0, X, size_at_t) assert_allclose(fwd_state[2], reference[1][2], atol=1e-4, rtol=1e-3) assert_allclose(fwd_state[1], reference[1][1], atol=1e-4, rtol=1e-3) assert_allclose(Y, reference[0], atol=1e-4, rtol=1e-3)
def test_lstm_forward_training(ops, depth, dirs, nO, batch_size, nI): reference_ops = Ops() params, H0, C0, X, size_at_t = get_lstm_args(depth, dirs, nO, batch_size, nI) reference = reference_ops.lstm_forward_training(params, H0, C0, X, size_at_t) Y, fwd_state = ops.lstm_forward_training(params, H0, C0, X, size_at_t) assert_allclose(fwd_state[2], reference[1][2], atol=1e-4, rtol=1e-3) assert_allclose(fwd_state[1], reference[1][1], atol=1e-4, rtol=1e-3) assert_allclose(Y, reference[0], atol=1e-4, rtol=1e-3)