-
Notifications
You must be signed in to change notification settings - Fork 0
/
import-xml.py
489 lines (399 loc) · 15.1 KB
/
import-xml.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
#!/usr/bin/env python2
from __future__ import unicode_literals
import sys
import itertools
from copy import deepcopy
from operator import itemgetter
from lxml.etree import parse as parse_xml
import sqlalchemy
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.exc import NoResultFound
sql_lower = sqlalchemy.func.lower
from db import Encounter, EncounterCondition, EncounterConditionValue
from db import EncounterMethod, EncounterTerrain, EncounterSlot
from db import Location, LocationArea
from db import Version
SessionClass = sessionmaker()
# Global are evil, but practically everything needs to access the session.
# The alternative is to reformulate the whole module as a class, which brings
# its own problems.
session = SessionClass()
class memoize(object):
"""A simple memoizer.
The wrapped function can only accept positional, hashable arguments.
"""
def __init__(self, func):
self.func = func
self.memo = {}
def __call__(self, *args):
if args in self.memo:
return self.memo[args]
else:
value = self.func(*args)
self.memo[args] = value
return value
@memoize
def get_condition((cond, value)):
"""Fetch the EncounterConditionValue for a given (cond,value) pair."""
q = session.query(EncounterConditionValue)
q = (q.join(EncounterCondition)
.filter(EncounterCondition.identifier == unicode(cond),
EncounterConditionValue.identifier == unicode(value)))
return q.one()
@memoize
def get_conditions():
"""Fetch all valid conditions identifiers."""
q = session.query(EncounterCondition.identifier)
q = q.order_by(EncounterCondition.id)
return [x.identifier for x in q.all()]
@memoize
def get_condition_values(cond):
"""Fetch the list of values for the given condition."""
q = session.query(EncounterConditionValue.identifier)
q = (q.join(EncounterCondition)
.filter(EncounterCondition.identifier == unicode(cond))
.order_by(EncounterConditionValue.id))
values = [x.identifier for x in q.all()]
if not values:
raise ValueError(cond)
return values
@memoize
def get_version(version):
"""Fetch the Version object for a given version identifier."""
q = session.query(Version)
q = q.filter(sql_lower(Version.name) == version)
return q.one()
@memoize
def get_terrain_id(terrain):
"""Fetch the id for a given terrain"""
q = session.query(EncounterTerrain.id).filter_by(identifier=terrain)
return q.one().id
@memoize
def get_method_id(method):
"""Fetch the id for a given method"""
q = session.query(EncounterMethod.id).filter_by(identifier=method)
return q.one().id
@memoize
def _get_or_create_encounter_slot(slot, method, version_group_id, terrain):
terrain_id = None
if terrain is not None:
terrain_id = get_terrain_id(terrain)
method_id = get_method_id(method)
q = session.query(EncounterSlot).filter_by(
slot=slot,
version_group_id=version_group_id,
encounter_method_id=method_id,
encounter_terrain_id=terrain_id,
)
try:
x = q.one()
except NoResultFound:
x = EncounterSlot()
x.slot = slot
x.version_group_id = version_group_id
x.encounter_method_id = method_id
x.encounter_terrain_id = terrain_id
return x
def get_or_create_encounter_slot(slot, method, version_group_id, terrain=None):
return _get_or_create_encounter_slot(slot, method, version_group_id, terrain)
def parse_range(s):
min, _, max = s.partition('-')
min = int(min)
if max:
max = int(max)
return min, max
else:
return min, min
def lift_conditions(e):
all_conditions = get_conditions()
conditions = {}
for name in all_conditions:
if name in e.attrib:
conditions[name] = e.attrib[name]
return conditions
def groupby(it, key):
"""My version of itertools.groupby(), which sorts the list first."""
return itertools.groupby(sorted(it, key=key), key)
def identical(it, key=lambda x: x):
return len(set(map(key, it))) == 1
def visit_monsters(elem):
assert elem.tag == 'monsters'
obj = {}
if 'method' in elem.attrib:
obj['method'] = elem.get('method')
if 'terrain' in elem.attrib:
obj['terrain'] = elem.get('terrain')
if 'rate' in elem.attrib:
obj['rate'] = int(elem.get('rate'))
obj['conditions'] = lift_conditions(elem)
obj['pokemon'] = pokemon_list = []
for slot, pokemon in enumerate(elem.xpath('./pokemon')):
pokemon_list.append(visit_pokemon(pokemon, slot+1))
obj['subgroups'] = subgroups = []
for monsters in elem.xpath('./monsters'):
subgroups.append(visit_monsters(monsters))
return obj
def visit_pokemon(elem, slot=None):
assert elem.tag == 'pokemon'
if elem.get('slot') is not None:
slot = int(elem.get('slot'))
obj = {
'slot': slot,
'pokemon_id': int(elem.get('number')),
'form': elem.get('form'),
'levels': parse_range(elem.get('levels')),
}
return obj
def lift(root):
"""Lift encounter xml into json objects"""
encounter_list = []
for monsters in root:
encounter_list.append(visit_monsters(monsters))
return encounter_list
class EncounterMonkey(object):
def __init__(self, encounters):
# we'll make a deepcopy of the encounter list, so we aren't affected by
# our caller's shennanigans
self.encounters = deepcopy(encounters)
# Now a lexicographical sort of the encounters.
self.encounters.sort(key=self._sort_key, reverse=True)
def strip_redundancy(self):
# This is easy to check when you have a sorted list:
# For each encounter, simply find the next one down that applies.
# If they are equivalent, the first is redundant.
# This would be prettier as a recursive function on a linked list.
key = self._eq_key
# 1. Mark
for i, e in enumerate(self.encounters):
conditions = e['conditions']
for e2 in self.encounters[i+1:]:
if self._match(e2, conditions) and key(e) == key(e2):
e['is_redundant'] = True
# 2. Sweep
self.encounters = [e for e in self.encounters
if not e.get('is_redundant', False)]
def collapse_encounters(self):
# we really only have to check for season and time, since the others
# should be caught in strip_redundancy().
encounters = self.encounters
for cond in ('season', 'time'):
encounters = self._collapse_encounters(encounters, cond)
encounters.sort(key=self._sort_key, reverse=True)
self.encounters = encounters
def _collapse_encounters(self, encounters, cond):
# group by conditions other than cond
def key(e):
if cond in e['conditions']:
conditions = e['conditions'].copy()
del conditions[cond]
return conditions
else:
return None
all_values = set(get_condition_values(cond))
def can_collapse(group):
values = set(e['conditions'][cond] for e in group)
if values == all_values and identical(group, key=self._eq_key):
return True
else:
return False
new_encounters = []
for common_conds, group in groupby(encounters, key):
group = list(group)
if (common_conds is not None and
len(group) > 1 and
can_collapse(group)):
e = group[0]
e['conditions'] = common_conds
new_encounters.append(e)
else:
new_encounters.extend(group)
return new_encounters
def get_encounter(self, conditions):
for e in self.encounters:
if self._match(e, conditions):
return e
def get_conditions(self):
"""Return all conditions which can apply to an encounter"""
conditions = dict()
for e in self.encounters:
for k, v in e['conditions'].iteritems():
if k not in conditions:
conditions[k] = set()
conditions[k].add(v)
return conditions.keys() #XXX
@staticmethod
def _match(encounter, conditions):
e_conditions = encounter['conditions']
for cond, value in conditions.iteritems():
if cond in e_conditions and e_conditions[cond] != value:
return False
return True
@staticmethod
def _sort_key(e):
priority_groups = [
['season'],
['time'],
['spots'],
['swarm', 'radar'], #XXX etc
]
key = []
for conditions in priority_groups:
# the number of conditions in this group which are part of the
# encounter's conditions
n = sum(int(c in e['conditions']) for c in conditions)
key.append(n)
key.reverse()
return tuple(key)
_eq_key = staticmethod(itemgetter('pokemon_id', 'form', 'levels'))
def reduce_encounters(root):
# Here's what happens:
# 1. We group by method (since different methods have different slots,
# encounters cannot reasonably be compared across methods).
# 2. We group by terrain (since we can't collapse terrains).
# 3. We group by slot. Each slot will be individually examined for
# reduction. We might collapse a condition (season, say) in slot 1, but
# not in slot 2. A condition like time or season can be collapsed if the
# encounter data for each of its values is identical. A higher condition,
# like swarm or slot2, can be collapsed if the data is identical *and*
# every lower condition can be collapsed.
# 4. Once we figure out which conditions to keep, we take the cartesian
# product of the values of those conditions. This is the set of sets of
# condition values which will be represented in the database.
# 5. For each set of condition values, the encounter data for the slot is
# computed under those conditions.
# 6. Said data is added to the database.
encounters_list = flatten_encounters(lift(root))
for method, group in groupby(encounters_list, itemgetter('method')):
for terrain, group in groupby(group, itemgetter('terrain')):
for slot, encounters in groupby(group, itemgetter('slot')):
encounters = list(encounters)
#for e in encounters:
# print(e)
monkey = EncounterMonkey(encounters)
monkey.collapse_encounters()
monkey.strip_redundancy()
condition_values = \
[[(cond, value) for value in get_condition_values(cond)]
for cond in monkey.get_conditions()]
#print(method, terrain, slot, monkey.encounters)
for condition_set in itertools.product(*condition_values):
condition_set = dict(condition_set)
e = monkey.get_encounter(condition_set)
#print (condition_set, e)
if e:
e = e.copy()
e['conditions'] = condition_set
yield e
#else:
# print (condition_set)
def flatten_encounters(encountergroups):
flattened = []
for group in encountergroups:
context = {'conditions': {}}
flattened.extend(_flatten(context, group))
return flattened
def _flatten(context, group):
context = context.copy()
for key in ('method', 'terrain', 'rate'):
if key in group:
context[key] = group[key]
if group['conditions']:
c = context['conditions'].copy()
c.update(group['conditions'])
context['conditions'] = c
for p in group['pokemon']:
p['conditions'] = context['conditions'].copy()
p['method'] = context['method']
p['terrain'] = context.get('terrain')
yield p
for g in group['subgroups']:
for p in _flatten(context, g):
yield p
def insert_encounters(encounters, ctx):
session.add_all(make_encounter(e, ctx) for e in encounters)
def make_encounter(obj, ctx):
"""Make an db.Encounter object from a dict"""
slot = get_or_create_encounter_slot(
slot = obj['slot'],
version_group_id = ctx['version'].version_group_id,
method = obj['method'],
terrain = obj['terrain'],
)
e = Encounter()
# Pokemon
e.pokemon_id = obj['pokemon_id']
#e.form_id = int(obj['form_id'])
e.version_id = ctx['version'].id
e.location_area = ctx['area']
e.slot = slot
# Add levels
if len(obj['levels']) == 1:
e.min_level = e.max_level = obj['levels'][0]
elif len(obj['levels']) == 2:
e.min_level, e.max_level = obj['levels']
else:
raise ValueError(obj['levels'])
# Add conditions
for cond_value in obj['conditions'].iteritems():
c = get_condition(cond_value)
e.condition_values.append(c)
return e
def get_or_create_location(loc_elem, ctx):
name = unicode(loc_elem.get('name'))
q = session.query(Location).filter_by(
name=name,
region_id=ctx['region'].id,
)
try:
loc = q.one()
except NoResultFound:
loc = Location()
loc.name = name
loc.region = ctx['region']
session.add(loc)
return loc
def create_area(area_elem, ctx):
name = area_elem.get('name')
if name is not None:
name = unicode(name)
area = LocationArea()
area.name = name
area.internal_id = int(area_elem.get('internal_id'))
area.location = ctx['location']
session.add(area)
return area
def main():
engine = create_engine('sqlite:///test.sqlite')
session.bind = engine
session.autoflush = False
filename = sys.argv[1]
xml = parse_xml(filename)
ctx = {}
for game in xml.xpath('/wild/game'):
ctx['version'] = get_version(game.get('version'))
# XXX region should be set based on the location
ctx['region'] = ctx['version'].version_group.generation.main_region
for loc in game.xpath('location'):
ctx['location'] = get_or_create_location(loc, ctx)
for area in loc.xpath('area'):
ctx['area'] = create_area(area, ctx)
if area.get('name', False):
print loc.get('name') + "/" + area.get('name')
else:
print loc.get('name')
encounters = list(reduce_encounters(area))
insert_encounters(encounters, ctx)
#for e in sorted(encounters,
# key=itemgetter('method', 'terrain')):
# print e
session.flush()
session.commit()
if __name__ == '__main__':
import time
from sys import stdout, stderr
time_a = time.time()
main()
time_b = time.time()
stdout.flush()
print >>stderr, "{:.4f} seconds".format(time_b - time_a)