Ejemplo n.º 1
0
    def test_entities_creation(self):
        dao = Dao()
        f1 = FeedInfo("F1")
        a1 = Agency("F1",
                    "A1",
                    "Agency 1",
                    agency_url="http://www.agency.fr/",
                    agency_timezone="Europe/Paris")
        r1 = Route("F1",
                   "R1",
                   "A1",
                   3,
                   route_short_name="R1",
                   route_long_name="Route 1")
        r2 = Route("F1", "R2", "A1", 3, route_short_name="R2")
        c1 = Calendar("F1", "C1")
        c1.dates = [
            CalendarDate.ymd(2015, 11, 13),
            CalendarDate.ymd(2015, 11, 14)
        ]
        dao.add_all([f1, a1, r1, r2, c1])

        self.assertTrue(len(dao.feeds()) == 1)
        self.assertTrue(len(dao.agencies()) == 1)
        a1b = dao.agency("A1", feed_id="F1", prefetch_routes=True)
        self.assertTrue(a1b.agency_name == "Agency 1")
        self.assertTrue(len(a1b.routes) == 2)
        r1b = dao.route("R1", feed_id="F1")
        self.assertTrue(r1b.route_short_name == "R1")
        self.assertTrue(r1b.route_long_name == "Route 1")
        self.assertTrue(r1b.route_type == 3)
        r42 = dao.route("R42", feed_id="F1")
        self.assertTrue(r42 is None)
Ejemplo n.º 2
0
    def test_stop_station(self):
        dao = Dao()
        f1 = FeedInfo("F1")
        sa = Stop("F1", "SA", "Station A", 45.0000, 0.0000)
        sa1 = Stop("F1", "SA1", "Stop A1", 45.0001, 0.0001)
        sa1.parent_station_id = 'SA'
        sb = Stop("F1", "SB", "Station B", 45.0002, 0.0002)
        sb1 = Stop("F1", "SB1", "Stop B1", 45.0003, 0.0003)
        sb1.parent_station_id = 'SB'
        sb2 = Stop("F1", "SB2", "Stop B2", 45.0002, 0.0003)
        sb2.parent_station_id = 'SB'
        a1 = Agency("F1", "A1", "Agency 1", "url1", "Europe/Paris")
        r1 = Route("F1", "R1", "A1", Route.TYPE_BUS)
        c1 = Calendar("F1", "C1")
        t1 = Trip("F1", "T1", "R1", "C1")
        st1a = StopTime("F1", "T1", "SA1", 0, None, 3600, 0.0)
        st1b = StopTime("F1", "T1", "SB1", 1, 3800, None, 100.0)
        dao.add_all([f1, sa, sa1, sb, sb1, sb2, a1, r1, c1, t1, st1a, st1b])

        stops = list(dao.stops(fltr=(Agency.agency_id == 'A1')))
        self.assertTrue(len(stops) == 2)
        self.assertTrue(sa1 in stops)
        self.assertTrue(sb1 in stops)

        stops = list(dao.stops(fltr=(Stop.parent_station_id == 'SA')))
        self.assertTrue(len(stops) == 1)
        self.assertTrue(sa1 in stops)

        stops = list(dao.stops(fltr=(Stop.parent_station_id == 'SB')))
        self.assertTrue(len(stops) == 2)
        self.assertTrue(sb1 in stops)
        self.assertTrue(sb2 in stops)
Ejemplo n.º 3
0
    def test_route_agency_multi_feed(self):
        dao = Dao()
        fa = FeedInfo("FA")
        aa1 = Agency("FA",
                     "A",
                     "Agency A",
                     agency_url="http://www.agency.fr/",
                     agency_timezone="Europe/Paris")
        ar1 = Route("FA",
                    "R",
                    "A",
                    3,
                    route_short_name="RA",
                    route_long_name="Route A")
        ar2 = Route("FA",
                    "R2",
                    "A",
                    3,
                    route_short_name="RA2",
                    route_long_name="Route A2")
        fb = FeedInfo("FB")
        ba1 = Agency("FB",
                     "A",
                     "Agency B",
                     agency_url="http://www.agency.fr/",
                     agency_timezone="Europe/Paris")
        br1 = Route("FB",
                    "R",
                    "A",
                    3,
                    route_short_name="RB",
                    route_long_name="Route B")
        dao.add_all([fa, aa1, ar1, ar2, fb, ba1, br1])

        fa = dao.feed("FA")
        self.assertTrue(len(fa.agencies) == 1)
        for a in fa.agencies:
            self.assertTrue(a.agency_name == "Agency A")
        self.assertTrue(len(fa.routes) == 2)
        for r in fa.routes:
            self.assertTrue(r.route_short_name.startswith("RA"))
            self.assertTrue(r.agency.agency_name == "Agency A")
Ejemplo n.º 4
0
    def test_shapes(self):
        dao = Dao()
        f1 = FeedInfo("")
        a1 = Agency("",
                    "A1",
                    "Agency 1",
                    agency_url="http://www.agency.fr/",
                    agency_timezone="Europe/Paris")
        r1 = Route("",
                   "R1",
                   "A1",
                   3,
                   route_short_name="R1",
                   route_long_name="Route 1")
        c1 = Calendar("", "C1")
        c1.dates = [
            CalendarDate.ymd(2016, 1, 31),
            CalendarDate.ymd(2016, 2, 1)
        ]
        s1 = Stop("", "S1", "Stop 1", 45.0, 0.0)
        s2 = Stop("", "S2", "Stop 2", 45.1, 0.1)
        s3 = Stop("", "S3", "Stop 3", 45.2, 0.2)
        t1 = Trip("", "T1", "R1", "C1")
        t1.stop_times = [
            StopTime(None, None, "S1", 0, 28800, 28800, 0.0),
            StopTime(None, None, "S2", 1, 29400, 29400, 2.0),
            StopTime(None, None, "S3", 2, 30000, 30000, 4.0)
        ]
        t2 = Trip("", "T2", "R1", "C1")
        t2.stop_times = [
            StopTime(None, None, "S2", 0, 30600, 30600, 0.0),
            StopTime(None, None, "S1", 1, 31000, 31000, 1.0)
        ]
        sh1 = Shape("", "Sh1")
        sh1.points = [
            ShapePoint(None, None, 0, 45.00, 0.00, 0.0),
            ShapePoint(None, None, 1, 45.05, 0.10, 1.0),
            ShapePoint(None, None, 2, 45.10, 0.10, 2.0),
            ShapePoint(None, None, 3, 45.15, 0.20, 3.0),
            ShapePoint(None, None, 4, 45.20, 0.20, 4.0)
        ]
        t1.shape = sh1
        dao.add_all([f1, a1, r1, c1, s1, s2, s3, t1, t2, sh1])
        dao.commit()

        t = dao.trip("T1")
        self.assertTrue(t.shape.shape_id == "Sh1")
        self.assertTrue(len(t.shape.points) == 5)
        t = dao.trip("T2")
        self.assertTrue(t.shape == None)
Ejemplo n.º 5
0
def _convert_gtfs_model(feed_id,
                        gtfs,
                        dao,
                        lenient=False,
                        disable_normalization=False):

    feedinfo2 = None
    logger.info("Importing feed ID '%s'" % feed_id)
    n_feedinfo = 0
    for feedinfo in gtfs.feedinfo():
        n_feedinfo += 1
        if n_feedinfo > 1:
            logger.error(
                "Feed info should be unique if defined. Taking first one." %
                (n_feedinfo))
            break
        # TODO Automatically compute from calendar range if missing?
        feedinfo['feed_start_date'] = _todate(feedinfo.get('feed_start_date'))
        feedinfo['feed_end_date'] = _todate(feedinfo.get('feed_end_date'))
        feedinfo2 = FeedInfo(feed_id, **feedinfo)
    if feedinfo2 is None:
        # Optional, generate empty feed info
        feedinfo2 = FeedInfo(feed_id)
    dao.add(feedinfo2)
    dao.flush()
    logger.info("Imported %d feedinfo" % n_feedinfo)

    logger.info("Importing agencies...")
    n_agencies = 0
    single_agency = None
    agency_ids = set()
    for agency in gtfs.agencies():
        # agency_id is optional only if we have a single agency
        if n_agencies == 0 and agency.get('agency_id') is None:
            agency['agency_id'] = ''
        agency2 = Agency(feed_id, **agency)
        if n_agencies == 0:
            single_agency = agency2
        else:
            single_agency = None
        n_agencies += 1
        dao.add(agency2)
        agency_ids.add(agency2.agency_id)
    dao.flush()
    logger.info("Imported %d agencies" % n_agencies)

    def import_stop(stop, stoptype, zone_ids, item_ids, station_ids=None):
        zone_id = stop.get('zone_id')
        if zone_id and zone_id not in zone_ids:
            # Lazy-creation of zone
            zone = Zone(feed_id, zone_id)
            zone_ids.add(zone_id)
            dao.add(zone)
        stop['location_type'] = _toint(stop.get('location_type'),
                                       Stop.TYPE_STOP)
        if stop['location_type'] != stoptype:
            return 0
        stop['wheelchair_boarding'] = _toint(stop.get('wheelchair_boarding'),
                                             Stop.WHEELCHAIR_UNKNOWN)
        lat = _tofloat(stop.get('stop_lat'), None)
        lon = _tofloat(stop.get('stop_lon'), None)
        if lat is None or lon is None:
            if lenient:
                logger.error("Missing lat/lon for '%s', set to default (0,0)" %
                             (stop, ))
                if lat is None:
                    lat = 0
                if lon is None:
                    lon = 0
            else:
                raise ValueError("Missing mandatory lat/lon for '%s'." %
                                 (stop, ))
        stop['stop_lat'] = lat
        stop['stop_lon'] = lon
        # This field has been renamed for consistency
        parent_id = stop.get('parent_station')
        stop['parent_station_id'] = parent_id if parent_id else None
        if parent_id and station_ids and parent_id not in station_ids:
            if lenient:
                logger.error(
                    "Parent station ID '%s' in '%s' is invalid, resetting." %
                    (parent_id, stop))
                stop['parent_station_id'] = None
            else:
                raise KeyError("Parent station ID '%s' in '%s' is invalid." %
                               (parent_id, stop))
        stop.pop('parent_station', None)
        stop2 = Stop(feed_id, **stop)
        dao.add(stop2)
        item_ids.add(stop2.stop_id)
        return 1

    stop_ids = set()
    station_ids = set()
    zone_ids = set()
    logger.info("Importing zones, stations and stops...")
    n_stations = n_stops = 0
    for station in gtfs.stops():
        n_stations += import_stop(station, Stop.TYPE_STATION, zone_ids,
                                  station_ids)
    for stop in gtfs.stops():
        n_stops += import_stop(stop, Stop.TYPE_STOP, zone_ids, stop_ids,
                               station_ids)
    dao.flush()
    logger.info("Imported %d zones, %d stations and %d stops" %
                (len(zone_ids), n_stations, n_stops))

    logger.info("Importing transfers...")
    n_transfers = 0
    for transfer in gtfs.transfers():
        from_stop_id = transfer.get('from_stop_id')
        to_stop_id = transfer.get('to_stop_id')
        transfer['transfer_type'] = _toint(transfer.get('transfer_type'), 0)
        for stop_id in (from_stop_id, to_stop_id):
            if stop_id not in station_ids and stop_id not in stop_ids:
                if lenient:
                    logger.error("Stop ID '%s' in '%s' is invalid, skipping." %
                                 (stop_id, transfer))
                    continue
                else:
                    raise KeyError("Stop ID '%s' in '%s' is invalid." %
                                   (stop_id, transfer))
        transfer2 = Transfer(feed_id, **transfer)
        n_transfers += 1
        dao.add(transfer2)
    dao.flush()
    logger.info("Imported %d transfers" % (n_transfers))

    logger.info("Importing routes...")
    n_routes = 0
    route_ids = set()
    for route in gtfs.routes():
        route['route_type'] = int(route.get('route_type'))
        agency_id = route.get('agency_id')
        if (agency_id is None
                or len(agency_id) == 0) and single_agency is not None:
            # Route.agency is optional if only a single agency exists.
            agency_id = route['agency_id'] = single_agency.agency_id
        if agency_id not in agency_ids:
            if lenient:
                logger.error(
                    "Agency ID '%s' in '%s' is invalid, skipping route." %
                    (agency_id, route))
                continue
            else:
                raise KeyError("agency ID '%s' in '%s' is invalid." %
                               (agency_id, route))
        route2 = Route(feed_id, **route)
        dao.add(route2)
        route_ids.add(route2.route_id)
        n_routes += 1
    dao.flush()
    logger.info("Imported %d routes" % n_routes)

    logger.info("Importing fares...")
    n_fares = 0
    for fare_attr in gtfs.fare_attributes():
        fare_id = fare_attr.get('fare_id')
        fare_price = _tofloat(fare_attr.get('price'))
        currency_type = fare_attr.get('currency_type')
        payment_method = _toint(fare_attr.get('payment_method'))
        n_transfers = None
        if fare_attr.get('transfers') is not None:
            n_transfers = _toint(fare_attr.get('transfers'))
        transfer_duration = None
        if fare_attr.get('transfer_duration') is not None:
            transfer_duration = _toint(fare_attr.get('transfer_duration'))
        fare = FareAttribute(feed_id, fare_id, fare_price, currency_type,
                             payment_method, n_transfers, transfer_duration)
        dao.add(fare)
        n_fares += 1
    dao.flush()
    fare_rules = set()
    for fare_rule in gtfs.fare_rules():
        fare_rule2 = FareRule(feed_id, **fare_rule)
        if fare_rule2 in fare_rules:
            if lenient:
                logger.error("Duplicated fare rule (%s), skipping." %
                             (fare_rule2))
                continue
            else:
                raise KeyError("Duplicated fare rule (%s)" % (fare_rule2))
        dao.add(fare_rule2)
        fare_rules.add(fare_rule2)
    dao.flush()
    logger.info("Imported %d fare and %d rules" % (n_fares, len(fare_rules)))

    logger.info("Importing calendars...")
    calanddates2 = {}
    for calendar in gtfs.calendars():
        calid = calendar.get('service_id')
        calendar2 = Calendar(feed_id, calid)
        dates2 = []
        start_date = CalendarDate.fromYYYYMMDD(calendar.get('start_date'))
        end_date = CalendarDate.fromYYYYMMDD(calendar.get('end_date'))
        for d in CalendarDate.range(start_date, end_date.next_day()):
            if int(calendar.get(DOW_NAMES[d.dow()])):
                dates2.append(d)
        calanddates2[calid] = (calendar2, set(dates2))

    logger.info("Normalizing calendar dates...")
    for caldate in gtfs.calendar_dates():
        calid = caldate.get('service_id')
        date2 = CalendarDate.fromYYYYMMDD(caldate.get('date'))
        addremove = int(caldate.get('exception_type'))
        if calid in calanddates2:
            calendar2, dates2 = calanddates2[calid]
        else:
            calendar2 = Calendar(feed_id, calid)
            dates2 = set([])
            calanddates2[calid] = (calendar2, dates2)
        if addremove == 1:
            dates2.add(date2)
        elif addremove == 2:
            if date2 in dates2:
                dates2.remove(date2)
    n_calendars = 0
    n_caldates = 0
    calendar_ids = set()
    for (calendar2, dates2) in calanddates2.values():
        calendar2.dates = [d for d in dates2]
        dao.add(calendar2)
        calendar_ids.add(calendar2.service_id)
        n_calendars += 1
        n_caldates += len(calendar2.dates)
    dao.flush()
    logger.info("Imported %d calendars and %d dates" %
                (n_calendars, n_caldates))

    logger.info("Importing shapes...")
    n_shape_pts = 0
    shape_ids = set()
    shapepts_q = []
    for shpt in gtfs.shapes():
        shape_id = shpt.get('shape_id')
        if shape_id not in shape_ids:
            dao.add(Shape(feed_id, shape_id))
            dao.flush()
            shape_ids.add(shape_id)
        pt_seq = _toint(shpt.get('shape_pt_sequence'))
        # This field is optional
        dist_traveled = _tofloat(shpt.get('shape_dist_traveled'), -999999)
        lat = _tofloat(shpt.get('shape_pt_lat'))
        lon = _tofloat(shpt.get('shape_pt_lon'))
        shape_point = ShapePoint(feed_id, shape_id, pt_seq, lat, lon,
                                 dist_traveled)
        shapepts_q.append(shape_point)
        n_shape_pts += 1
        if n_shape_pts % 100000 == 0:
            logger.info("%d shape points" % n_shape_pts)
            dao.bulk_save_objects(shapepts_q)
            dao.flush()
            shapepts_q = []
    dao.bulk_save_objects(shapepts_q)
    dao.flush()
    logger.info("Imported %d shapes and %d points" %
                (len(shape_ids), n_shape_pts))

    logger.info("Importing trips...")
    n_trips = 0
    trips_q = []
    trip_ids = set()
    for trip in gtfs.trips():
        trip['wheelchair_accessible'] = _toint(
            trip.get('wheelchair_accessible'), Trip.WHEELCHAIR_UNKNOWN)
        trip['bikes_allowed'] = _toint(trip.get('bikes_allowed'),
                                       Trip.BIKES_UNKNOWN)
        cal_id = trip.get('service_id')
        if cal_id not in calendar_ids:
            if lenient:
                logger.error(
                    "Calendar ID '%s' in '%s' is invalid. Skipping trip." %
                    (cal_id, trip))
                continue
            else:
                raise KeyError("Calendar ID '%s' in '%s' is invalid." %
                               (cal_id, trip))
        route_id = trip.get('route_id')
        if route_id not in route_ids:
            if lenient:
                logger.error(
                    "Route ID '%s' in '%s' is invalid. Skipping trip." %
                    (route_id, trip))
                continue
            else:
                raise KeyError("Route ID '%s' in trip '%s' is invalid." %
                               (route_id, trip))
        trip2 = Trip(feed_id, frequency_generated=False, **trip)

        trips_q.append(trip2)
        n_trips += 1
        if n_trips % 10000 == 0:
            dao.bulk_save_objects(trips_q)
            dao.flush()
            logger.info('%s trips' % n_trips)
            trips_q = []

        trip_ids.add(trip.get('trip_id'))
    dao.bulk_save_objects(trips_q)
    dao.flush()

    logger.info("Imported %d trips" % n_trips)

    logger.info("Importing stop times...")
    n_stoptimes = 0
    stoptimes_q = []
    for stoptime in gtfs.stop_times():
        stopseq = _toint(stoptime.get('stop_sequence'))
        # Mark times to interpolate later on
        arrtime = _timetoint(stoptime.get('arrival_time'), -999999)
        deptime = _timetoint(stoptime.get('departure_time'), -999999)
        if arrtime == -999999:
            arrtime = deptime
        if deptime == -999999:
            deptime = arrtime
        interp = arrtime < 0 and deptime < 0
        shpdist = _tofloat(stoptime.get('shape_dist_traveled'), -999999)
        pkptype = _toint(stoptime.get('pickup_type'),
                         StopTime.PICKUP_DROPOFF_REGULAR)
        drptype = _toint(stoptime.get('drop_off_type'),
                         StopTime.PICKUP_DROPOFF_REGULAR)
        trip_id = stoptime.get('trip_id')
        if trip_id not in trip_ids:
            if lenient:
                logger.error(
                    "Trip ID '%s' in '%s' is invalid. Skipping stop time." %
                    (trip_id, stoptime))
                continue
            else:
                raise KeyError("Trip ID '%s' in '%s' is invalid." %
                               (trip_id, stoptime))
        stop_id = stoptime.get('stop_id')
        if stop_id not in stop_ids:
            if lenient:
                logger.error(
                    "Stop ID '%s' in '%s' is invalid. Skipping stop time." %
                    (stop_id, stoptime))
                continue
            else:
                raise KeyError("Trip ID '%s' in stoptime '%s' is invalid." %
                               (stop_id, stoptime))
        stoptime2 = StopTime(feed_id,
                             trip_id,
                             stop_id,
                             stop_sequence=stopseq,
                             arrival_time=arrtime,
                             departure_time=deptime,
                             shape_dist_traveled=shpdist,
                             interpolated=interp,
                             pickup_type=pkptype,
                             drop_off_type=drptype,
                             stop_headsign=stoptime.get('stop_headsign'))
        stoptimes_q.append(stoptime2)
        n_stoptimes += 1
        # Commit every now and then
        if n_stoptimes % 50000 == 0:
            logger.info("%d stop times" % n_stoptimes)
            dao.bulk_save_objects(stoptimes_q)
            dao.flush()
            stoptimes_q = []
    dao.bulk_save_objects(stoptimes_q)

    logger.info("Imported %d stop times" % n_stoptimes)
    logger.info("Committing")
    dao.flush()
    # TODO Add option to enable/disable this commit
    # to ensure import is transactionnal
    dao.commit()
    logger.info("Commit done")

    def normalize_trip(trip, odometer):
        stopseq = 0
        n_stoptimes = len(trip.stop_times)
        last_stoptime_with_time = None
        to_interpolate = []
        odometer.reset()
        for stoptime in trip.stop_times:
            stoptime.stop_sequence = stopseq
            stoptime.shape_dist_traveled = odometer.dist_traveled(
                stoptime.stop, stoptime.shape_dist_traveled
                if stoptime.shape_dist_traveled != -999999 else None)
            if stopseq == 0:
                # Force first arrival time to NULL
                stoptime.arrival_time = None
            if stopseq == n_stoptimes - 1:
                # Force last departure time to NULL
                stoptime.departure_time = None
            if stoptime.interpolated:
                to_interpolate.append(stoptime)
            else:
                if len(to_interpolate) > 0:
                    # Interpolate
                    if last_stoptime_with_time is None:
                        logger.error(
                            "Cannot interpolate missing time at trip start: %s"
                            % trip)
                        for stti in to_interpolate:
                            # Use first defined time as fallback value.
                            stti.arrival_time = stoptime.arrival_time
                            stti.departure_time = stoptime.arrival_time
                    else:
                        tdist = stoptime.shape_dist_traveled - last_stoptime_with_time.shape_dist_traveled
                        ttime = stoptime.arrival_time - last_stoptime_with_time.departure_time
                        for stti in to_interpolate:
                            fdist = stti.shape_dist_traveled - last_stoptime_with_time.shape_dist_traveled
                            t = last_stoptime_with_time.departure_time + ttime * fdist // tdist
                            stti.arrival_time = t
                            stti.departure_time = t
                to_interpolate = []
                last_stoptime_with_time = stoptime
            stopseq += 1
        if len(to_interpolate) > 0:
            # Should not happen, but handle the case, we never know
            if last_stoptime_with_time is None:
                logger.error(
                    "Cannot interpolate missing time, no time at all: %s" %
                    trip)
                # Keep times NULL (TODO: or remove the trip?)
            else:
                logger.error(
                    "Cannot interpolate missing time at trip end: %s" % trip)
                for stti in to_interpolate:
                    # Use last defined time as fallback value
                    stti.arrival_time = last_stoptime_with_time.departure_time
                    stti.departure_time = last_stoptime_with_time.departure_time

    if disable_normalization:
        logger.info("Skipping shapes and trips normalization")
    else:
        logger.info("Normalizing shapes and trips...")
        nshapes = 0
        ntrips = 0
        odometer = _Odometer()
        # Process shapes and associated trips
        for shape in dao.shapes(fltr=Shape.feed_id == feed_id,
                                prefetch_points=True,
                                batch_size=50):
            # Shape will be registered in the normalize
            odometer.normalize_and_register_shape(shape)
            for trip in dao.trips(fltr=(Trip.feed_id == feed_id) &
                                  (Trip.shape_id == shape.shape_id),
                                  prefetch_stop_times=False,
                                  prefetch_stops=False,
                                  batch_size=800):
                normalize_trip(trip, odometer)
                ntrips += 1
                if ntrips % 1000 == 0:
                    logger.info("%d trips, %d shapes" % (ntrips, nshapes))
                    dao.flush()
            nshapes += 1
            #odometer._debug_cache()
        # Process trips w/o shapes
        for trip in dao.trips(fltr=(Trip.feed_id == feed_id) &
                              (Trip.shape_id == None),
                              prefetch_stop_times=False,
                              prefetch_stops=False,
                              batch_size=800):
            odometer.register_noshape()
            normalize_trip(trip, odometer)
            ntrips += 1
            if ntrips % 1000 == 0:
                logger.info("%d trips" % ntrips)
                dao.flush()
        dao.flush()
        logger.info("Normalized %d trips and %d shapes" % (ntrips, nshapes))

    # Note: we expand frequencies *after* normalization
    # for performances purpose only: that minimize the
    # number of trips to normalize. We can do that since
    # the expansion is neutral trip-normalization-wise.
    logger.info("Expanding frequencies...")
    n_freq = 0
    n_exp_trips = 0
    trips_to_delete = []
    for frequency in gtfs.frequencies():
        trip_id = frequency.get('trip_id')
        if trip_id not in trip_ids:
            if lenient:
                logger.error(
                    "Trip ID '%s' in '%s' is invalid. Skipping frequency." %
                    (trip_id, frequency))
                continue
            else:
                raise KeyError("Trip ID '%s' in '%s' is invalid." %
                               (trip_id, frequency))
        trip = dao.trip(trip_id, feed_id=feed_id)
        start_time = _timetoint(frequency.get('start_time'))
        end_time = _timetoint(frequency.get('end_time'))
        headway_secs = _toint(frequency.get('headway_secs'))
        exact_times = _toint(frequency.get('exact_times'), Trip.TIME_APPROX)
        for trip_dep_time in range(start_time, end_time, headway_secs):
            # Here we assume departure time are all different.
            # That's a requirement in the GTFS specs, but this may break.
            # TODO Make the expanded trip ID generation parametrable.
            trip_id2 = trip.trip_id + "@" + fmttime(trip_dep_time)
            trip2 = Trip(feed_id,
                         trip_id2,
                         trip.route_id,
                         trip.service_id,
                         wheelchair_accessible=trip.wheelchair_accessible,
                         bikes_allowed=trip.bikes_allowed,
                         exact_times=exact_times,
                         frequency_generated=True,
                         trip_headsign=trip.trip_headsign,
                         trip_short_name=trip.trip_short_name,
                         direction_id=trip.direction_id,
                         block_id=trip.block_id)
            trip2.stop_times = []
            base_time = trip.stop_times[0].departure_time
            for stoptime in trip.stop_times:
                arrtime = None if stoptime.arrival_time is None else stoptime.arrival_time - base_time + trip_dep_time
                deptime = None if stoptime.departure_time is None else stoptime.departure_time - base_time + trip_dep_time
                stoptime2 = StopTime(
                    feed_id,
                    trip_id2,
                    stoptime.stop_id,
                    stoptime.stop_sequence,
                    arrival_time=arrtime,
                    departure_time=deptime,
                    shape_dist_traveled=stoptime.shape_dist_traveled,
                    interpolated=stoptime.interpolated,
                    timepoint=stoptime.timepoint,
                    pickup_type=stoptime.pickup_type,
                    drop_off_type=stoptime.drop_off_type)
                trip2.stop_times.append(stoptime2)
            n_exp_trips += 1
            # This will add the associated stop times
            dao.add(trip2)
        # Do not delete trip now, as two frequency can refer to same trip
        trips_to_delete.append(trip)
        n_freq += 1
    for trip in trips_to_delete:
        # This also delete the associated stop times
        dao.delete(trip)
    dao.flush()
    dao.commit()
    logger.info("Expanded %d frequencies to %d trips." % (n_freq, n_exp_trips))

    logger.info("Feed '%s': import done." % feed_id)
Ejemplo n.º 6
0
    def test_fares(self):
        dao = Dao()
        f1 = FeedInfo("")
        a1 = Agency("",
                    "A1",
                    "Agency 1",
                    agency_url="http://www.agency.fr/",
                    agency_timezone="Europe/Paris")
        r1 = Route("",
                   "R1",
                   "A1",
                   3,
                   route_short_name="R1",
                   route_long_name="Route 1")
        r2 = Route("",
                   "R2",
                   "A1",
                   3,
                   route_short_name="R2",
                   route_long_name="Route 2")
        z1 = Zone("", "Z1")
        z2 = Zone("", "Z2")
        fare1 = FareAttribute("", "F1", 1.0, "EUR",
                              FareAttribute.PAYMENT_ONBOARD, None, None)
        fare2 = FareAttribute("", "F2", 2.0, "EUR",
                              FareAttribute.PAYMENT_BEFOREBOARDING, 3, 3600)
        rule1 = FareRule("", "F1", route_id="R1")
        rule2 = FareRule("", "F1", origin_id="Z1", destination_id="Z2")
        dao.add_all([f1, a1, r1, r2, z1, z2, fare1, fare2, rule1, rule2])
        dao.commit()

        self.assertTrue(len(dao.fare_attributes()) == 2)
        self.assertTrue(len(dao.fare_rules(fltr=(FareRule.route == r1))) == 1)
        self.assertTrue(len(dao.fare_rules(fltr=(FareRule.route == r2))) == 0)
        self.assertTrue(len(dao.fare_rules(fltr=(FareRule.origin == z1))) == 1)
        fare = dao.fare_attribute("F1")
        self.assertTrue(len(fare.fare_rules) == 2)
        fare = dao.fare_attribute("F2")
        self.assertTrue(len(fare.fare_rules) == 0)

        # Test equivalence and hash on primary keys for rule
        fr1a = FareRule("",
                        "F1",
                        route_id="R",
                        origin_id="ZO",
                        destination_id="ZD",
                        contains_id=None)
        fr1b = FareRule("",
                        "F1",
                        route_id="R",
                        origin_id="ZO",
                        destination_id="ZD",
                        contains_id=None)
        fr2 = FareRule("",
                       "F1",
                       route_id="R",
                       origin_id="ZO",
                       destination_id=None,
                       contains_id="ZD")
        ruleset = set()
        ruleset.add(fr1a)
        ruleset.add(fr2)
        ruleset.add(fr1b)
        self.assertTrue(len(ruleset) == 2)
        self.assertTrue(fr1a == fr1b)
        self.assertTrue(fr1a != fr2)
Ejemplo n.º 7
0
    def test_transfers(self):
        dao = Dao()
        f1 = FeedInfo("F1")
        s1 = Stop("F1", "S1", "Stop 1", 45.0000, 0.0000)
        s2 = Stop("F1", "S2", "Stop 2", 45.0001, 0.0001)
        s3 = Stop("F1", "S3", "Stop 3", 45.0002, 0.0002)
        t12 = Transfer("F1", "S1", "S2")
        t21 = Transfer("F1", "S2", "S1")
        t23 = Transfer("F1",
                       "S2",
                       "S3",
                       transfer_type=Transfer.TRANSFER_TIMED,
                       min_transfer_time=180)
        t32 = Transfer("F1",
                       "S3",
                       "S2",
                       transfer_type=Transfer.TRANSFER_TIMED,
                       min_transfer_time=120)
        t13 = Transfer("F1", "S1", "S3", transfer_type=Transfer.TRANSFER_NONE)
        a1 = Agency("F1", "A1", "Agency 1", "url1", "Europe/Paris")
        a2 = Agency("F1", "A2", "Agency 2", "url2", "Europe/London")
        r1 = Route("F1", "R1", "A1", Route.TYPE_BUS)
        r2 = Route("F1", "R2", "A2", Route.TYPE_BUS)
        c1 = Calendar("F1", "C1")
        t1 = Trip("F1", "T1", "R1", "C1")
        t2 = Trip("F1", "T2", "R2", "C1")
        st1a = StopTime("F1", "T1", "S1", 0, None, 3600, 0.0)
        st1b = StopTime("F1", "T1", "S2", 1, 3800, None, 100.0)
        st2a = StopTime("F1", "T2", "S1", 0, None, 4600, 0.0)
        st2b = StopTime("F1", "T2", "S3", 1, 4800, None, 100.0)
        dao.add_all([
            f1, s1, s2, s3, t12, t21, t23, t32, t13, a1, a2, r1, r2, c1, t1,
            t2, st1a, st1b, st2a, st2b
        ])

        self.assertTrue(len(dao.transfers()) == 5)

        timed_transfers = dao.transfers(
            fltr=(Transfer.transfer_type == Transfer.TRANSFER_TIMED))
        self.assertTrue(len(timed_transfers) == 2)
        for transfer in timed_transfers:
            self.assertTrue(transfer.transfer_type == Transfer.TRANSFER_TIMED)

        s1_from_transfers = dao.transfers(
            fltr=(dao.transfer_from_stop().stop_name == "Stop 1"))
        self.assertTrue(len(s1_from_transfers) == 2)
        for transfer in s1_from_transfers:
            self.assertTrue(transfer.from_stop.stop_name == "Stop 1")

        s1_fromto_transfers = dao.transfers(
            fltr=((dao.transfer_from_stop().stop_name == "Stop 1")
                  | (dao.transfer_to_stop().stop_name == "Stop 1")))
        self.assertTrue(len(s1_fromto_transfers) == 3)
        for transfer in s1_fromto_transfers:
            self.assertTrue(transfer.from_stop.stop_name == "Stop 1"
                            or transfer.to_stop.stop_name == "Stop 1")

        s1 = dao.stop("S1", feed_id="F1")
        self.assertTrue(len(s1.from_transfers) == 2)
        self.assertTrue(len(s1.to_transfers) == 1)
        for transfer in s1.from_transfers:
            if transfer.to_stop.stop_id == "S2":
                self.assertTrue(
                    transfer.transfer_type == Transfer.TRANSFER_DEFAULT)
            elif transfer.to_stop.stop_id == "S3":
                self.assertTrue(
                    transfer.transfer_type == Transfer.TRANSFER_NONE)

        a1_stops = list(dao.stops(fltr=(Agency.agency_id == 'A1')))
        self.assertTrue(len(a1_stops) == 2)
        self.assertTrue(s1 in a1_stops)
        self.assertTrue(s2 in a1_stops)
Ejemplo n.º 8
0
    def test_trip(self):
        dao = Dao()
        f1 = FeedInfo("F1")
        a1 = Agency("F1",
                    "A1",
                    "Agency 1",
                    agency_url="http://www.agency.fr/",
                    agency_timezone="Europe/Paris")
        r1 = Route("F1",
                   "R1",
                   "A1",
                   3,
                   route_short_name="R1",
                   route_long_name="Route 1")
        c1 = Calendar("F1", "C1")
        c1.dates = [
            d for d in CalendarDate.range(
                CalendarDate.ymd(2016, 1, 1),
                CalendarDate.ymd(2016, 1, 31).next_day())
        ]
        s1 = Stop("F1", "S1", "Stop 1", 45.0, 0.0)
        s2 = Stop("F1", "S2", "Stop 2", 45.1, 0.1)
        s3 = Stop("F1", "S3", "Stop 3", 45.2, 0.2)
        t1 = Trip("F1", "T1", "R1", "C1")
        t1.direction_id = 0
        t11 = StopTime("F1", "T1", "S1", 0, 28800, 28800, 0.0)
        t12 = StopTime("F1", "T1", "S2", 1, 29400, 29400, 0.0)
        t13 = StopTime("F1", "T1", "S3", 2, 30000, 30000, 0.0)
        t2 = Trip("F1", "T2", "R1", "C1")
        t2.direction_id = 1
        # Order is not important for now
        t2.stop_times.append(StopTime(None, None, "S1", 1, 31000, 31000, 0.0))
        t2.stop_times.append(StopTime(None, None, "S2", 0, 30600, 30600, 0.0))

        dao.add_all([f1, a1, r1, c1, s1, s2, s3, t1, t11, t12, t13, t2])
        # Commit is needed to re-order stop times of T2
        dao.commit()

        cal = dao.calendar("C1", feed_id="F1")
        for trip in cal.trips:
            self.assertTrue(trip.calendar.service_id == "C1")
            for stoptime in trip.stop_times:
                self.assertTrue(stoptime.trip.calendar.service_id == "C1")

        stop = dao.stop("S2", feed_id="F1")
        for stoptime in stop.stop_times:
            self.assertTrue(stoptime.stop.stop_id == "S2")
            self.assertTrue(stoptime.trip.trip_id.startswith("T"))

        trip = dao.trip("T1", feed_id="F1")
        self.assertTrue(len(trip.stop_times) == 3)

        trip = dao.trip("T2", feed_id="F1")
        self.assertTrue(len(trip.stop_times) == 2)

        for trip in dao.trips(prefetch_stop_times=True):
            last_stop_seq = -1
            for stoptime in trip.stop_times:
                self.assertTrue(stoptime.stop_sequence > last_stop_seq)
                last_stop_seq = stoptime.stop_sequence

        for trip in dao.trips():
            for stoptime1, stoptime2 in trip.hops():
                self.assertTrue(stoptime1.trip == stoptime2.trip)
                self.assertTrue(stoptime1.stop_sequence +
                                1 == stoptime2.stop_sequence)

        trips = list(dao.trips(fltr=Trip.direction_id == 0))
        self.assertTrue(len(trips) == 1)
        trips = list(dao.trips(fltr=Trip.direction_id == 1))
        self.assertTrue(len(trips) == 1)