def test_entities_creation(self): dao = Dao() f1 = FeedInfo("F1") a1 = Agency("F1", "A1", "Agency 1", agency_url="http://www.agency.fr/", agency_timezone="Europe/Paris") r1 = Route("F1", "R1", "A1", 3, route_short_name="R1", route_long_name="Route 1") r2 = Route("F1", "R2", "A1", 3, route_short_name="R2") c1 = Calendar("F1", "C1") c1.dates = [ CalendarDate.ymd(2015, 11, 13), CalendarDate.ymd(2015, 11, 14) ] dao.add_all([f1, a1, r1, r2, c1]) self.assertTrue(len(dao.feeds()) == 1) self.assertTrue(len(dao.agencies()) == 1) a1b = dao.agency("A1", feed_id="F1", prefetch_routes=True) self.assertTrue(a1b.agency_name == "Agency 1") self.assertTrue(len(a1b.routes) == 2) r1b = dao.route("R1", feed_id="F1") self.assertTrue(r1b.route_short_name == "R1") self.assertTrue(r1b.route_long_name == "Route 1") self.assertTrue(r1b.route_type == 3) r42 = dao.route("R42", feed_id="F1") self.assertTrue(r42 is None)
def test_stop_station(self): dao = Dao() f1 = FeedInfo("F1") sa = Stop("F1", "SA", "Station A", 45.0000, 0.0000) sa1 = Stop("F1", "SA1", "Stop A1", 45.0001, 0.0001) sa1.parent_station_id = 'SA' sb = Stop("F1", "SB", "Station B", 45.0002, 0.0002) sb1 = Stop("F1", "SB1", "Stop B1", 45.0003, 0.0003) sb1.parent_station_id = 'SB' sb2 = Stop("F1", "SB2", "Stop B2", 45.0002, 0.0003) sb2.parent_station_id = 'SB' a1 = Agency("F1", "A1", "Agency 1", "url1", "Europe/Paris") r1 = Route("F1", "R1", "A1", Route.TYPE_BUS) c1 = Calendar("F1", "C1") t1 = Trip("F1", "T1", "R1", "C1") st1a = StopTime("F1", "T1", "SA1", 0, None, 3600, 0.0) st1b = StopTime("F1", "T1", "SB1", 1, 3800, None, 100.0) dao.add_all([f1, sa, sa1, sb, sb1, sb2, a1, r1, c1, t1, st1a, st1b]) stops = list(dao.stops(fltr=(Agency.agency_id == 'A1'))) self.assertTrue(len(stops) == 2) self.assertTrue(sa1 in stops) self.assertTrue(sb1 in stops) stops = list(dao.stops(fltr=(Stop.parent_station_id == 'SA'))) self.assertTrue(len(stops) == 1) self.assertTrue(sa1 in stops) stops = list(dao.stops(fltr=(Stop.parent_station_id == 'SB'))) self.assertTrue(len(stops) == 2) self.assertTrue(sb1 in stops) self.assertTrue(sb2 in stops)
def test_route_agency_multi_feed(self): dao = Dao() fa = FeedInfo("FA") aa1 = Agency("FA", "A", "Agency A", agency_url="http://www.agency.fr/", agency_timezone="Europe/Paris") ar1 = Route("FA", "R", "A", 3, route_short_name="RA", route_long_name="Route A") ar2 = Route("FA", "R2", "A", 3, route_short_name="RA2", route_long_name="Route A2") fb = FeedInfo("FB") ba1 = Agency("FB", "A", "Agency B", agency_url="http://www.agency.fr/", agency_timezone="Europe/Paris") br1 = Route("FB", "R", "A", 3, route_short_name="RB", route_long_name="Route B") dao.add_all([fa, aa1, ar1, ar2, fb, ba1, br1]) fa = dao.feed("FA") self.assertTrue(len(fa.agencies) == 1) for a in fa.agencies: self.assertTrue(a.agency_name == "Agency A") self.assertTrue(len(fa.routes) == 2) for r in fa.routes: self.assertTrue(r.route_short_name.startswith("RA")) self.assertTrue(r.agency.agency_name == "Agency A")
def test_shapes(self): dao = Dao() f1 = FeedInfo("") a1 = Agency("", "A1", "Agency 1", agency_url="http://www.agency.fr/", agency_timezone="Europe/Paris") r1 = Route("", "R1", "A1", 3, route_short_name="R1", route_long_name="Route 1") c1 = Calendar("", "C1") c1.dates = [ CalendarDate.ymd(2016, 1, 31), CalendarDate.ymd(2016, 2, 1) ] s1 = Stop("", "S1", "Stop 1", 45.0, 0.0) s2 = Stop("", "S2", "Stop 2", 45.1, 0.1) s3 = Stop("", "S3", "Stop 3", 45.2, 0.2) t1 = Trip("", "T1", "R1", "C1") t1.stop_times = [ StopTime(None, None, "S1", 0, 28800, 28800, 0.0), StopTime(None, None, "S2", 1, 29400, 29400, 2.0), StopTime(None, None, "S3", 2, 30000, 30000, 4.0) ] t2 = Trip("", "T2", "R1", "C1") t2.stop_times = [ StopTime(None, None, "S2", 0, 30600, 30600, 0.0), StopTime(None, None, "S1", 1, 31000, 31000, 1.0) ] sh1 = Shape("", "Sh1") sh1.points = [ ShapePoint(None, None, 0, 45.00, 0.00, 0.0), ShapePoint(None, None, 1, 45.05, 0.10, 1.0), ShapePoint(None, None, 2, 45.10, 0.10, 2.0), ShapePoint(None, None, 3, 45.15, 0.20, 3.0), ShapePoint(None, None, 4, 45.20, 0.20, 4.0) ] t1.shape = sh1 dao.add_all([f1, a1, r1, c1, s1, s2, s3, t1, t2, sh1]) dao.commit() t = dao.trip("T1") self.assertTrue(t.shape.shape_id == "Sh1") self.assertTrue(len(t.shape.points) == 5) t = dao.trip("T2") self.assertTrue(t.shape == None)
def _convert_gtfs_model(feed_id, gtfs, dao, lenient=False, disable_normalization=False): feedinfo2 = None logger.info("Importing feed ID '%s'" % feed_id) n_feedinfo = 0 for feedinfo in gtfs.feedinfo(): n_feedinfo += 1 if n_feedinfo > 1: logger.error( "Feed info should be unique if defined. Taking first one." % (n_feedinfo)) break # TODO Automatically compute from calendar range if missing? feedinfo['feed_start_date'] = _todate(feedinfo.get('feed_start_date')) feedinfo['feed_end_date'] = _todate(feedinfo.get('feed_end_date')) feedinfo2 = FeedInfo(feed_id, **feedinfo) if feedinfo2 is None: # Optional, generate empty feed info feedinfo2 = FeedInfo(feed_id) dao.add(feedinfo2) dao.flush() logger.info("Imported %d feedinfo" % n_feedinfo) logger.info("Importing agencies...") n_agencies = 0 single_agency = None agency_ids = set() for agency in gtfs.agencies(): # agency_id is optional only if we have a single agency if n_agencies == 0 and agency.get('agency_id') is None: agency['agency_id'] = '' agency2 = Agency(feed_id, **agency) if n_agencies == 0: single_agency = agency2 else: single_agency = None n_agencies += 1 dao.add(agency2) agency_ids.add(agency2.agency_id) dao.flush() logger.info("Imported %d agencies" % n_agencies) def import_stop(stop, stoptype, zone_ids, item_ids, station_ids=None): zone_id = stop.get('zone_id') if zone_id and zone_id not in zone_ids: # Lazy-creation of zone zone = Zone(feed_id, zone_id) zone_ids.add(zone_id) dao.add(zone) stop['location_type'] = _toint(stop.get('location_type'), Stop.TYPE_STOP) if stop['location_type'] != stoptype: return 0 stop['wheelchair_boarding'] = _toint(stop.get('wheelchair_boarding'), Stop.WHEELCHAIR_UNKNOWN) lat = _tofloat(stop.get('stop_lat'), None) lon = _tofloat(stop.get('stop_lon'), None) if lat is None or lon is None: if lenient: logger.error("Missing lat/lon for '%s', set to default (0,0)" % (stop, )) if lat is None: lat = 0 if lon is None: lon = 0 else: raise ValueError("Missing mandatory lat/lon for '%s'." % (stop, )) stop['stop_lat'] = lat stop['stop_lon'] = lon # This field has been renamed for consistency parent_id = stop.get('parent_station') stop['parent_station_id'] = parent_id if parent_id else None if parent_id and station_ids and parent_id not in station_ids: if lenient: logger.error( "Parent station ID '%s' in '%s' is invalid, resetting." % (parent_id, stop)) stop['parent_station_id'] = None else: raise KeyError("Parent station ID '%s' in '%s' is invalid." % (parent_id, stop)) stop.pop('parent_station', None) stop2 = Stop(feed_id, **stop) dao.add(stop2) item_ids.add(stop2.stop_id) return 1 stop_ids = set() station_ids = set() zone_ids = set() logger.info("Importing zones, stations and stops...") n_stations = n_stops = 0 for station in gtfs.stops(): n_stations += import_stop(station, Stop.TYPE_STATION, zone_ids, station_ids) for stop in gtfs.stops(): n_stops += import_stop(stop, Stop.TYPE_STOP, zone_ids, stop_ids, station_ids) dao.flush() logger.info("Imported %d zones, %d stations and %d stops" % (len(zone_ids), n_stations, n_stops)) logger.info("Importing transfers...") n_transfers = 0 for transfer in gtfs.transfers(): from_stop_id = transfer.get('from_stop_id') to_stop_id = transfer.get('to_stop_id') transfer['transfer_type'] = _toint(transfer.get('transfer_type'), 0) for stop_id in (from_stop_id, to_stop_id): if stop_id not in station_ids and stop_id not in stop_ids: if lenient: logger.error("Stop ID '%s' in '%s' is invalid, skipping." % (stop_id, transfer)) continue else: raise KeyError("Stop ID '%s' in '%s' is invalid." % (stop_id, transfer)) transfer2 = Transfer(feed_id, **transfer) n_transfers += 1 dao.add(transfer2) dao.flush() logger.info("Imported %d transfers" % (n_transfers)) logger.info("Importing routes...") n_routes = 0 route_ids = set() for route in gtfs.routes(): route['route_type'] = int(route.get('route_type')) agency_id = route.get('agency_id') if (agency_id is None or len(agency_id) == 0) and single_agency is not None: # Route.agency is optional if only a single agency exists. agency_id = route['agency_id'] = single_agency.agency_id if agency_id not in agency_ids: if lenient: logger.error( "Agency ID '%s' in '%s' is invalid, skipping route." % (agency_id, route)) continue else: raise KeyError("agency ID '%s' in '%s' is invalid." % (agency_id, route)) route2 = Route(feed_id, **route) dao.add(route2) route_ids.add(route2.route_id) n_routes += 1 dao.flush() logger.info("Imported %d routes" % n_routes) logger.info("Importing fares...") n_fares = 0 for fare_attr in gtfs.fare_attributes(): fare_id = fare_attr.get('fare_id') fare_price = _tofloat(fare_attr.get('price')) currency_type = fare_attr.get('currency_type') payment_method = _toint(fare_attr.get('payment_method')) n_transfers = None if fare_attr.get('transfers') is not None: n_transfers = _toint(fare_attr.get('transfers')) transfer_duration = None if fare_attr.get('transfer_duration') is not None: transfer_duration = _toint(fare_attr.get('transfer_duration')) fare = FareAttribute(feed_id, fare_id, fare_price, currency_type, payment_method, n_transfers, transfer_duration) dao.add(fare) n_fares += 1 dao.flush() fare_rules = set() for fare_rule in gtfs.fare_rules(): fare_rule2 = FareRule(feed_id, **fare_rule) if fare_rule2 in fare_rules: if lenient: logger.error("Duplicated fare rule (%s), skipping." % (fare_rule2)) continue else: raise KeyError("Duplicated fare rule (%s)" % (fare_rule2)) dao.add(fare_rule2) fare_rules.add(fare_rule2) dao.flush() logger.info("Imported %d fare and %d rules" % (n_fares, len(fare_rules))) logger.info("Importing calendars...") calanddates2 = {} for calendar in gtfs.calendars(): calid = calendar.get('service_id') calendar2 = Calendar(feed_id, calid) dates2 = [] start_date = CalendarDate.fromYYYYMMDD(calendar.get('start_date')) end_date = CalendarDate.fromYYYYMMDD(calendar.get('end_date')) for d in CalendarDate.range(start_date, end_date.next_day()): if int(calendar.get(DOW_NAMES[d.dow()])): dates2.append(d) calanddates2[calid] = (calendar2, set(dates2)) logger.info("Normalizing calendar dates...") for caldate in gtfs.calendar_dates(): calid = caldate.get('service_id') date2 = CalendarDate.fromYYYYMMDD(caldate.get('date')) addremove = int(caldate.get('exception_type')) if calid in calanddates2: calendar2, dates2 = calanddates2[calid] else: calendar2 = Calendar(feed_id, calid) dates2 = set([]) calanddates2[calid] = (calendar2, dates2) if addremove == 1: dates2.add(date2) elif addremove == 2: if date2 in dates2: dates2.remove(date2) n_calendars = 0 n_caldates = 0 calendar_ids = set() for (calendar2, dates2) in calanddates2.values(): calendar2.dates = [d for d in dates2] dao.add(calendar2) calendar_ids.add(calendar2.service_id) n_calendars += 1 n_caldates += len(calendar2.dates) dao.flush() logger.info("Imported %d calendars and %d dates" % (n_calendars, n_caldates)) logger.info("Importing shapes...") n_shape_pts = 0 shape_ids = set() shapepts_q = [] for shpt in gtfs.shapes(): shape_id = shpt.get('shape_id') if shape_id not in shape_ids: dao.add(Shape(feed_id, shape_id)) dao.flush() shape_ids.add(shape_id) pt_seq = _toint(shpt.get('shape_pt_sequence')) # This field is optional dist_traveled = _tofloat(shpt.get('shape_dist_traveled'), -999999) lat = _tofloat(shpt.get('shape_pt_lat')) lon = _tofloat(shpt.get('shape_pt_lon')) shape_point = ShapePoint(feed_id, shape_id, pt_seq, lat, lon, dist_traveled) shapepts_q.append(shape_point) n_shape_pts += 1 if n_shape_pts % 100000 == 0: logger.info("%d shape points" % n_shape_pts) dao.bulk_save_objects(shapepts_q) dao.flush() shapepts_q = [] dao.bulk_save_objects(shapepts_q) dao.flush() logger.info("Imported %d shapes and %d points" % (len(shape_ids), n_shape_pts)) logger.info("Importing trips...") n_trips = 0 trips_q = [] trip_ids = set() for trip in gtfs.trips(): trip['wheelchair_accessible'] = _toint( trip.get('wheelchair_accessible'), Trip.WHEELCHAIR_UNKNOWN) trip['bikes_allowed'] = _toint(trip.get('bikes_allowed'), Trip.BIKES_UNKNOWN) cal_id = trip.get('service_id') if cal_id not in calendar_ids: if lenient: logger.error( "Calendar ID '%s' in '%s' is invalid. Skipping trip." % (cal_id, trip)) continue else: raise KeyError("Calendar ID '%s' in '%s' is invalid." % (cal_id, trip)) route_id = trip.get('route_id') if route_id not in route_ids: if lenient: logger.error( "Route ID '%s' in '%s' is invalid. Skipping trip." % (route_id, trip)) continue else: raise KeyError("Route ID '%s' in trip '%s' is invalid." % (route_id, trip)) trip2 = Trip(feed_id, frequency_generated=False, **trip) trips_q.append(trip2) n_trips += 1 if n_trips % 10000 == 0: dao.bulk_save_objects(trips_q) dao.flush() logger.info('%s trips' % n_trips) trips_q = [] trip_ids.add(trip.get('trip_id')) dao.bulk_save_objects(trips_q) dao.flush() logger.info("Imported %d trips" % n_trips) logger.info("Importing stop times...") n_stoptimes = 0 stoptimes_q = [] for stoptime in gtfs.stop_times(): stopseq = _toint(stoptime.get('stop_sequence')) # Mark times to interpolate later on arrtime = _timetoint(stoptime.get('arrival_time'), -999999) deptime = _timetoint(stoptime.get('departure_time'), -999999) if arrtime == -999999: arrtime = deptime if deptime == -999999: deptime = arrtime interp = arrtime < 0 and deptime < 0 shpdist = _tofloat(stoptime.get('shape_dist_traveled'), -999999) pkptype = _toint(stoptime.get('pickup_type'), StopTime.PICKUP_DROPOFF_REGULAR) drptype = _toint(stoptime.get('drop_off_type'), StopTime.PICKUP_DROPOFF_REGULAR) trip_id = stoptime.get('trip_id') if trip_id not in trip_ids: if lenient: logger.error( "Trip ID '%s' in '%s' is invalid. Skipping stop time." % (trip_id, stoptime)) continue else: raise KeyError("Trip ID '%s' in '%s' is invalid." % (trip_id, stoptime)) stop_id = stoptime.get('stop_id') if stop_id not in stop_ids: if lenient: logger.error( "Stop ID '%s' in '%s' is invalid. Skipping stop time." % (stop_id, stoptime)) continue else: raise KeyError("Trip ID '%s' in stoptime '%s' is invalid." % (stop_id, stoptime)) stoptime2 = StopTime(feed_id, trip_id, stop_id, stop_sequence=stopseq, arrival_time=arrtime, departure_time=deptime, shape_dist_traveled=shpdist, interpolated=interp, pickup_type=pkptype, drop_off_type=drptype, stop_headsign=stoptime.get('stop_headsign')) stoptimes_q.append(stoptime2) n_stoptimes += 1 # Commit every now and then if n_stoptimes % 50000 == 0: logger.info("%d stop times" % n_stoptimes) dao.bulk_save_objects(stoptimes_q) dao.flush() stoptimes_q = [] dao.bulk_save_objects(stoptimes_q) logger.info("Imported %d stop times" % n_stoptimes) logger.info("Committing") dao.flush() # TODO Add option to enable/disable this commit # to ensure import is transactionnal dao.commit() logger.info("Commit done") def normalize_trip(trip, odometer): stopseq = 0 n_stoptimes = len(trip.stop_times) last_stoptime_with_time = None to_interpolate = [] odometer.reset() for stoptime in trip.stop_times: stoptime.stop_sequence = stopseq stoptime.shape_dist_traveled = odometer.dist_traveled( stoptime.stop, stoptime.shape_dist_traveled if stoptime.shape_dist_traveled != -999999 else None) if stopseq == 0: # Force first arrival time to NULL stoptime.arrival_time = None if stopseq == n_stoptimes - 1: # Force last departure time to NULL stoptime.departure_time = None if stoptime.interpolated: to_interpolate.append(stoptime) else: if len(to_interpolate) > 0: # Interpolate if last_stoptime_with_time is None: logger.error( "Cannot interpolate missing time at trip start: %s" % trip) for stti in to_interpolate: # Use first defined time as fallback value. stti.arrival_time = stoptime.arrival_time stti.departure_time = stoptime.arrival_time else: tdist = stoptime.shape_dist_traveled - last_stoptime_with_time.shape_dist_traveled ttime = stoptime.arrival_time - last_stoptime_with_time.departure_time for stti in to_interpolate: fdist = stti.shape_dist_traveled - last_stoptime_with_time.shape_dist_traveled t = last_stoptime_with_time.departure_time + ttime * fdist // tdist stti.arrival_time = t stti.departure_time = t to_interpolate = [] last_stoptime_with_time = stoptime stopseq += 1 if len(to_interpolate) > 0: # Should not happen, but handle the case, we never know if last_stoptime_with_time is None: logger.error( "Cannot interpolate missing time, no time at all: %s" % trip) # Keep times NULL (TODO: or remove the trip?) else: logger.error( "Cannot interpolate missing time at trip end: %s" % trip) for stti in to_interpolate: # Use last defined time as fallback value stti.arrival_time = last_stoptime_with_time.departure_time stti.departure_time = last_stoptime_with_time.departure_time if disable_normalization: logger.info("Skipping shapes and trips normalization") else: logger.info("Normalizing shapes and trips...") nshapes = 0 ntrips = 0 odometer = _Odometer() # Process shapes and associated trips for shape in dao.shapes(fltr=Shape.feed_id == feed_id, prefetch_points=True, batch_size=50): # Shape will be registered in the normalize odometer.normalize_and_register_shape(shape) for trip in dao.trips(fltr=(Trip.feed_id == feed_id) & (Trip.shape_id == shape.shape_id), prefetch_stop_times=False, prefetch_stops=False, batch_size=800): normalize_trip(trip, odometer) ntrips += 1 if ntrips % 1000 == 0: logger.info("%d trips, %d shapes" % (ntrips, nshapes)) dao.flush() nshapes += 1 #odometer._debug_cache() # Process trips w/o shapes for trip in dao.trips(fltr=(Trip.feed_id == feed_id) & (Trip.shape_id == None), prefetch_stop_times=False, prefetch_stops=False, batch_size=800): odometer.register_noshape() normalize_trip(trip, odometer) ntrips += 1 if ntrips % 1000 == 0: logger.info("%d trips" % ntrips) dao.flush() dao.flush() logger.info("Normalized %d trips and %d shapes" % (ntrips, nshapes)) # Note: we expand frequencies *after* normalization # for performances purpose only: that minimize the # number of trips to normalize. We can do that since # the expansion is neutral trip-normalization-wise. logger.info("Expanding frequencies...") n_freq = 0 n_exp_trips = 0 trips_to_delete = [] for frequency in gtfs.frequencies(): trip_id = frequency.get('trip_id') if trip_id not in trip_ids: if lenient: logger.error( "Trip ID '%s' in '%s' is invalid. Skipping frequency." % (trip_id, frequency)) continue else: raise KeyError("Trip ID '%s' in '%s' is invalid." % (trip_id, frequency)) trip = dao.trip(trip_id, feed_id=feed_id) start_time = _timetoint(frequency.get('start_time')) end_time = _timetoint(frequency.get('end_time')) headway_secs = _toint(frequency.get('headway_secs')) exact_times = _toint(frequency.get('exact_times'), Trip.TIME_APPROX) for trip_dep_time in range(start_time, end_time, headway_secs): # Here we assume departure time are all different. # That's a requirement in the GTFS specs, but this may break. # TODO Make the expanded trip ID generation parametrable. trip_id2 = trip.trip_id + "@" + fmttime(trip_dep_time) trip2 = Trip(feed_id, trip_id2, trip.route_id, trip.service_id, wheelchair_accessible=trip.wheelchair_accessible, bikes_allowed=trip.bikes_allowed, exact_times=exact_times, frequency_generated=True, trip_headsign=trip.trip_headsign, trip_short_name=trip.trip_short_name, direction_id=trip.direction_id, block_id=trip.block_id) trip2.stop_times = [] base_time = trip.stop_times[0].departure_time for stoptime in trip.stop_times: arrtime = None if stoptime.arrival_time is None else stoptime.arrival_time - base_time + trip_dep_time deptime = None if stoptime.departure_time is None else stoptime.departure_time - base_time + trip_dep_time stoptime2 = StopTime( feed_id, trip_id2, stoptime.stop_id, stoptime.stop_sequence, arrival_time=arrtime, departure_time=deptime, shape_dist_traveled=stoptime.shape_dist_traveled, interpolated=stoptime.interpolated, timepoint=stoptime.timepoint, pickup_type=stoptime.pickup_type, drop_off_type=stoptime.drop_off_type) trip2.stop_times.append(stoptime2) n_exp_trips += 1 # This will add the associated stop times dao.add(trip2) # Do not delete trip now, as two frequency can refer to same trip trips_to_delete.append(trip) n_freq += 1 for trip in trips_to_delete: # This also delete the associated stop times dao.delete(trip) dao.flush() dao.commit() logger.info("Expanded %d frequencies to %d trips." % (n_freq, n_exp_trips)) logger.info("Feed '%s': import done." % feed_id)
def test_fares(self): dao = Dao() f1 = FeedInfo("") a1 = Agency("", "A1", "Agency 1", agency_url="http://www.agency.fr/", agency_timezone="Europe/Paris") r1 = Route("", "R1", "A1", 3, route_short_name="R1", route_long_name="Route 1") r2 = Route("", "R2", "A1", 3, route_short_name="R2", route_long_name="Route 2") z1 = Zone("", "Z1") z2 = Zone("", "Z2") fare1 = FareAttribute("", "F1", 1.0, "EUR", FareAttribute.PAYMENT_ONBOARD, None, None) fare2 = FareAttribute("", "F2", 2.0, "EUR", FareAttribute.PAYMENT_BEFOREBOARDING, 3, 3600) rule1 = FareRule("", "F1", route_id="R1") rule2 = FareRule("", "F1", origin_id="Z1", destination_id="Z2") dao.add_all([f1, a1, r1, r2, z1, z2, fare1, fare2, rule1, rule2]) dao.commit() self.assertTrue(len(dao.fare_attributes()) == 2) self.assertTrue(len(dao.fare_rules(fltr=(FareRule.route == r1))) == 1) self.assertTrue(len(dao.fare_rules(fltr=(FareRule.route == r2))) == 0) self.assertTrue(len(dao.fare_rules(fltr=(FareRule.origin == z1))) == 1) fare = dao.fare_attribute("F1") self.assertTrue(len(fare.fare_rules) == 2) fare = dao.fare_attribute("F2") self.assertTrue(len(fare.fare_rules) == 0) # Test equivalence and hash on primary keys for rule fr1a = FareRule("", "F1", route_id="R", origin_id="ZO", destination_id="ZD", contains_id=None) fr1b = FareRule("", "F1", route_id="R", origin_id="ZO", destination_id="ZD", contains_id=None) fr2 = FareRule("", "F1", route_id="R", origin_id="ZO", destination_id=None, contains_id="ZD") ruleset = set() ruleset.add(fr1a) ruleset.add(fr2) ruleset.add(fr1b) self.assertTrue(len(ruleset) == 2) self.assertTrue(fr1a == fr1b) self.assertTrue(fr1a != fr2)
def test_transfers(self): dao = Dao() f1 = FeedInfo("F1") s1 = Stop("F1", "S1", "Stop 1", 45.0000, 0.0000) s2 = Stop("F1", "S2", "Stop 2", 45.0001, 0.0001) s3 = Stop("F1", "S3", "Stop 3", 45.0002, 0.0002) t12 = Transfer("F1", "S1", "S2") t21 = Transfer("F1", "S2", "S1") t23 = Transfer("F1", "S2", "S3", transfer_type=Transfer.TRANSFER_TIMED, min_transfer_time=180) t32 = Transfer("F1", "S3", "S2", transfer_type=Transfer.TRANSFER_TIMED, min_transfer_time=120) t13 = Transfer("F1", "S1", "S3", transfer_type=Transfer.TRANSFER_NONE) a1 = Agency("F1", "A1", "Agency 1", "url1", "Europe/Paris") a2 = Agency("F1", "A2", "Agency 2", "url2", "Europe/London") r1 = Route("F1", "R1", "A1", Route.TYPE_BUS) r2 = Route("F1", "R2", "A2", Route.TYPE_BUS) c1 = Calendar("F1", "C1") t1 = Trip("F1", "T1", "R1", "C1") t2 = Trip("F1", "T2", "R2", "C1") st1a = StopTime("F1", "T1", "S1", 0, None, 3600, 0.0) st1b = StopTime("F1", "T1", "S2", 1, 3800, None, 100.0) st2a = StopTime("F1", "T2", "S1", 0, None, 4600, 0.0) st2b = StopTime("F1", "T2", "S3", 1, 4800, None, 100.0) dao.add_all([ f1, s1, s2, s3, t12, t21, t23, t32, t13, a1, a2, r1, r2, c1, t1, t2, st1a, st1b, st2a, st2b ]) self.assertTrue(len(dao.transfers()) == 5) timed_transfers = dao.transfers( fltr=(Transfer.transfer_type == Transfer.TRANSFER_TIMED)) self.assertTrue(len(timed_transfers) == 2) for transfer in timed_transfers: self.assertTrue(transfer.transfer_type == Transfer.TRANSFER_TIMED) s1_from_transfers = dao.transfers( fltr=(dao.transfer_from_stop().stop_name == "Stop 1")) self.assertTrue(len(s1_from_transfers) == 2) for transfer in s1_from_transfers: self.assertTrue(transfer.from_stop.stop_name == "Stop 1") s1_fromto_transfers = dao.transfers( fltr=((dao.transfer_from_stop().stop_name == "Stop 1") | (dao.transfer_to_stop().stop_name == "Stop 1"))) self.assertTrue(len(s1_fromto_transfers) == 3) for transfer in s1_fromto_transfers: self.assertTrue(transfer.from_stop.stop_name == "Stop 1" or transfer.to_stop.stop_name == "Stop 1") s1 = dao.stop("S1", feed_id="F1") self.assertTrue(len(s1.from_transfers) == 2) self.assertTrue(len(s1.to_transfers) == 1) for transfer in s1.from_transfers: if transfer.to_stop.stop_id == "S2": self.assertTrue( transfer.transfer_type == Transfer.TRANSFER_DEFAULT) elif transfer.to_stop.stop_id == "S3": self.assertTrue( transfer.transfer_type == Transfer.TRANSFER_NONE) a1_stops = list(dao.stops(fltr=(Agency.agency_id == 'A1'))) self.assertTrue(len(a1_stops) == 2) self.assertTrue(s1 in a1_stops) self.assertTrue(s2 in a1_stops)
def test_trip(self): dao = Dao() f1 = FeedInfo("F1") a1 = Agency("F1", "A1", "Agency 1", agency_url="http://www.agency.fr/", agency_timezone="Europe/Paris") r1 = Route("F1", "R1", "A1", 3, route_short_name="R1", route_long_name="Route 1") c1 = Calendar("F1", "C1") c1.dates = [ d for d in CalendarDate.range( CalendarDate.ymd(2016, 1, 1), CalendarDate.ymd(2016, 1, 31).next_day()) ] s1 = Stop("F1", "S1", "Stop 1", 45.0, 0.0) s2 = Stop("F1", "S2", "Stop 2", 45.1, 0.1) s3 = Stop("F1", "S3", "Stop 3", 45.2, 0.2) t1 = Trip("F1", "T1", "R1", "C1") t1.direction_id = 0 t11 = StopTime("F1", "T1", "S1", 0, 28800, 28800, 0.0) t12 = StopTime("F1", "T1", "S2", 1, 29400, 29400, 0.0) t13 = StopTime("F1", "T1", "S3", 2, 30000, 30000, 0.0) t2 = Trip("F1", "T2", "R1", "C1") t2.direction_id = 1 # Order is not important for now t2.stop_times.append(StopTime(None, None, "S1", 1, 31000, 31000, 0.0)) t2.stop_times.append(StopTime(None, None, "S2", 0, 30600, 30600, 0.0)) dao.add_all([f1, a1, r1, c1, s1, s2, s3, t1, t11, t12, t13, t2]) # Commit is needed to re-order stop times of T2 dao.commit() cal = dao.calendar("C1", feed_id="F1") for trip in cal.trips: self.assertTrue(trip.calendar.service_id == "C1") for stoptime in trip.stop_times: self.assertTrue(stoptime.trip.calendar.service_id == "C1") stop = dao.stop("S2", feed_id="F1") for stoptime in stop.stop_times: self.assertTrue(stoptime.stop.stop_id == "S2") self.assertTrue(stoptime.trip.trip_id.startswith("T")) trip = dao.trip("T1", feed_id="F1") self.assertTrue(len(trip.stop_times) == 3) trip = dao.trip("T2", feed_id="F1") self.assertTrue(len(trip.stop_times) == 2) for trip in dao.trips(prefetch_stop_times=True): last_stop_seq = -1 for stoptime in trip.stop_times: self.assertTrue(stoptime.stop_sequence > last_stop_seq) last_stop_seq = stoptime.stop_sequence for trip in dao.trips(): for stoptime1, stoptime2 in trip.hops(): self.assertTrue(stoptime1.trip == stoptime2.trip) self.assertTrue(stoptime1.stop_sequence + 1 == stoptime2.stop_sequence) trips = list(dao.trips(fltr=Trip.direction_id == 0)) self.assertTrue(len(trips) == 1) trips = list(dao.trips(fltr=Trip.direction_id == 1)) self.assertTrue(len(trips) == 1)