forked from petere/plpydbapi
/
plpydbapi.py
358 lines (252 loc) · 8.1 KB
/
plpydbapi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
"""
A Python DB-API compatible (sort of) interface on top of PL/Python
"""
__author__ = "Peter Eisentraut <peter@eisentraut.org>"
import decimal
import plpy
import sys
import time
## Module Interface
def connect():
return Connection()
apilevel = '2.0'
threadsafety = 0 # Threads may not share the module.
paramstyle = 'format'
if sys.version[0] == '3':
long = int
StandardError = Exception
class Warning(StandardError):
pass
class Error(StandardError):
def __init__(self, spierror=None):
super(Error, self).__init__()
self.spierror = spierror
class InterfaceError(Error):
pass
class DatabaseError(Error):
pass
class DataError(DatabaseError):
pass
class OperationalError(DatabaseError):
pass
class IntegrityError(DatabaseError):
pass
class InternalError(DatabaseError):
pass
class ProgrammingError(DatabaseError):
pass
class NotSupportedError(DatabaseError):
pass
## Connection Objects
class Connection:
Warning = Warning
Error = Error
InterfaceError = InterfaceError
DatabaseError = DatabaseError
DataError = DataError
OperationalError = OperationalError
IntegrityError = IntegrityError
InternalError = InternalError
ProgrammingError = ProgrammingError
NotSupportedError = NotSupportedError
closed = False
_subxact = None
def __init__(self):
pass
def close(self):
if self.closed:
raise Error()
self.rollback()
self.closed = True
def _ensure_transaction(self):
if self._subxact is None:
self._subxact = plpy.subtransaction()
self._subxact.enter()
def commit(self):
if self.closed:
raise Error()
if self._subxact is not None:
self._subxact.exit(None, None, None)
self._subxact = None
def rollback(self):
if self.closed:
raise Error()
if self._subxact is not None:
self._subxact.exit('fake exception', None, None)
self._subxact = None
def cursor(self):
newcursor = Cursor()
newcursor.connection = self
return newcursor
## Cursor Objects
class Cursor:
description = None
rowcount = -1
arraysize = 1
rownumber = None
closed = False
connection = None
_execute_result = None
_SPI_OK_UTILITY = 4
_SPI_OK_SELECT = 5
def __init__(self):
pass
def close(self):
self.closed = True
def _is_closed(self):
return self.closed or self.connection.closed
def execute(self, operation, parameters=None):
if self._is_closed():
raise Error()
self.connection._ensure_transaction()
parameters = parameters or []
placeholders = []
types = []
values = []
for i, param in enumerate(parameters):
placeholders.append("$%d" % (i + 1))
types.append(self.py_param_to_pg_type(param))
values.append(param)
if len(placeholders) == 1:
query = operation % placeholders[0]
else:
query = operation % placeholders
try:
plan = plpy.prepare(query, types)
res = plpy.execute(plan, values)
except plpy.SPIError as e:
raise Error(e)
self._execute_result = None
self.rownumber = None
self.description = None
self.rowcount = -1
if res.status() == self._SPI_OK_SELECT:
self._execute_result = [[row[col] for col in row] for row in res]
self.rownumber = 0
if 'colnames' in res.__class__.__dict__:
# PG 9.2+: use .colnames() and .coltypes() methods
self.description = [(name, get_type_obj(typeoid), None, None, None, None, None) for name, typeoid in zip(res.colnames(), res.coltypes())]
elif len(res) > 0:
# else get at least the column names from the row keys
self.description = [(name, None, None, None, None, None, None) for name in res[0].keys()]
else:
# else we know nothing
self.description = [(None, None, None, None, None, None, None)]
if res.status() == self._SPI_OK_UTILITY:
self.rowcount = -1
else:
self.rowcount = res.nrows()
@staticmethod
def py_param_to_pg_type(param):
if isinstance(param, bool):
pgtype = 'bool'
elif isinstance(param, decimal.Decimal):
pgtype = 'numeric'
elif isinstance(param, float):
pgtype = 'float8'
elif isinstance(param, long):
pgtype = 'int'
elif isinstance(param, int):
pgtype = 'int'
else:
pgtype = 'text'
# TODO ...
return pgtype
def executemany(self, operation, seq_of_parameters):
# We can't reuse saved plans here, because we have no way of
# knowing whether all parameter sets will be of the same type.
totalcount = 0
for parameters in seq_of_parameters:
self.execute(operation, parameters)
if totalcount != -1:
totalcount += self.rowcount
self.rowcount = totalcount
def fetchone(self):
if self._execute_result is None:
raise Error()
if self.rownumber == len(self._execute_result):
return None
result = self._execute_result[self.rownumber]
self.rownumber += 1
return result
def fetchmany(self, size=None):
if self._execute_result is None:
raise Error()
if size is None:
size = self.arraysize
result = self._execute_result[self.rownumber:self.rownumber + size]
self.rownumber += size
return result
def fetchall(self):
if self._execute_result is None:
raise Error()
result = self._execute_result[self.rownumber:]
self.rownumber = len(self._execute_result)
return result
def next(self):
result = self.fetchone()
if result is None:
raise StopIteration
return result
def scroll(self, value, mode='relative'):
if mode == 'relative':
newpos = self.rownumber + value
elif mode == 'absolute':
newpos = value
else:
raise ValueError("Invalid mode")
if newpos < 0 or newpos > len(self._execute_result):
raise IndexError("scroll operation would leave result set")
self.rownumber = newpos
def setinputsizes(self, sizes):
pass
def setoutputsize(self, size, column=None):
pass
def __iter__(self):
return self
## Type Objects and Constructors
def Date(year, month, day):
return '%04d-%02d-%02d' % (year, month, day)
def Time(hour, minute, second):
return '%02d:%02d:%02d' % (hour, minute, second)
def Timestamp(year, month, day, hour, minute, second):
return '%04d-%02d-%02d %02d:%02d:%02d' % (year, month, day, hour, minute, second)
def DateFromTicks(ticks):
return Date(*time.localtime(ticks)[:3])
def TimeFromTicks(ticks):
return Time(*time.localtime(ticks)[3:6])
def TimestampFromTicks(ticks):
return Timestamp(*time.localtime(ticks)[:6])
def Binary(string):
return string
# Type objects
class STRING:
pass
class BINARY:
pass
class NUMBER:
pass
class DATETIME:
pass
class ROWID:
pass
_typname_typeobjs = {
'bytea': BINARY,
}
_typcategory_typeobjs = {
'D': DATETIME,
'N': NUMBER,
'S': STRING,
'T': DATETIME,
}
_typoid_typeobjs = {}
def get_type_obj(typeoid):
"""Return the type object (STRING, NUMBER, etc.) that corresponds
to the given type OID."""
if not _typoid_typeobjs:
for row in plpy.execute(plpy.prepare("SELECT oid, typname, typcategory FROM pg_type")):
if row['typcategory'] in _typcategory_typeobjs:
_typoid_typeobjs[int(row['oid'])] = _typcategory_typeobjs[row['typcategory']]
elif row['typname'] in _typname_typeobjs:
_typoid_typeobjs[int(row['oid'])] = _typname_typeobjs[row['typname']]
return _typoid_typeobjs.get(typeoid)