Example #1
0
    def test_non_overlapping_feeds(self):
        dao = Dao(DAO_URL, sql_logging=SQL_LOG)
        # Load twice the same data under two distinct namespaces
        dao.load_gtfs(DUMMY_GTFS, feed_id='A')
        dao.load_gtfs(DUMMY_GTFS, feed_id='B')

        # Check that each feed only return it's own data
        feed_a = dao.feed('A')
        self.assertTrue(feed_a.feed_id == 'A')
        feed_b = dao.feed('B')
        self.assertTrue(feed_b.feed_id == 'B')
        self.assertTrue(len(dao.agencies()) == 4)
        self.assertTrue(len(feed_a.agencies) == 2)
        self.assertTrue(len(feed_b.agencies) == 2)
        self.assertTrue(len(feed_a.routes) * 2 == len(dao.routes()))
        self.assertTrue(len(feed_b.routes) * 2 == len(dao.routes()))
        self.assertTrue(len(feed_a.stops) * 2 == len(list(dao.stops())))
        self.assertTrue(len(feed_b.stops) * 2 == len(list(dao.stops())))
        self.assertTrue(len(feed_a.calendars) * 2 == len(dao.calendars()))
        self.assertTrue(len(feed_b.calendars) * 2 == len(dao.calendars()))
        self.assertTrue(len(feed_a.trips) * 2 == len(list(dao.trips())))
        self.assertTrue(len(feed_b.trips) * 2 == len(list(dao.trips())))
Example #2
0
    def test_route_agency_multi_feed(self):
        dao = Dao()
        fa = FeedInfo("FA")
        aa1 = Agency("FA",
                     "A",
                     "Agency A",
                     agency_url="http://www.agency.fr/",
                     agency_timezone="Europe/Paris")
        ar1 = Route("FA",
                    "R",
                    "A",
                    3,
                    route_short_name="RA",
                    route_long_name="Route A")
        ar2 = Route("FA",
                    "R2",
                    "A",
                    3,
                    route_short_name="RA2",
                    route_long_name="Route A2")
        fb = FeedInfo("FB")
        ba1 = Agency("FB",
                     "A",
                     "Agency B",
                     agency_url="http://www.agency.fr/",
                     agency_timezone="Europe/Paris")
        br1 = Route("FB",
                    "R",
                    "A",
                    3,
                    route_short_name="RB",
                    route_long_name="Route B")
        dao.add_all([fa, aa1, ar1, ar2, fb, ba1, br1])

        fa = dao.feed("FA")
        self.assertTrue(len(fa.agencies) == 1)
        for a in fa.agencies:
            self.assertTrue(a.agency_name == "Agency A")
        self.assertTrue(len(fa.routes) == 2)
        for r in fa.routes:
            self.assertTrue(r.route_short_name.startswith("RA"))
            self.assertTrue(r.agency.agency_name == "Agency A")
Example #3
0
    def test_gtfs_data(self):
        dao = Dao(DAO_URL, sql_logging=SQL_LOG)
        dao.load_gtfs(MINI_GTFS)

        # Check feed
        feed = dao.feed()
        self.assertTrue(feed.feed_id == "")
        self.assertTrue(feed.feed_publisher_name is None)
        self.assertTrue(feed.feed_publisher_url is None)
        self.assertTrue(feed.feed_contact_email is None)
        self.assertTrue(feed.feed_contact_url is None)
        self.assertTrue(feed.feed_start_date is None)
        self.assertTrue(feed.feed_end_date is None)
        self.assertTrue(len(dao.agencies()) == 1)
        self.assertTrue(len(dao.routes()) == 1)
        self.assertTrue(len(feed.agencies) == 1)
        self.assertTrue(len(feed.routes) == 1)

        # Check if optional route agency is set
        a = dao.agency("A")
        self.assertTrue(a is not None)
        self.assertTrue(len(a.routes) == 1)

        # Check for frequency-generated trips
        # They should all have the same delta
        trips = dao.trips(fltr=(Trip.frequency_generated == True),
                          prefetch_stop_times=True)
        n_trips = 0
        deltas = {}
        for trip in trips:
            original_trip_id = trip.trip_id.rsplit('@', 1)[0]
            delta1 = []
            for st1, st2 in trip.hops():
                delta1.append(st2.arrival_time - st1.departure_time)
            delta2 = deltas.get(original_trip_id)
            if delta2 is not None:
                self.assertTrue(delta1 == delta2)
            else:
                deltas[original_trip_id] = delta1
            n_trips += 1
        self.assertTrue(n_trips == 8)
Example #4
0
    def test_all_gtfs(self):

        if not ENABLE:
            print("This test is disabled as it is very time-consuming.")
            print("If you want to enable it, please see in the code.")
            return

        # Create temporary directory if not there
        if not os.path.isdir(DIR):
            os.mkdir(DIR)

        # Create a DAO. Re-use any existing present.
        logging.basicConfig(level=logging.INFO)
        dao = Dao("%s/all_gtfs.sqlite" % (DIR))

        deids = IDS_TO_LOAD
        if deids is None:
            print("Downloading meta-info for all agencies...")
            resource_url = "http://www.gtfs-data-exchange.com/api/agencies?format=json"
            response = requests.get(resource_url).json()
            if response.get('status_code') != 200:
                raise IOError()
            deids = []
            for entry in response.get('data'):
                deid = entry.get('dataexchange_id')
                deids.append(deid)
            # Randomize the list, otherwise we will always load ABCBus, then ...
            random.shuffle(deids)

        for deid in deids:
            try:
                local_filename = "%s/%s.gtfs.zip" % (DIR, deid)
                if os.path.exists(local_filename) and SKIP_EXISTING:
                    print("Skipping [%s], GTFS already present." % (deid))
                    continue

                print("Downloading meta-info for ID [%s]" % (deid))
                resource_url = "http://www.gtfs-data-exchange.com/api/agency?agency=%s&format=json" % deid
                response = requests.get(resource_url).json()
                status_code = response.get('status_code')
                if status_code != 200:
                    raise IOError("Error %d (%s)" %
                                  (status_code, response.get('status_txt')))
                data = response.get('data')
                agency_data = data.get('agency')
                agency_name = agency_data.get('name')
                agency_area = agency_data.get('area')
                agency_country = agency_data.get('country')

                print("Processing [%s] %s (%s / %s)" %
                      (deid, agency_name, agency_country, agency_area))
                date_max = 0.0
                file_url = None
                file_size = 0
                file_md5 = None
                for datafile in data.get('datafiles'):
                    date_added = datafile.get('date_added')
                    if date_added > date_max:
                        date_max = date_added
                        file_url = datafile.get('file_url')
                        file_size = datafile.get('size')
                        file_md5 = datafile.get('md5sum')
                if file_url is None:
                    print("No datafile available, skipping.")
                    continue

                if file_size > MAX_GTFS_SIZE:
                    print("GTFS too large (%d bytes > max %d), skipping." %
                          (file_size, MAX_GTFS_SIZE))
                    continue

                # Check if the file is present and do not download it.
                try:
                    existing_md5 = hashlib.md5(
                        open(local_filename, 'rb').read()).hexdigest()
                except:
                    existing_md5 = None
                if existing_md5 == file_md5:
                    print("Using existing file '%s': MD5 checksum matches." %
                          (local_filename))
                else:
                    print("Downloading file '%s' to '%s' (%d bytes)" %
                          (file_url, local_filename, file_size))
                    with open(local_filename, 'wb') as local_file:
                        cnx = requests.get(file_url, stream=True)
                        for block in cnx.iter_content(1024):
                            local_file.write(block)
                    cnx.close()

                feed = dao.feed(deid)
                if feed is not None:
                    print("Removing existing data for feed [%s]" % (deid))
                    dao.delete_feed(deid)
                print("Importing into DAO as ID [%s]" % (deid))
                try:
                    dao.load_gtfs("%s/%s.gtfs.zip" % (DIR, deid), feed_id=deid)
                except:
                    error_filename = "%s/%s.error" % (DIR, deid)
                    print("Import of [%s]: FAILED. Logging error to '%s'" %
                          (deid, error_filename))
                    with open(error_filename, 'wb') as errfile:
                        errfile.write(traceback.format_exc())
                    raise
                print("Import of [%s]: OK." % (deid))

            except Exception as error:
                logging.exception(error)
                continue
Example #5
0
    def test_gtfs_data(self):
        dao = Dao(DAO_URL, sql_logging=False)
        dao.load_gtfs(DUMMY_GTFS)

        # Check feed
        feed = dao.feed()
        self.assertTrue(feed.feed_id == "")
        self.assertTrue(feed.feed_publisher_name == "Mecatran")
        self.assertTrue(feed.feed_publisher_url == "http://www.mecatran.com/")
        self.assertTrue(feed.feed_contact_email == "*****@*****.**")
        self.assertTrue(feed.feed_lang == "fr")
        self.assertTrue(len(dao.agencies()) == 2)
        self.assertTrue(len(dao.routes()) == 3)
        self.assertTrue(len(feed.agencies) == 2)
        self.assertTrue(len(feed.routes) == 3)

        # Check agencies
        at = dao.agency("AT")
        self.assertTrue(at.agency_name == "Agency Train")
        self.assertTrue(len(at.routes) == 1)
        ab = dao.agency("AB")
        self.assertTrue(ab.agency_name == "Agency Bus")
        self.assertTrue(len(ab.routes) == 2)

        # Check calendars
        week = dao.calendar("WEEK")
        self.assertTrue(len(week.dates) == 253)
        summer = dao.calendar("SUMMER")
        self.assertTrue(len(summer.dates) == 42)
        mon = dao.calendar("MONDAY")
        self.assertTrue(len(mon.dates) == 49)
        sat = dao.calendar("SAT")
        self.assertTrue(len(sat.dates) == 53)
        for date in mon.dates:
            self.assertTrue(date.dow() == 0)
        for date in sat.dates:
            self.assertTrue(date.dow() == 5)
        for date in week.dates:
            self.assertTrue(date.dow() >= 0 and date.dow() <= 4)
        for date in summer.dates:
            self.assertTrue(date >= CalendarDate.ymd(2016, 7, 1)
                            and date <= CalendarDate.ymd(2016, 8, 31))
        empty = dao.calendars(
            func.date(CalendarDate.date) == datetime.date(2016, 5, 1))
        # OR USE: empty = dao.calendars(CalendarDate.date == "2016-05-01")
        self.assertTrue(len(empty) == 0)
        july4 = CalendarDate.ymd(2016, 7, 4)
        summer_mon = dao.calendars(func.date(CalendarDate.date) == july4.date)
        n = 0
        for cal in summer_mon:
            self.assertTrue(july4 in cal.dates)
            n += 1
        self.assertTrue(n == 3)

        # Check stops
        sbq = dao.stop("BQ")
        self.assertAlmostEqual(sbq.stop_lat, 44.844, places=2)
        self.assertAlmostEqual(sbq.stop_lon, -0.573, places=2)
        self.assertTrue(sbq.stop_name == "Bordeaux Quinconces")
        n = 0
        for stop in dao.stops(Stop.stop_name.like("Gare%")):
            self.assertTrue(stop.stop_name.startswith("Gare"))
            n += 1
        self.assertTrue(n == 7)
        n = 0
        for stop in dao.stops(
                fltr=dao.in_area(RectangularArea(44.7, -0.6, 44.9, -0.4))):
            self.assertTrue(stop.stop_lat >= 44.7 and stop.stop_lat <= 44.9
                            and stop.stop_lon >= -0.6
                            and stop.stop_lon <= -0.4)
            n += 1
        self.assertTrue(n == 16)
        for station in dao.stops(Stop.location_type == Stop.TYPE_STATION):
            self.assertTrue(station.location_type == Stop.TYPE_STATION)
            self.assertTrue(len(station.sub_stops) >= 2)
            for stop in station.sub_stops:
                self.assertTrue(stop.parent_station == station)

        # Check zones
        z_inexistant = dao.zone("ZX")
        self.assertTrue(z_inexistant is None)
        z1 = dao.zone("Z1")
        self.assertEqual(16, len(z1.stops))
        z2 = dao.zone("Z2")
        self.assertEqual(4, len(z2.stops))

        # Check transfers
        transfers = dao.transfers()
        self.assertTrue(len(transfers) == 3)
        transfers = dao.transfers(
            fltr=(dao.transfer_from_stop().stop_id == 'GBSJB'))
        self.assertTrue(len(transfers) == 1)
        self.assertTrue(transfers[0].from_stop.stop_id == 'GBSJB')

        # Check routes
        tgv = dao.route("TGVBP")
        self.assertTrue(tgv.agency == at)
        self.assertTrue(tgv.route_type == 2)
        r1 = dao.route("BR")
        self.assertTrue(r1.route_short_name == "R1")
        self.assertTrue(r1.route_long_name == "Bus Red")
        n = 0
        for route in dao.routes(Route.route_type == 3):
            self.assertTrue(route.route_type == 3)
            n += 1
        self.assertTrue(n == 2)

        # Check trip for route
        n = 0
        trips = dao.trips(fltr=Route.route_type == Route.TYPE_BUS)
        for trip in trips:
            self.assertTrue(trip.route.route_type == Route.TYPE_BUS)
            n += 1
        self.assertTrue(n > 20)

        # Check trips on date
        trips = dao.trips(fltr=func.date(CalendarDate.date) == july4.date,
                          prefetch_calendars=True)
        n = 0
        for trip in trips:
            self.assertTrue(july4 in trip.calendar.dates)
            n += 1
        self.assertTrue(n > 30)
Example #6
0
class Exporter:
    def __init__(self, arguments):
        self.logger = logging.getLogger('gtfsexporter')
        self._arguments = arguments

        if arguments['--id'] is None:
            arguments['--id'] = "default"

        database_path = os.path.join(__cwd_path__,
                                     arguments['--id'] + ".sqlite")

        self._dao = Dao(database_path,
                        sql_logging=arguments['--logsql'],
                        schema=arguments['--schema'])

        if arguments['--list']:
            for feed in self._dao.feeds():
                print(feed.feed_id if feed.feed_id != "" else "(default)")

        if arguments['--delete']:
            feed_id = arguments['--id']
            existing_feed = self._dao.feed(feed_id)
            if existing_feed:
                self.logger.warning("Deleting existing feed ID '%s'" % feed_id)
                self._dao.delete_feed(feed_id)
                self._dao.commit()

    @property
    def dao(self):
        return self._dao

    def load(self, provider: DataProvider):
        self.logger.info("Importing data from provided source")
        provider.load_data_source(self._dao)

    def process(self, processor: Processor = None):
        self.logger.info("Processing data from provided source")

        # Here we should use a rule processor to have more flexibility when processing data
        # processor.process(ruleset)

        for route in self._dao.routes():
            print("updating route [%s] setting correct color" %
                  route.route_long_name)

            route.route_text_color = "FFFFFF"

            if route.route_type == Route.TYPE_BUS:
                route.route_color = "195BAD"
            elif route.route_type == Route.TYPE_TRAM:
                route.route_color = "FFAD33"
            elif route.route_type == Route.TYPE_RAIL:
                route.route_color = "FF5B33"
            elif route.route_type == Route.TYPE_CABLECAR:
                route.route_color = "FF8433"
            elif route.route_type == Route.TYPE_SUBWAY:
                route.route_color = "D13333"
            elif route.route_type == Route.TYPE_FERRY:
                route.route_color = "62A9DD"

        self._dao.commit()

    def export(self, bundle=False, out_path: str = __output_path__) -> str:
        self.logger.info(
            f"Generating archive with name gtfs-{self._arguments['--id']}.zip")

        class __Args:
            filter = None

        context = Context(self._dao, __Args(), out_path)

        if bundle:
            w = Writer(context, bundle=f"gtfs-{self._arguments['--id']}.zip")
        else:
            w = Writer(context)

        w.run()

        return out_path