def test_custom_queries(self): dao = Dao(DAO_URL, sql_logging=SQL_LOG) dao.load_gtfs(DUMMY_GTFS) # A simple custom query: count the number of stops per type (stop/station) # SQL equivalent: SELECT stop.location_type, count(stop.location_type) FROM stop GROUP BY stop.location_type for type, stop_count in dao.session \ .query(Stop.location_type, func.count(Stop.location_type)) \ .group_by(Stop.location_type) \ .all(): # print("type %d : %d stops" % (type, stop_count)) if type == Stop.TYPE_STATION: self.assertTrue(stop_count == 3) if type == Stop.TYPE_STOP: self.assertTrue(15 < stop_count < 30) # A more complex custom query: count the number of trips per calendar date per route on june/july from_date = CalendarDate.ymd(2016, 6, 1) to_date = CalendarDate.ymd(2016, 7, 31) for date, route, trip_count in dao.session \ .query(CalendarDate.date, Route, func.count(Trip.trip_id)) \ .select_from(Calendar).join(Trip).join(Route) \ .filter((func.date(CalendarDate.date) >= from_date.date) & (func.date(CalendarDate.date) <= to_date.date)) \ .group_by(CalendarDate.date, Route.route_short_name) \ .all(): # print("%s / %20s : %d trips" % (date, route.route_short_name + " " + route.route_long_name, trip_count)) self.assertTrue(date >= from_date.as_date()) self.assertTrue(date <= to_date.as_date()) self.assertTrue(trip_count > 0)
def test_entities_creation(self): dao = Dao() f1 = FeedInfo("F1") a1 = Agency("F1", "A1", "Agency 1", agency_url="http://www.agency.fr/", agency_timezone="Europe/Paris") r1 = Route("F1", "R1", "A1", 3, route_short_name="R1", route_long_name="Route 1") r2 = Route("F1", "R2", "A1", 3, route_short_name="R2") c1 = Calendar("F1", "C1") c1.dates = [ CalendarDate.ymd(2015, 11, 13), CalendarDate.ymd(2015, 11, 14) ] dao.add_all([f1, a1, r1, r2, c1]) self.assertTrue(len(dao.feeds()) == 1) self.assertTrue(len(dao.agencies()) == 1) a1b = dao.agency("A1", feed_id="F1", prefetch_routes=True) self.assertTrue(a1b.agency_name == "Agency 1") self.assertTrue(len(a1b.routes) == 2) r1b = dao.route("R1", feed_id="F1") self.assertTrue(r1b.route_short_name == "R1") self.assertTrue(r1b.route_long_name == "Route 1") self.assertTrue(r1b.route_type == 3) r42 = dao.route("R42", feed_id="F1") self.assertTrue(r42 is None)
def __load_services(self): self.service_ids = set() today = datetime.datetime.today() start_date = CalendarDate.fromYYYYMMDD(today.strftime('%Y%m%d')) end_date = CalendarDate.fromYYYYMMDD(datetime.date(today.year, 12, 31).strftime('%Y%m%d')) def save_calendar_for(service_id: str, dow: []): self.service_ids.add(service_id) service = Calendar(self.feed_id, service_id) dates = [] for d in CalendarDate.range(start_date, end_date.next_day()): if dow[d.dow()] == 1: d.service_id = service.service_id d.feed_id = self.feed_id # add to list to be saved dates.append(d) self.dao.add(service) self.dao.bulk_save_objects(dates) save_calendar_for("lv", [1, 1, 1, 1, 1, 0, 0]) save_calendar_for("s", [0, 0, 0, 0, 0, 1, 0]) save_calendar_for("d", [0, 0, 0, 0, 0, 0, 1])
def test_calendar_date_range(self): d1 = CalendarDate.ymd(2016, 1, 1) d2 = CalendarDate.ymd(2016, 2, 1) n = 0 for d in CalendarDate.range(d1, d2): self.assertTrue(d >= d1) self.assertTrue(d < d2) n += 1 self.assertEqual(n, 31)
def test_calendar_date_set(self): d1 = CalendarDate.ymd(2015, 12, 31) d2 = CalendarDate.ymd(2016, 1, 1) dates = {d1, d2} self.assertTrue(len(dates) == 2) d2b = CalendarDate.ymd(2016, 1, 1) dates.add(d2b) self.assertTrue(len(dates) == 2) d1b = CalendarDate.ymd(2015, 12, 31) d4 = CalendarDate.ymd(2015, 1, 2) dates.add(d1b) dates.add(d4) self.assertTrue(len(dates) == 3)
def make_dow_service(feed_id: str, service_id: str, today, dow: list = []): dates = [] service = Calendar(feed_id, service_id) start_date = CalendarDate.fromYYYYMMDD(today.strftime("%Y%m%d")) end_date = CalendarDate.fromYYYYMMDD( datetime.date(today.year, 12, 31).strftime("%Y%m%d")) for d in CalendarDate.range(start_date, end_date.next_day()): if dow[d.dow()] == 1: d.feed_id = feed_id d.service_id = service.service_id dates.append(d) return service, dates
def test_calendar_date(self): self.assertEqual(True, True) d1 = CalendarDate.ymd(2015, 12, 31) d2 = CalendarDate.ymd(2016, 1, 1) self.assertEqual(d1.next_day(), d2) self.assertTrue(d1 < d2) self.assertTrue(d1 <= d2) self.assertFalse(d1 > d2) self.assertFalse(d1 >= d2) self.assertFalse(d1 == d2) dt = datetime.date(2015, 12, 31) self.assertTrue(d1 == dt) dates = [CalendarDate.ymd(2016, 1, 1), CalendarDate.ymd(2016, 1, 2)] self.assertTrue(CalendarDate.ymd(2016, 1, 2) in dates)
def test_shapes(self): dao = Dao() f1 = FeedInfo("") a1 = Agency("", "A1", "Agency 1", agency_url="http://www.agency.fr/", agency_timezone="Europe/Paris") r1 = Route("", "R1", "A1", 3, route_short_name="R1", route_long_name="Route 1") c1 = Calendar("", "C1") c1.dates = [ CalendarDate.ymd(2016, 1, 31), CalendarDate.ymd(2016, 2, 1) ] s1 = Stop("", "S1", "Stop 1", 45.0, 0.0) s2 = Stop("", "S2", "Stop 2", 45.1, 0.1) s3 = Stop("", "S3", "Stop 3", 45.2, 0.2) t1 = Trip("", "T1", "R1", "C1") t1.stop_times = [ StopTime(None, None, "S1", 0, 28800, 28800, 0.0), StopTime(None, None, "S2", 1, 29400, 29400, 2.0), StopTime(None, None, "S3", 2, 30000, 30000, 4.0) ] t2 = Trip("", "T2", "R1", "C1") t2.stop_times = [ StopTime(None, None, "S2", 0, 30600, 30600, 0.0), StopTime(None, None, "S1", 1, 31000, 31000, 1.0) ] sh1 = Shape("", "Sh1") sh1.points = [ ShapePoint(None, None, 0, 45.00, 0.00, 0.0), ShapePoint(None, None, 1, 45.05, 0.10, 1.0), ShapePoint(None, None, 2, 45.10, 0.10, 2.0), ShapePoint(None, None, 3, 45.15, 0.20, 3.0), ShapePoint(None, None, 4, 45.20, 0.20, 4.0) ] t1.shape = sh1 dao.add_all([f1, a1, r1, c1, s1, s2, s3, t1, t2, sh1]) dao.commit() t = dao.trip("T1") self.assertTrue(t.shape.shape_id == "Sh1") self.assertTrue(len(t.shape.points) == 5) t = dao.trip("T2") self.assertTrue(t.shape == None)
def _load_services(self): year = datetime.datetime.now().year start_date = CalendarDate.fromYYYYMMDD(f"{year}0101") end_date = CalendarDate.fromYYYYMMDD(f"{year}1231") def save_calendar_for(service_id: str, dow: []): service = Calendar(self.feed_id, service_id) self.dao.update(service) self.dao.flush() save_calendar_for("LV", [1, 1, 1, 1, 1, 0, 0]) save_calendar_for("SD", [0, 0, 0, 0, 0, 1, 1]) self.service_id = "LV" if datetime.datetime.today().weekday( ) < 5 else "SD"
def test_calendar_date_out_of_range(self): broke = False try: d1 = CalendarDate.ymd(2015, 12, 32) # @UnusedVariable except (ValueError): broke = True self.assertTrue(broke)
def save_calendar_for(service_id: str, dow: []): self.service_ids.add(service_id) service = Calendar(self.feed_id, service_id) dates = [] for d in CalendarDate.range(start_date, end_date.next_day()): if dow[d.dow()] == 1: d.service_id = service.service_id d.feed_id = self.feed_id # add to list to be saved dates.append(d) self.dao.add(service) self.dao.bulk_save_objects(dates)
def test_trip(self): dao = Dao() f1 = FeedInfo("F1") a1 = Agency("F1", "A1", "Agency 1", agency_url="http://www.agency.fr/", agency_timezone="Europe/Paris") r1 = Route("F1", "R1", "A1", 3, route_short_name="R1", route_long_name="Route 1") c1 = Calendar("F1", "C1") c1.dates = [ d for d in CalendarDate.range( CalendarDate.ymd(2016, 1, 1), CalendarDate.ymd(2016, 1, 31).next_day()) ] s1 = Stop("F1", "S1", "Stop 1", 45.0, 0.0) s2 = Stop("F1", "S2", "Stop 2", 45.1, 0.1) s3 = Stop("F1", "S3", "Stop 3", 45.2, 0.2) t1 = Trip("F1", "T1", "R1", "C1") t1.direction_id = 0 t11 = StopTime("F1", "T1", "S1", 0, 28800, 28800, 0.0) t12 = StopTime("F1", "T1", "S2", 1, 29400, 29400, 0.0) t13 = StopTime("F1", "T1", "S3", 2, 30000, 30000, 0.0) t2 = Trip("F1", "T2", "R1", "C1") t2.direction_id = 1 # Order is not important for now t2.stop_times.append(StopTime(None, None, "S1", 1, 31000, 31000, 0.0)) t2.stop_times.append(StopTime(None, None, "S2", 0, 30600, 30600, 0.0)) dao.add_all([f1, a1, r1, c1, s1, s2, s3, t1, t11, t12, t13, t2]) # Commit is needed to re-order stop times of T2 dao.commit() cal = dao.calendar("C1", feed_id="F1") for trip in cal.trips: self.assertTrue(trip.calendar.service_id == "C1") for stoptime in trip.stop_times: self.assertTrue(stoptime.trip.calendar.service_id == "C1") stop = dao.stop("S2", feed_id="F1") for stoptime in stop.stop_times: self.assertTrue(stoptime.stop.stop_id == "S2") self.assertTrue(stoptime.trip.trip_id.startswith("T")) trip = dao.trip("T1", feed_id="F1") self.assertTrue(len(trip.stop_times) == 3) trip = dao.trip("T2", feed_id="F1") self.assertTrue(len(trip.stop_times) == 2) for trip in dao.trips(prefetch_stop_times=True): last_stop_seq = -1 for stoptime in trip.stop_times: self.assertTrue(stoptime.stop_sequence > last_stop_seq) last_stop_seq = stoptime.stop_sequence for trip in dao.trips(): for stoptime1, stoptime2 in trip.hops(): self.assertTrue(stoptime1.trip == stoptime2.trip) self.assertTrue(stoptime1.stop_sequence + 1 == stoptime2.stop_sequence) trips = list(dao.trips(fltr=Trip.direction_id == 0)) self.assertTrue(len(trips) == 1) trips = list(dao.trips(fltr=Trip.direction_id == 1)) self.assertTrue(len(trips) == 1)
def test_complex_queries(self): dao = Dao(DAO_URL, sql_logging=SQL_LOG) dao.load_gtfs(DUMMY_GTFS) # Get the list of departures: # 1) from "Porte de Bourgogne" # 2) on 4th July # 3) between 10:00 and 14:00 # 4) on route type BUS # 5) not the last of trip (only departing) porte_bourgogne = dao.stop("BBG") july4 = CalendarDate.ymd(2016, 7, 4) from_time = gtfstime(10, 00) to_time = gtfstime(14, 00) departures = dao.stoptimes( fltr=(StopTime.stop == porte_bourgogne) & (StopTime.departure_time >= from_time) & (StopTime.departure_time <= to_time) & (Route.route_type == Route.TYPE_BUS) & (func.date(CalendarDate.date) == july4.date), prefetch_trips=True) n = 0 for dep in departures: self.assertTrue(dep.stop == porte_bourgogne) self.assertTrue(july4 in dep.trip.calendar.dates) self.assertTrue(dep.trip.route.route_type == Route.TYPE_BUS) self.assertTrue(dep.departure_time >= from_time and dep.departure_time <= to_time) n += 1 self.assertTrue(n > 10) # Plage is a stop that is used only in summer (hence the name!) plage = dao.stop("BPG") # Get the list of stops used by some route: # 1) All-year round route_red = dao.route("BR") stoplist_all = list(dao.stops(fltr=Trip.route == route_red)) # 2) Only in january from_date = CalendarDate.ymd(2016, 1, 1) to_date = CalendarDate.ymd(2016, 1, 31) stoplist_jan = list( dao.stops(fltr=(Trip.route == route_red) & (func.date(CalendarDate.date) >= from_date.date) & (func.date(CalendarDate.date) <= to_date.date))) # Now, make some tests self.assertTrue(len(stoplist_all) > 5) self.assertTrue(plage in stoplist_all) self.assertFalse(plage in stoplist_jan) stoplist = list(stoplist_all) stoplist.remove(plage) self.assertTrue(set(stoplist) == set(stoplist_jan)) # Get all routes passing by the set of stops routes = dao.routes(fltr=or_(StopTime.stop == stop for stop in stoplist_jan)) stopset = set() for route in routes: for trip in route.trips: for stoptime in trip.stop_times: stopset.add(stoptime.stop) self.assertTrue(set(stoplist_jan).issubset(stopset))
def test_calendar_date_convert(self): d1 = CalendarDate.fromYYYYMMDD("20151231") d2 = CalendarDate.ymd(2015, 12, 31) self.assertTrue(d1 == d2)
def test_gtfs_data(self): dao = Dao(DAO_URL, sql_logging=False) dao.load_gtfs(DUMMY_GTFS) # Check feed feed = dao.feed() self.assertTrue(feed.feed_id == "") self.assertTrue(feed.feed_publisher_name == "Mecatran") self.assertTrue(feed.feed_publisher_url == "http://www.mecatran.com/") self.assertTrue(feed.feed_contact_email == "*****@*****.**") self.assertTrue(feed.feed_lang == "fr") self.assertTrue(len(dao.agencies()) == 2) self.assertTrue(len(dao.routes()) == 3) self.assertTrue(len(feed.agencies) == 2) self.assertTrue(len(feed.routes) == 3) # Check agencies at = dao.agency("AT") self.assertTrue(at.agency_name == "Agency Train") self.assertTrue(len(at.routes) == 1) ab = dao.agency("AB") self.assertTrue(ab.agency_name == "Agency Bus") self.assertTrue(len(ab.routes) == 2) # Check calendars week = dao.calendar("WEEK") self.assertTrue(len(week.dates) == 253) summer = dao.calendar("SUMMER") self.assertTrue(len(summer.dates) == 42) mon = dao.calendar("MONDAY") self.assertTrue(len(mon.dates) == 49) sat = dao.calendar("SAT") self.assertTrue(len(sat.dates) == 53) for date in mon.dates: self.assertTrue(date.dow() == 0) for date in sat.dates: self.assertTrue(date.dow() == 5) for date in week.dates: self.assertTrue(date.dow() >= 0 and date.dow() <= 4) for date in summer.dates: self.assertTrue(date >= CalendarDate.ymd(2016, 7, 1) and date <= CalendarDate.ymd(2016, 8, 31)) empty = dao.calendars( func.date(CalendarDate.date) == datetime.date(2016, 5, 1)) # OR USE: empty = dao.calendars(CalendarDate.date == "2016-05-01") self.assertTrue(len(empty) == 0) july4 = CalendarDate.ymd(2016, 7, 4) summer_mon = dao.calendars(func.date(CalendarDate.date) == july4.date) n = 0 for cal in summer_mon: self.assertTrue(july4 in cal.dates) n += 1 self.assertTrue(n == 3) # Check stops sbq = dao.stop("BQ") self.assertAlmostEqual(sbq.stop_lat, 44.844, places=2) self.assertAlmostEqual(sbq.stop_lon, -0.573, places=2) self.assertTrue(sbq.stop_name == "Bordeaux Quinconces") n = 0 for stop in dao.stops(Stop.stop_name.like("Gare%")): self.assertTrue(stop.stop_name.startswith("Gare")) n += 1 self.assertTrue(n == 7) n = 0 for stop in dao.stops( fltr=dao.in_area(RectangularArea(44.7, -0.6, 44.9, -0.4))): self.assertTrue(stop.stop_lat >= 44.7 and stop.stop_lat <= 44.9 and stop.stop_lon >= -0.6 and stop.stop_lon <= -0.4) n += 1 self.assertTrue(n == 16) for station in dao.stops(Stop.location_type == Stop.TYPE_STATION): self.assertTrue(station.location_type == Stop.TYPE_STATION) self.assertTrue(len(station.sub_stops) >= 2) for stop in station.sub_stops: self.assertTrue(stop.parent_station == station) # Check zones z_inexistant = dao.zone("ZX") self.assertTrue(z_inexistant is None) z1 = dao.zone("Z1") self.assertEqual(16, len(z1.stops)) z2 = dao.zone("Z2") self.assertEqual(4, len(z2.stops)) # Check transfers transfers = dao.transfers() self.assertTrue(len(transfers) == 3) transfers = dao.transfers( fltr=(dao.transfer_from_stop().stop_id == 'GBSJB')) self.assertTrue(len(transfers) == 1) self.assertTrue(transfers[0].from_stop.stop_id == 'GBSJB') # Check routes tgv = dao.route("TGVBP") self.assertTrue(tgv.agency == at) self.assertTrue(tgv.route_type == 2) r1 = dao.route("BR") self.assertTrue(r1.route_short_name == "R1") self.assertTrue(r1.route_long_name == "Bus Red") n = 0 for route in dao.routes(Route.route_type == 3): self.assertTrue(route.route_type == 3) n += 1 self.assertTrue(n == 2) # Check trip for route n = 0 trips = dao.trips(fltr=Route.route_type == Route.TYPE_BUS) for trip in trips: self.assertTrue(trip.route.route_type == Route.TYPE_BUS) n += 1 self.assertTrue(n > 20) # Check trips on date trips = dao.trips(fltr=func.date(CalendarDate.date) == july4.date, prefetch_calendars=True) n = 0 for trip in trips: self.assertTrue(july4 in trip.calendar.dates) n += 1 self.assertTrue(n > 30)
def test_demo(self): dao = Dao(DAO_URL, sql_logging=False) dao.load_gtfs(DUMMY_GTFS) print("List of stops named '...Bordeaux...':") stops_bordeaux = list( dao.stops(fltr=(Stop.stop_name.ilike('%Bordeaux%')) & (Stop.location_type == Stop.TYPE_STOP))) for stop in stops_bordeaux: print(stop.stop_name) print("List of routes passing by those stops:") routes_bordeaux = dao.routes(fltr=or_(StopTime.stop == stop for stop in stops_bordeaux)) for route in routes_bordeaux: print("%s - %s" % (route.route_short_name, route.route_long_name)) july4 = CalendarDate.ymd(2016, 7, 4) print("All departures from those stops on %s:" % (july4.as_date())) departures = list( dao.stoptimes(fltr=(or_(StopTime.stop == stop for stop in stops_bordeaux)) & (StopTime.departure_time != None) & (func.date(CalendarDate.date) == july4.date))) print("There is %d departures" % (len(departures))) for departure in departures: print("%30.30s %10.10s %-20.20s > %s" % (departure.stop.stop_name, fmttime(departure.departure_time), departure.trip.route.route_long_name, departure.trip.trip_headsign)) print("Number of departures and time range per stop on %s:" % (july4.as_date())) departure_by_stop = defaultdict(list) for departure in departures: departure_by_stop[departure.stop].append(departure) for stop, deps in departure_by_stop.items(): min_dep = min(d.departure_time for d in deps) max_dep = max(d.departure_time for d in deps) print("%30.30s %3d departures (from %s to %s)" % (stop.stop_name, len(deps), fmttime(min_dep), fmttime(max_dep))) # Compute the average distance and time to next stop by route type ntd = [[0, 0, 0.0] for type in range(0, Route.TYPE_FUNICULAR + 1)] for departure in departures: # The following is guaranteed to succeed as we have departure_time == Null for last stop time in trip next_arrival = departure.trip.stop_times[departure.stop_sequence + 1] hop_dist = next_arrival.shape_dist_traveled - departure.shape_dist_traveled hop_time = next_arrival.arrival_time - departure.departure_time route_type = departure.trip.route.route_type ntd[route_type][0] += 1 ntd[route_type][1] += hop_time ntd[route_type][2] += hop_dist for route_type in range(0, len(ntd)): n, t, d = ntd[route_type] if n > 0: print( "The average distance to the next stop on those departures for route type %d is %.2f meters" % (route_type, d / n)) print( "The average time in sec to the next stop on those departures for route type %d is %s" % (route_type, fmttime(t / n)))
def _todate(s, default_value=None): if s is None or len(s) == 0: return default_value return CalendarDate.fromYYYYMMDD(s).as_date()
def _convert_gtfs_model(feed_id, gtfs, dao, lenient=False, disable_normalization=False): feedinfo2 = None logger.info("Importing feed ID '%s'" % feed_id) n_feedinfo = 0 for feedinfo in gtfs.feedinfo(): n_feedinfo += 1 if n_feedinfo > 1: logger.error( "Feed info should be unique if defined. Taking first one." % (n_feedinfo)) break # TODO Automatically compute from calendar range if missing? feedinfo['feed_start_date'] = _todate(feedinfo.get('feed_start_date')) feedinfo['feed_end_date'] = _todate(feedinfo.get('feed_end_date')) feedinfo2 = FeedInfo(feed_id, **feedinfo) if feedinfo2 is None: # Optional, generate empty feed info feedinfo2 = FeedInfo(feed_id) dao.add(feedinfo2) dao.flush() logger.info("Imported %d feedinfo" % n_feedinfo) logger.info("Importing agencies...") n_agencies = 0 single_agency = None agency_ids = set() for agency in gtfs.agencies(): # agency_id is optional only if we have a single agency if n_agencies == 0 and agency.get('agency_id') is None: agency['agency_id'] = '' agency2 = Agency(feed_id, **agency) if n_agencies == 0: single_agency = agency2 else: single_agency = None n_agencies += 1 dao.add(agency2) agency_ids.add(agency2.agency_id) dao.flush() logger.info("Imported %d agencies" % n_agencies) def import_stop(stop, stoptype, zone_ids, item_ids, station_ids=None): zone_id = stop.get('zone_id') if zone_id and zone_id not in zone_ids: # Lazy-creation of zone zone = Zone(feed_id, zone_id) zone_ids.add(zone_id) dao.add(zone) stop['location_type'] = _toint(stop.get('location_type'), Stop.TYPE_STOP) if stop['location_type'] != stoptype: return 0 stop['wheelchair_boarding'] = _toint(stop.get('wheelchair_boarding'), Stop.WHEELCHAIR_UNKNOWN) lat = _tofloat(stop.get('stop_lat'), None) lon = _tofloat(stop.get('stop_lon'), None) if lat is None or lon is None: if lenient: logger.error("Missing lat/lon for '%s', set to default (0,0)" % (stop, )) if lat is None: lat = 0 if lon is None: lon = 0 else: raise ValueError("Missing mandatory lat/lon for '%s'." % (stop, )) stop['stop_lat'] = lat stop['stop_lon'] = lon # This field has been renamed for consistency parent_id = stop.get('parent_station') stop['parent_station_id'] = parent_id if parent_id else None if parent_id and station_ids and parent_id not in station_ids: if lenient: logger.error( "Parent station ID '%s' in '%s' is invalid, resetting." % (parent_id, stop)) stop['parent_station_id'] = None else: raise KeyError("Parent station ID '%s' in '%s' is invalid." % (parent_id, stop)) stop.pop('parent_station', None) stop2 = Stop(feed_id, **stop) dao.add(stop2) item_ids.add(stop2.stop_id) return 1 stop_ids = set() station_ids = set() zone_ids = set() logger.info("Importing zones, stations and stops...") n_stations = n_stops = 0 for station in gtfs.stops(): n_stations += import_stop(station, Stop.TYPE_STATION, zone_ids, station_ids) for stop in gtfs.stops(): n_stops += import_stop(stop, Stop.TYPE_STOP, zone_ids, stop_ids, station_ids) dao.flush() logger.info("Imported %d zones, %d stations and %d stops" % (len(zone_ids), n_stations, n_stops)) logger.info("Importing transfers...") n_transfers = 0 for transfer in gtfs.transfers(): from_stop_id = transfer.get('from_stop_id') to_stop_id = transfer.get('to_stop_id') transfer['transfer_type'] = _toint(transfer.get('transfer_type'), 0) for stop_id in (from_stop_id, to_stop_id): if stop_id not in station_ids and stop_id not in stop_ids: if lenient: logger.error("Stop ID '%s' in '%s' is invalid, skipping." % (stop_id, transfer)) continue else: raise KeyError("Stop ID '%s' in '%s' is invalid." % (stop_id, transfer)) transfer2 = Transfer(feed_id, **transfer) n_transfers += 1 dao.add(transfer2) dao.flush() logger.info("Imported %d transfers" % (n_transfers)) logger.info("Importing routes...") n_routes = 0 route_ids = set() for route in gtfs.routes(): route['route_type'] = int(route.get('route_type')) agency_id = route.get('agency_id') if (agency_id is None or len(agency_id) == 0) and single_agency is not None: # Route.agency is optional if only a single agency exists. agency_id = route['agency_id'] = single_agency.agency_id if agency_id not in agency_ids: if lenient: logger.error( "Agency ID '%s' in '%s' is invalid, skipping route." % (agency_id, route)) continue else: raise KeyError("agency ID '%s' in '%s' is invalid." % (agency_id, route)) route2 = Route(feed_id, **route) dao.add(route2) route_ids.add(route2.route_id) n_routes += 1 dao.flush() logger.info("Imported %d routes" % n_routes) logger.info("Importing fares...") n_fares = 0 for fare_attr in gtfs.fare_attributes(): fare_id = fare_attr.get('fare_id') fare_price = _tofloat(fare_attr.get('price')) currency_type = fare_attr.get('currency_type') payment_method = _toint(fare_attr.get('payment_method')) n_transfers = None if fare_attr.get('transfers') is not None: n_transfers = _toint(fare_attr.get('transfers')) transfer_duration = None if fare_attr.get('transfer_duration') is not None: transfer_duration = _toint(fare_attr.get('transfer_duration')) fare = FareAttribute(feed_id, fare_id, fare_price, currency_type, payment_method, n_transfers, transfer_duration) dao.add(fare) n_fares += 1 dao.flush() fare_rules = set() for fare_rule in gtfs.fare_rules(): fare_rule2 = FareRule(feed_id, **fare_rule) if fare_rule2 in fare_rules: if lenient: logger.error("Duplicated fare rule (%s), skipping." % (fare_rule2)) continue else: raise KeyError("Duplicated fare rule (%s)" % (fare_rule2)) dao.add(fare_rule2) fare_rules.add(fare_rule2) dao.flush() logger.info("Imported %d fare and %d rules" % (n_fares, len(fare_rules))) logger.info("Importing calendars...") calanddates2 = {} for calendar in gtfs.calendars(): calid = calendar.get('service_id') calendar2 = Calendar(feed_id, calid) dates2 = [] start_date = CalendarDate.fromYYYYMMDD(calendar.get('start_date')) end_date = CalendarDate.fromYYYYMMDD(calendar.get('end_date')) for d in CalendarDate.range(start_date, end_date.next_day()): if int(calendar.get(DOW_NAMES[d.dow()])): dates2.append(d) calanddates2[calid] = (calendar2, set(dates2)) logger.info("Normalizing calendar dates...") for caldate in gtfs.calendar_dates(): calid = caldate.get('service_id') date2 = CalendarDate.fromYYYYMMDD(caldate.get('date')) addremove = int(caldate.get('exception_type')) if calid in calanddates2: calendar2, dates2 = calanddates2[calid] else: calendar2 = Calendar(feed_id, calid) dates2 = set([]) calanddates2[calid] = (calendar2, dates2) if addremove == 1: dates2.add(date2) elif addremove == 2: if date2 in dates2: dates2.remove(date2) n_calendars = 0 n_caldates = 0 calendar_ids = set() for (calendar2, dates2) in calanddates2.values(): calendar2.dates = [d for d in dates2] dao.add(calendar2) calendar_ids.add(calendar2.service_id) n_calendars += 1 n_caldates += len(calendar2.dates) dao.flush() logger.info("Imported %d calendars and %d dates" % (n_calendars, n_caldates)) logger.info("Importing shapes...") n_shape_pts = 0 shape_ids = set() shapepts_q = [] for shpt in gtfs.shapes(): shape_id = shpt.get('shape_id') if shape_id not in shape_ids: dao.add(Shape(feed_id, shape_id)) dao.flush() shape_ids.add(shape_id) pt_seq = _toint(shpt.get('shape_pt_sequence')) # This field is optional dist_traveled = _tofloat(shpt.get('shape_dist_traveled'), -999999) lat = _tofloat(shpt.get('shape_pt_lat')) lon = _tofloat(shpt.get('shape_pt_lon')) shape_point = ShapePoint(feed_id, shape_id, pt_seq, lat, lon, dist_traveled) shapepts_q.append(shape_point) n_shape_pts += 1 if n_shape_pts % 100000 == 0: logger.info("%d shape points" % n_shape_pts) dao.bulk_save_objects(shapepts_q) dao.flush() shapepts_q = [] dao.bulk_save_objects(shapepts_q) dao.flush() logger.info("Imported %d shapes and %d points" % (len(shape_ids), n_shape_pts)) logger.info("Importing trips...") n_trips = 0 trips_q = [] trip_ids = set() for trip in gtfs.trips(): trip['wheelchair_accessible'] = _toint( trip.get('wheelchair_accessible'), Trip.WHEELCHAIR_UNKNOWN) trip['bikes_allowed'] = _toint(trip.get('bikes_allowed'), Trip.BIKES_UNKNOWN) cal_id = trip.get('service_id') if cal_id not in calendar_ids: if lenient: logger.error( "Calendar ID '%s' in '%s' is invalid. Skipping trip." % (cal_id, trip)) continue else: raise KeyError("Calendar ID '%s' in '%s' is invalid." % (cal_id, trip)) route_id = trip.get('route_id') if route_id not in route_ids: if lenient: logger.error( "Route ID '%s' in '%s' is invalid. Skipping trip." % (route_id, trip)) continue else: raise KeyError("Route ID '%s' in trip '%s' is invalid." % (route_id, trip)) trip2 = Trip(feed_id, frequency_generated=False, **trip) trips_q.append(trip2) n_trips += 1 if n_trips % 10000 == 0: dao.bulk_save_objects(trips_q) dao.flush() logger.info('%s trips' % n_trips) trips_q = [] trip_ids.add(trip.get('trip_id')) dao.bulk_save_objects(trips_q) dao.flush() logger.info("Imported %d trips" % n_trips) logger.info("Importing stop times...") n_stoptimes = 0 stoptimes_q = [] for stoptime in gtfs.stop_times(): stopseq = _toint(stoptime.get('stop_sequence')) # Mark times to interpolate later on arrtime = _timetoint(stoptime.get('arrival_time'), -999999) deptime = _timetoint(stoptime.get('departure_time'), -999999) if arrtime == -999999: arrtime = deptime if deptime == -999999: deptime = arrtime interp = arrtime < 0 and deptime < 0 shpdist = _tofloat(stoptime.get('shape_dist_traveled'), -999999) pkptype = _toint(stoptime.get('pickup_type'), StopTime.PICKUP_DROPOFF_REGULAR) drptype = _toint(stoptime.get('drop_off_type'), StopTime.PICKUP_DROPOFF_REGULAR) trip_id = stoptime.get('trip_id') if trip_id not in trip_ids: if lenient: logger.error( "Trip ID '%s' in '%s' is invalid. Skipping stop time." % (trip_id, stoptime)) continue else: raise KeyError("Trip ID '%s' in '%s' is invalid." % (trip_id, stoptime)) stop_id = stoptime.get('stop_id') if stop_id not in stop_ids: if lenient: logger.error( "Stop ID '%s' in '%s' is invalid. Skipping stop time." % (stop_id, stoptime)) continue else: raise KeyError("Trip ID '%s' in stoptime '%s' is invalid." % (stop_id, stoptime)) stoptime2 = StopTime(feed_id, trip_id, stop_id, stop_sequence=stopseq, arrival_time=arrtime, departure_time=deptime, shape_dist_traveled=shpdist, interpolated=interp, pickup_type=pkptype, drop_off_type=drptype, stop_headsign=stoptime.get('stop_headsign')) stoptimes_q.append(stoptime2) n_stoptimes += 1 # Commit every now and then if n_stoptimes % 50000 == 0: logger.info("%d stop times" % n_stoptimes) dao.bulk_save_objects(stoptimes_q) dao.flush() stoptimes_q = [] dao.bulk_save_objects(stoptimes_q) logger.info("Imported %d stop times" % n_stoptimes) logger.info("Committing") dao.flush() # TODO Add option to enable/disable this commit # to ensure import is transactionnal dao.commit() logger.info("Commit done") def normalize_trip(trip, odometer): stopseq = 0 n_stoptimes = len(trip.stop_times) last_stoptime_with_time = None to_interpolate = [] odometer.reset() for stoptime in trip.stop_times: stoptime.stop_sequence = stopseq stoptime.shape_dist_traveled = odometer.dist_traveled( stoptime.stop, stoptime.shape_dist_traveled if stoptime.shape_dist_traveled != -999999 else None) if stopseq == 0: # Force first arrival time to NULL stoptime.arrival_time = None if stopseq == n_stoptimes - 1: # Force last departure time to NULL stoptime.departure_time = None if stoptime.interpolated: to_interpolate.append(stoptime) else: if len(to_interpolate) > 0: # Interpolate if last_stoptime_with_time is None: logger.error( "Cannot interpolate missing time at trip start: %s" % trip) for stti in to_interpolate: # Use first defined time as fallback value. stti.arrival_time = stoptime.arrival_time stti.departure_time = stoptime.arrival_time else: tdist = stoptime.shape_dist_traveled - last_stoptime_with_time.shape_dist_traveled ttime = stoptime.arrival_time - last_stoptime_with_time.departure_time for stti in to_interpolate: fdist = stti.shape_dist_traveled - last_stoptime_with_time.shape_dist_traveled t = last_stoptime_with_time.departure_time + ttime * fdist // tdist stti.arrival_time = t stti.departure_time = t to_interpolate = [] last_stoptime_with_time = stoptime stopseq += 1 if len(to_interpolate) > 0: # Should not happen, but handle the case, we never know if last_stoptime_with_time is None: logger.error( "Cannot interpolate missing time, no time at all: %s" % trip) # Keep times NULL (TODO: or remove the trip?) else: logger.error( "Cannot interpolate missing time at trip end: %s" % trip) for stti in to_interpolate: # Use last defined time as fallback value stti.arrival_time = last_stoptime_with_time.departure_time stti.departure_time = last_stoptime_with_time.departure_time if disable_normalization: logger.info("Skipping shapes and trips normalization") else: logger.info("Normalizing shapes and trips...") nshapes = 0 ntrips = 0 odometer = _Odometer() # Process shapes and associated trips for shape in dao.shapes(fltr=Shape.feed_id == feed_id, prefetch_points=True, batch_size=50): # Shape will be registered in the normalize odometer.normalize_and_register_shape(shape) for trip in dao.trips(fltr=(Trip.feed_id == feed_id) & (Trip.shape_id == shape.shape_id), prefetch_stop_times=True, prefetch_stops=True, batch_size=800): normalize_trip(trip, odometer) ntrips += 1 if ntrips % 1000 == 0: logger.info("%d trips, %d shapes" % (ntrips, nshapes)) dao.flush() nshapes += 1 # odometer._debug_cache() # Process trips w/o shapes for trip in dao.trips(fltr=(Trip.feed_id == feed_id) & (Trip.shape_id == None), prefetch_stop_times=True, prefetch_stops=True, batch_size=800): odometer.register_noshape() normalize_trip(trip, odometer) ntrips += 1 if ntrips % 1000 == 0: logger.info("%d trips" % ntrips) dao.flush() dao.flush() logger.info("Normalized %d trips and %d shapes" % (ntrips, nshapes)) # Note: we expand frequencies *after* normalization # for performances purpose only: that minimize the # number of trips to normalize. We can do that since # the expansion is neutral trip-normalization-wise. logger.info("Expanding frequencies...") n_freq = 0 n_exp_trips = 0 trips_to_delete = [] for frequency in gtfs.frequencies(): trip_id = frequency.get('trip_id') if trip_id not in trip_ids: if lenient: logger.error( "Trip ID '%s' in '%s' is invalid. Skipping frequency." % (trip_id, frequency)) continue else: raise KeyError("Trip ID '%s' in '%s' is invalid." % (trip_id, frequency)) trip = dao.trip(trip_id, feed_id=feed_id) start_time = _timetoint(frequency.get('start_time')) end_time = _timetoint(frequency.get('end_time')) headway_secs = _toint(frequency.get('headway_secs')) exact_times = _toint(frequency.get('exact_times'), Trip.TIME_APPROX) for trip_dep_time in range(start_time, end_time, headway_secs): # Here we assume departure time are all different. # That's a requirement in the GTFS specs, but this may break. # TODO Make the expanded trip ID generation parametrable. trip_id2 = trip.trip_id + "@" + fmttime(trip_dep_time) trip2 = Trip(feed_id, trip_id2, trip.route_id, trip.service_id, wheelchair_accessible=trip.wheelchair_accessible, bikes_allowed=trip.bikes_allowed, exact_times=exact_times, frequency_generated=True, trip_headsign=trip.trip_headsign, trip_short_name=trip.trip_short_name, direction_id=trip.direction_id, block_id=trip.block_id) trip2.stop_times = [] base_time = trip.stop_times[0].departure_time for stoptime in trip.stop_times: arrtime = None if stoptime.arrival_time is None else stoptime.arrival_time - base_time + trip_dep_time deptime = None if stoptime.departure_time is None else stoptime.departure_time - base_time + trip_dep_time stoptime2 = StopTime( feed_id, trip_id2, stoptime.stop_id, stoptime.stop_sequence, arrival_time=arrtime, departure_time=deptime, shape_dist_traveled=stoptime.shape_dist_traveled, interpolated=stoptime.interpolated, timepoint=stoptime.timepoint, pickup_type=stoptime.pickup_type, drop_off_type=stoptime.drop_off_type) trip2.stop_times.append(stoptime2) n_exp_trips += 1 # This will add the associated stop times dao.add(trip2) # Do not delete trip now, as two frequency can refer to same trip trips_to_delete.append(trip) n_freq += 1 for trip in trips_to_delete: # This also delete the associated stop times dao.delete(trip) dao.flush() dao.commit() logger.info("Expanded %d frequencies to %d trips." % (n_freq, n_exp_trips)) logger.info("Feed '%s': import done." % feed_id)