/
test_db.py
152 lines (102 loc) · 3.08 KB
/
test_db.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from db import Database, LineInput
import unittest
class DummyDatabase(object):
def test(self, first):
return first
class TestLineInput(unittest.TestCase):
def setUp(self):
self.input = LineInput(DummyDatabase())
def test_input_ok(self):
self.assertEqual(('ok', 'test'), self.input.consume('test test'))
def test_input_no_method(self):
self.assertEqual(('error', ''), self.input.consume('nothing command'))
def test_input_bad_args(self):
self.assertEqual(('error', ''), self.input.consume('test test test'))
class TestDatabase(unittest.TestCase):
def setUp(self):
self.db = Database()
def test_get_and_set(self):
self.db.set('a', 10)
self.assertEqual(self.db.get('a'), 10)
def test_get_null(self):
self.assertEqual(self.db.get('nonexistant'), 'NULL')
def test_unset(self):
self.db.set('a', 10)
self.db.unset('a')
self.assertEqual(self.db.get('a'), 'NULL')
def test_equalto(self):
self.db.set('a', 10)
self.db.set('c', 10)
self.db.set('b', 10)
self.assertEqual(self.db.equalto(10), 'a b c')
def test_equalto_none(self):
self.assertEqual(self.db.equalto(10), 'NONE')
def test_equalto_unset(self):
self.db.set('a', 10)
self.db.set('c', 10)
self.db.set('b', 10)
self.db.unset('b')
self.assertEqual(self.db.equalto(10), 'a c')
def test_equalto_change(self):
self.db.set('a', 10)
self.db.set('a', 20)
self.assertEqual(self.db.equalto(10), 'NONE')
self.assertEqual(self.db.equalto(20), 'a')
def test_program(self):
"""SET a 10
SET b 10
EQUALTO 10
EQUALTO 20
UNSET a
EQUALTO 10
SET b 30
EQUALTO 10"""
self.db.set('a', 10)
self.db.set('b', 10)
self.assertEqual(self.db.equalto(10), 'a b')
self.assertEqual(self.db.equalto(20), 'NONE')
self.db.unset('a')
self.assertEqual(self.db.equalto(10), 'b')
self.db.set('b', 30)
self.assertEqual(self.db.equalto(10), 'NONE')
def test_tx_commit(self):
self.db.set('a', 10)
self.db.begin()
self.assertEqual(self.db.get('a'), 10)
self.db.set('a', 20)
self.assertEqual(self.db.get('a'), 20)
self.db.commit()
self.assertEqual(self.db.get('a'), 20)
def test_tx_full_commit(self):
# test that commit commits *all* open transactions
self.db.begin()
self.db.set('a', 30)
self.db.begin()
self.db.set('a', 40)
self.db.commit()
self.assertEqual(self.db.get('a'), 40)
self.assertEqual(self.db.rollback(), 'INVALID ROLLBACK')
def test_tx_rollback(self):
self.db.set('a', 10)
self.db.begin()
self.assertEqual(self.db.get('a'), 10)
self.db.set('a', 20)
self.assertEqual(self.db.get('a'), 20)
self.db.rollback()
self.assertEqual(self.db.get('a'), 10)
def test_tx_invalid_rollback(self):
self.assertEqual(self.db.rollback(), 'INVALID ROLLBACK')
def test_tx_begin_rollback_commit(self):
self.db.set('a', 50)
self.db.begin()
self.assertEqual(self.db.get('a'), 50)
self.db.set('a', 60)
self.db.begin()
self.db.unset('a')
self.assertEqual(self.db.get('a'), 'NULL')
self.db.rollback()
self.assertEqual(self.db.get('a'), 60)
self.db.commit()
self.assertEqual(self.db.get('a'), 60)
if __name__ == '__main__':
unittest.main()