Exemplo n.º 1
0
 def test_filter_by_agency(self):
     FilterExtract(self.G, self.fname_copy,
                   agency_ids_to_preserve=['DTA']).create_filtered_copy()
     hash_copy = hashlib.md5(open(self.fname_copy, 'rb').read()).hexdigest()
     self.assertNotEqual(self.hash_orig, hash_copy)
     G_copy = GTFS(self.fname_copy)
     agency_table = G_copy.get_table("agencies")
     assert "EXA" not in agency_table[
         'agency_id'].values, "EXA agency should not be preserved"
     assert "DTA" in agency_table[
         'agency_id'].values, "DTA agency should be preserved"
     routes_table = G_copy.get_table("routes")
     assert "EXR1" not in routes_table[
         'route_id'].values, "EXR1 route_id should not be preserved"
     assert "AB" in routes_table[
         'route_id'].values, "AB route_id should be preserved"
     trips_table = G_copy.get_table("trips")
     assert "EXT1" not in trips_table[
         'trip_id'].values, "EXR1 route_id should not be preserved"
     assert "AB1" in trips_table[
         'trip_id'].values, "AB1 route_id should be preserved"
     calendar_table = G_copy.get_table("calendar")
     assert "FULLW" in calendar_table[
         'service_id'].values, "FULLW service_id should be preserved"
     # stop_times
     stop_times_table = G_copy.get_table("stop_times")
     # 01:23:45 corresponds to 3600 + (32 * 60) + 45 [in day seconds]
     assert 3600 + (32 * 60) + 45 not in stop_times_table['arr_time']
     os.remove(self.fname_copy)
Exemplo n.º 2
0
    def test_frequencyLoader(self):
        import_gtfs(self.fdict, self.conn, preserve_connection=True)
        # "\nfrequency_route, freq_service, freq_trip, going north, freq_name, shape_es1" \
        keys = ["trip_I", "start_time", "end_time", "headway_secs", "exact_times", "start_time_ds", "end_time_ds"]
        self.setDictConn()
        rows = self.conn.execute("SELECT * FROM frequencies").fetchall()
        for key in keys:
            row = rows[0]
            assert key in row
        for row in rows:
            if row["start_time_ds"] == 14 * 3600:
                self.assertEqual(row["exact_times"], 1)
        # there should be twelve trips with service_I freq
        count = self.conn.execute("SELECT count(*) AS count FROM trips JOIN calendar "
                                  "USING(service_I) WHERE service_id='freq_service'").fetchone()['count']

        assert count == 12, count
        rows = self.conn.execute("SELECT trip_I FROM trips JOIN calendar "
                                 "USING(service_I) WHERE service_id='freq_service'").fetchall()
        for row in rows:
            trip_I = row['trip_I']
            res = self.conn.execute("SELECT * FROM stop_times WHERE trip_I={trip_I}".format(trip_I=trip_I)).fetchall()
            assert len(res) > 1, res
        self.setRowConn()
        g = GTFS(self.conn)
        print("Stop times: \n\n ", g.get_table("stop_times"))
        print("Frequencies: \n\n ", g.get_table("frequencies"))
Exemplo n.º 3
0
    def test_filter_spatially(self):
        # test that the db is split by a given spatial boundary
        FilterExtract(self.G,
                      self.fname_copy,
                      buffer_lat=36.914893,
                      buffer_lon=-116.76821,
                      buffer_distance_km=50).create_filtered_copy()
        G_copy = GTFS(self.fname_copy)

        stops_table = G_copy.get_table("stops")
        self.assertNotIn("FUR_CREEK_RES", stops_table['stop_id'].values)
        self.assertIn("AMV", stops_table['stop_id'].values)
        self.assertEqual(len(stops_table['stop_id'].values), 8)

        conn_copy = sqlite3.connect(self.fname_copy)
        stop_ids_df = pandas.read_sql(
            'SELECT stop_id from stop_times '
            'left join stops '
            'on stops.stop_I = stop_times.stop_I', conn_copy)
        stop_ids = stop_ids_df["stop_id"].values

        self.assertNotIn("FUR_CREEK_RES", stop_ids)
        self.assertIn("AMV", stop_ids)

        trips_table = G_copy.get_table("trips")
        self.assertNotIn("BFC1", trips_table['trip_id'].values)

        routes_table = G_copy.get_table("routes")
        self.assertNotIn("BFC", routes_table['route_id'].values)
Exemplo n.º 4
0
    def test_filter_end_date_not_included(self):
        # the end date should not be included:
        FilterExtract(self.G,
                      self.fname_copy,
                      start_date="2007-01-02",
                      end_date="2010-12-31").create_filtered_copy()

        hash_copy = hashlib.md5(open(self.fname_copy, 'rb').read()).hexdigest()
        self.assertNotEqual(self.hash_orig, hash_copy)
        G_copy = GTFS(self.fname_copy)
        dsut_end = G_copy.get_day_start_ut("2010-12-31")
        dsut_to_trip_I = G_copy.get_tripIs_within_range_by_dsut(
            dsut_end, dsut_end + 24 * 3600)
        self.assertEqual(len(dsut_to_trip_I), 0)

        calendar_copy = G_copy.get_table("calendar")
        max_date_calendar = max([
            datetime.datetime.strptime(el, "%Y-%m-%d")
            for el in calendar_copy["end_date"].values
        ])
        min_date_calendar = max([
            datetime.datetime.strptime(el, "%Y-%m-%d")
            for el in calendar_copy["start_date"].values
        ])
        end_date_not_included = datetime.datetime.strptime(
            "2010-12-31", "%Y-%m-%d")
        start_date_not_included = datetime.datetime.strptime(
            "2007-01-01", "%Y-%m-%d")
        self.assertLess(max_date_calendar,
                        end_date_not_included,
                        msg="the last date should not be included in calendar")
        self.assertLess(start_date_not_included, min_date_calendar)
        os.remove(self.fname_copy)
Exemplo n.º 5
0
def calc_transfers(conn, threshold_meters=1000):
    geohash_precision = _get_geo_hash_precision(threshold_meters / 1000.)
    geo_index = GeoGridIndex(precision=geohash_precision)
    g = GTFS(conn)
    stops = g.get_table("stops")
    stop_geopoints = []
    cursor = conn.cursor()

    for stop in stops.itertuples():
        stop_geopoint = GeoPoint(stop.lat, stop.lon, ref=stop.stop_I)
        geo_index.add_point(stop_geopoint)
        stop_geopoints.append(stop_geopoint)
    for stop_geopoint in stop_geopoints:
        nearby_stop_geopoints = geo_index.get_nearest_points_dirty(
            stop_geopoint, threshold_meters / 1000.0, "km")
        from_stop_I = int(stop_geopoint.ref)
        from_lat = stop_geopoint.latitude
        from_lon = stop_geopoint.longitude

        to_stop_Is = []
        distances = []
        for nearby_stop_geopoint in nearby_stop_geopoints:
            to_stop_I = int(nearby_stop_geopoint.ref)
            if to_stop_I == from_stop_I:
                continue
            to_lat = nearby_stop_geopoint.latitude
            to_lon = nearby_stop_geopoint.longitude
            distance = math.ceil(
                wgs84_distance(from_lat, from_lon, to_lat, to_lon))
            if distance <= threshold_meters:
                to_stop_Is.append(to_stop_I)
                distances.append(distance)

        n_pairs = len(to_stop_Is)
        from_stop_Is = [from_stop_I] * n_pairs
        cursor.executemany(
            'INSERT OR REPLACE INTO stop_distances VALUES (?, ?, ?, ?, ?, ?);',
            zip(from_stop_Is, to_stop_Is, distances, [None] * n_pairs,
                [None] * n_pairs, [None] * n_pairs))
        cursor.execute(
            'CREATE INDEX IF NOT EXISTS idx_sd_fsid ON stop_distances (from_stop_I);'
        )
Exemplo n.º 6
0
    def test_filter_spatially_2(self):
        n_rows_before = {
            "routes": 4,
            "stop_times": 14,
            "trips": 4,
            "stops": 6,
            "shapes": 4
        }
        n_rows_after_1000 = {  # within "soft buffer" in the feed data
            "routes": 1,
            "stop_times": 2,
            "trips": 1,
            "stops": 2,
            "shapes": 0
        }
        n_rows_after_3000 = {  # within "hard buffer" in the feed data
            "routes": len(["t1", "t3", "t4"]),
            "stop_times": 11,
            "trips": 4,
            "stops": len({"P", "H", "V", "L", "B"}),
            # for some reason, the first "shapes": 4
        }
        paris_lat = 48.832781
        paris_lon = 2.360734

        SELECT_MIN_MAX_SHAPE_BREAKS_BY_TRIP_I_SQL = \
            "SELECT trips.trip_I, shape_id, min(shape_break) as min_shape_break, max(shape_break) as max_shape_break FROM trips, stop_times WHERE trips.trip_I=stop_times.trip_I GROUP BY trips.trip_I"
        trip_min_max_shape_seqs = pandas.read_sql(
            SELECT_MIN_MAX_SHAPE_BREAKS_BY_TRIP_I_SQL, self.G_filter_test.conn)

        for distance_km, n_rows_after in zip(
            [1000, 3000], [n_rows_after_1000, n_rows_after_3000]):
            try:
                os.remove(self.fname_copy)
            except FileNotFoundError:
                pass
            FilterExtract(
                self.G_filter_test,
                self.fname_copy,
                buffer_lat=paris_lat,
                buffer_lon=paris_lon,
                buffer_distance_km=distance_km).create_filtered_copy()
            for table_name, n_rows in n_rows_before.items():
                self.assertEqual(
                    len(self.G_filter_test.get_table(table_name)), n_rows,
                    "Row counts before differ in " + table_name +
                    ", distance: " + str(distance_km))
            G_copy = GTFS(self.fname_copy)
            for table_name, n_rows in n_rows_after.items():
                table = G_copy.get_table(table_name)
                self.assertEqual(
                    len(table), n_rows,
                    "Row counts after differ in " + table_name +
                    ", distance: " + str(distance_km) + "\n" + str(table))

            # assert that stop_times are resequenced starting from one
            counts = pandas.read_sql(
                "SELECT count(*) FROM stop_times GROUP BY trip_I ORDER BY trip_I",
                G_copy.conn)
            max_values = pandas.read_sql(
                "SELECT max(seq) FROM stop_times GROUP BY trip_I ORDER BY trip_I",
                G_copy.conn)
            self.assertTrue((counts.values == max_values.values).all())
Exemplo n.º 7
0
class TestGTFSFilter(unittest.TestCase):
    def setUp(self):
        self.gtfs_source_dir = os.path.join(os.path.dirname(__file__),
                                            "test_data")
        self.gtfs_source_dir_filter_test = os.path.join(
            self.gtfs_source_dir, "filter_test_feed/")

        # self.G = GTFS.from_directory_as_inmemory_db(self.gtfs_source_dir)

        # some preparations:
        self.fname = self.gtfs_source_dir + "/test_gtfs.sqlite"
        self.fname_copy = self.gtfs_source_dir + "/test_gtfs_copy.sqlite"
        self.fname_filter = self.gtfs_source_dir + "/test_gtfs_filter_test.sqlite"

        self._remove_temporary_files()
        self.assertFalse(os.path.exists(self.fname_copy))

        conn = sqlite3.connect(self.fname)
        import_gtfs(self.gtfs_source_dir,
                    conn,
                    preserve_connection=True,
                    print_progress=False)
        conn_filter = sqlite3.connect(self.fname_filter)
        import_gtfs(self.gtfs_source_dir_filter_test,
                    conn_filter,
                    preserve_connection=True,
                    print_progress=False)

        self.G = GTFS(conn)
        self.G_filter_test = GTFS(conn_filter)

        self.hash_orig = hashlib.md5(open(self.fname, 'rb').read()).hexdigest()

    def _remove_temporary_files(self):
        for fn in [self.fname, self.fname_copy, self.fname_filter]:
            if os.path.exists(fn) and os.path.isfile(fn):
                os.remove(fn)

    def tearDown(self):
        self._remove_temporary_files()

    def test_copy(self):
        # do a simple copy
        FilterExtract(self.G, self.fname_copy,
                      update_metadata=False).create_filtered_copy()

        # check that the copying has been properly performed:
        hash_copy = hashlib.md5(open(self.fname_copy, 'rb').read()).hexdigest()
        self.assertTrue(os.path.exists(self.fname_copy))
        self.assertEqual(self.hash_orig, hash_copy)

    def test_filter_change_metadata(self):
        # A simple test that changing update_metadata to True, does update some stuff:
        FilterExtract(self.G, self.fname_copy,
                      update_metadata=True).create_filtered_copy()
        # check that the copying has been properly performed:
        hash_orig = hashlib.md5(open(self.fname, 'rb').read()).hexdigest()
        hash_copy = hashlib.md5(open(self.fname_copy, 'rb').read()).hexdigest()
        self.assertTrue(os.path.exists(self.fname_copy))
        self.assertNotEqual(hash_orig, hash_copy)
        os.remove(self.fname_copy)

    def test_filter_by_agency(self):
        FilterExtract(self.G, self.fname_copy,
                      agency_ids_to_preserve=['DTA']).create_filtered_copy()
        hash_copy = hashlib.md5(open(self.fname_copy, 'rb').read()).hexdigest()
        self.assertNotEqual(self.hash_orig, hash_copy)
        G_copy = GTFS(self.fname_copy)
        agency_table = G_copy.get_table("agencies")
        assert "EXA" not in agency_table[
            'agency_id'].values, "EXA agency should not be preserved"
        assert "DTA" in agency_table[
            'agency_id'].values, "DTA agency should be preserved"
        routes_table = G_copy.get_table("routes")
        assert "EXR1" not in routes_table[
            'route_id'].values, "EXR1 route_id should not be preserved"
        assert "AB" in routes_table[
            'route_id'].values, "AB route_id should be preserved"
        trips_table = G_copy.get_table("trips")
        assert "EXT1" not in trips_table[
            'trip_id'].values, "EXR1 route_id should not be preserved"
        assert "AB1" in trips_table[
            'trip_id'].values, "AB1 route_id should be preserved"
        calendar_table = G_copy.get_table("calendar")
        assert "FULLW" in calendar_table[
            'service_id'].values, "FULLW service_id should be preserved"
        # stop_times
        stop_times_table = G_copy.get_table("stop_times")
        # 01:23:45 corresponds to 3600 + (32 * 60) + 45 [in day seconds]
        assert 3600 + (32 * 60) + 45 not in stop_times_table['arr_time']
        os.remove(self.fname_copy)

    def test_filter_by_start_and_end_full_range(self):
        # untested tables with filtering: stops, shapes
        # test filtering by start and end time, copy full range
        FilterExtract(self.G,
                      self.fname_copy,
                      start_date=u"2007-01-01",
                      end_date=u"2011-01-01",
                      update_metadata=False).create_filtered_copy()
        G_copy = GTFS(self.fname_copy)
        dsut_end = G_copy.get_day_start_ut("2010-12-31")
        dsut_to_trip_I = G_copy.get_tripIs_within_range_by_dsut(
            dsut_end, dsut_end + 24 * 3600)
        self.assertGreater(len(dsut_to_trip_I), 0)
        os.remove(self.fname_copy)

    def test_filter_end_date_not_included(self):
        # the end date should not be included:
        FilterExtract(self.G,
                      self.fname_copy,
                      start_date="2007-01-02",
                      end_date="2010-12-31").create_filtered_copy()

        hash_copy = hashlib.md5(open(self.fname_copy, 'rb').read()).hexdigest()
        self.assertNotEqual(self.hash_orig, hash_copy)
        G_copy = GTFS(self.fname_copy)
        dsut_end = G_copy.get_day_start_ut("2010-12-31")
        dsut_to_trip_I = G_copy.get_tripIs_within_range_by_dsut(
            dsut_end, dsut_end + 24 * 3600)
        self.assertEqual(len(dsut_to_trip_I), 0)

        calendar_copy = G_copy.get_table("calendar")
        max_date_calendar = max([
            datetime.datetime.strptime(el, "%Y-%m-%d")
            for el in calendar_copy["end_date"].values
        ])
        min_date_calendar = max([
            datetime.datetime.strptime(el, "%Y-%m-%d")
            for el in calendar_copy["start_date"].values
        ])
        end_date_not_included = datetime.datetime.strptime(
            "2010-12-31", "%Y-%m-%d")
        start_date_not_included = datetime.datetime.strptime(
            "2007-01-01", "%Y-%m-%d")
        self.assertLess(max_date_calendar,
                        end_date_not_included,
                        msg="the last date should not be included in calendar")
        self.assertLess(start_date_not_included, min_date_calendar)
        os.remove(self.fname_copy)

    def test_filter_spatially(self):
        # test that the db is split by a given spatial boundary
        FilterExtract(self.G,
                      self.fname_copy,
                      buffer_lat=36.914893,
                      buffer_lon=-116.76821,
                      buffer_distance_km=50).create_filtered_copy()
        G_copy = GTFS(self.fname_copy)

        stops_table = G_copy.get_table("stops")
        self.assertNotIn("FUR_CREEK_RES", stops_table['stop_id'].values)
        self.assertIn("AMV", stops_table['stop_id'].values)
        self.assertEqual(len(stops_table['stop_id'].values), 8)

        conn_copy = sqlite3.connect(self.fname_copy)
        stop_ids_df = pandas.read_sql(
            'SELECT stop_id from stop_times '
            'left join stops '
            'on stops.stop_I = stop_times.stop_I', conn_copy)
        stop_ids = stop_ids_df["stop_id"].values

        self.assertNotIn("FUR_CREEK_RES", stop_ids)
        self.assertIn("AMV", stop_ids)

        trips_table = G_copy.get_table("trips")
        self.assertNotIn("BFC1", trips_table['trip_id'].values)

        routes_table = G_copy.get_table("routes")
        self.assertNotIn("BFC", routes_table['route_id'].values)
        # cases:
        # whole trip excluded
        # whole route excluded
        # whole agency excluded
        # part of trip excluded
        # part of route excluded
        # part of agency excluded
        # not removing stops from a trip that returns into area

        # test higher-order removals
        # stop A preserved
        # -> stop B preserved
        # -> stop C preserved

    def test_filter_spatially_2(self):
        n_rows_before = {
            "routes": 4,
            "stop_times": 14,
            "trips": 4,
            "stops": 6,
            "shapes": 4
        }
        n_rows_after_1000 = {  # within "soft buffer" in the feed data
            "routes": 1,
            "stop_times": 2,
            "trips": 1,
            "stops": 2,
            "shapes": 0
        }
        n_rows_after_3000 = {  # within "hard buffer" in the feed data
            "routes": len(["t1", "t3", "t4"]),
            "stop_times": 11,
            "trips": 4,
            "stops": len({"P", "H", "V", "L", "B"}),
            # for some reason, the first "shapes": 4
        }
        paris_lat = 48.832781
        paris_lon = 2.360734

        SELECT_MIN_MAX_SHAPE_BREAKS_BY_TRIP_I_SQL = \
            "SELECT trips.trip_I, shape_id, min(shape_break) as min_shape_break, max(shape_break) as max_shape_break FROM trips, stop_times WHERE trips.trip_I=stop_times.trip_I GROUP BY trips.trip_I"
        trip_min_max_shape_seqs = pandas.read_sql(
            SELECT_MIN_MAX_SHAPE_BREAKS_BY_TRIP_I_SQL, self.G_filter_test.conn)

        for distance_km, n_rows_after in zip(
            [1000, 3000], [n_rows_after_1000, n_rows_after_3000]):
            try:
                os.remove(self.fname_copy)
            except FileNotFoundError:
                pass
            FilterExtract(
                self.G_filter_test,
                self.fname_copy,
                buffer_lat=paris_lat,
                buffer_lon=paris_lon,
                buffer_distance_km=distance_km).create_filtered_copy()
            for table_name, n_rows in n_rows_before.items():
                self.assertEqual(
                    len(self.G_filter_test.get_table(table_name)), n_rows,
                    "Row counts before differ in " + table_name +
                    ", distance: " + str(distance_km))
            G_copy = GTFS(self.fname_copy)
            for table_name, n_rows in n_rows_after.items():
                table = G_copy.get_table(table_name)
                self.assertEqual(
                    len(table), n_rows,
                    "Row counts after differ in " + table_name +
                    ", distance: " + str(distance_km) + "\n" + str(table))

            # assert that stop_times are resequenced starting from one
            counts = pandas.read_sql(
                "SELECT count(*) FROM stop_times GROUP BY trip_I ORDER BY trip_I",
                G_copy.conn)
            max_values = pandas.read_sql(
                "SELECT max(seq) FROM stop_times GROUP BY trip_I ORDER BY trip_I",
                G_copy.conn)
            self.assertTrue((counts.values == max_values.values).all())

    def test_remove_all_trips_fully_outside_buffer(self):
        stops = self.G.stops()
        stop_1 = stops[stops['stop_I'] == 1]

        n_trips_before = len(self.G.get_table("trips"))

        remove_all_trips_fully_outside_buffer(self.G.conn, float(stop_1.lat),
                                              float(stop_1.lon), 100000)
        self.assertEqual(len(self.G.get_table("trips")), n_trips_before)

        # 0.002 (=max 2 meters from the stop), rounding errors can take place...
        remove_all_trips_fully_outside_buffer(self.G.conn, float(stop_1.lat),
                                              float(stop_1.lon), 0.002)
        self.assertEqual(len(self.G.get_table("trips")),
                         2)  # value "2" comes from the data
Exemplo n.º 8
0
class ImportValidator(object):
    def __init__(self, gtfssource, gtfs):
        """
        Parameters
        ----------
        gtfs_sources: list of strings
        gtfs: GTFS, or path to a GTFS object
            A GTFS object
        """
        self.df_freq_dict = {}
        if isinstance(gtfssource, string_types + (dict, )):
            self.gtfs_sources = [gtfssource]
        else:
            assert isinstance(gtfssource, list)
            self.gtfs_sources = gtfssource
        assert len(
            self.gtfs_sources
        ) > 0, "There needs to be some source files for validating an import"

        if not isinstance(gtfs, GTFS):
            self.gtfs = GTFS(gtfs)
        else:
            self.gtfs = gtfs

        self.location = self.gtfs.get_location_name()
        self.warnings_container = WarningsContainer()

    def get_warnings(self):
        self.warnings_container.clear()
        self._validate_table_counts()
        self._validate_no_nulls()
        self._validate_danglers()
        self.warnings_container.print_summary()
        return self.warnings_container

    def _validate_table_counts(self):
        """
        Imports source .txt files, checks row counts and then compares the rowcounts with the gtfsobject
        :return:
        """
        for table_name_txt, db_table_name, row_warning in zip(
                SOURCE_TABLE_NAMES, DB_TABLE_NAMES, ROW_WARNINGS):
            source_row_count = 0

            for gtfs_source in self.gtfs_sources:
                frequencies_in_source = source_table_txt_to_pandas(
                    gtfs_source, 'frequencies.txt')
                try:
                    if table_name_txt == 'trips' and not frequencies_in_source.empty:
                        source_row_count += self._frequency_generated_trips(
                            gtfs_source, table_name_txt)

                    elif table_name_txt == 'stop_times' and not frequencies_in_source.empty:
                        source_row_count += self._frequency_generated_stop_times(
                            gtfs_source, table_name_txt)
                    else:
                        df = source_table_txt_to_pandas(
                            gtfs_source, table_name_txt)

                        source_row_count += len(df.index)
                except (IOError) as e:
                    print(e)
                    pass

            # Result from GTFSobj:
            database_row_count = self.gtfs.get_row_count(db_table_name)
            if source_row_count == database_row_count:
                print("Row counts match for " + table_name_txt +
                      " between the source and database (" +
                      str(database_row_count) + ")")

            else:
                difference = database_row_count - source_row_count
                print('Row counts do not match for ' + str(table_name_txt) +
                      ': (source=' + str(source_row_count) + ', database=' +
                      str(database_row_count) + ")")
                if table_name_txt == "calendar" and difference > 0:
                    query = "SELECT count(*) FROM (SELECT * FROM calendar ORDER BY service_I DESC LIMIT " \
                            + str(int(difference)) + \
                            ") WHERE start_date=end_date AND m=0 AND t=0 AND w=0 AND th=0 AND f=0 AND s=0 AND su=0"
                    number_of_entries_added_by_calendar_dates_loader = self.gtfs.execute_custom_query(
                        query).fetchone()[0]
                    if number_of_entries_added_by_calendar_dates_loader == difference:
                        print(
                            "    But don't worry, the extra entries seem to just dummy entries due to calendar_dates"
                        )
                    else:
                        print("    Reason for this is unknown.")
                        self.warnings_container.add_warning(
                            self.location, row_warning, difference)
                else:
                    self.warnings_container.add_warning(
                        self.location, row_warning, difference)

    def _validate_no_nulls(self):
        """
        Loads the tables from the gtfs object and counts the number of rows that have null values in
        fields that should not be null. Stores the number of null rows in warnings_container
        """
        for table, null_warning in zip(DB_TABLE_NAMES, NULL_WARNINGS):
            # TODO: make this validation source by source
            df = self.gtfs.get_table(table)
            df.drop(FIELDS_WHERE_NULL_OK[table], inplace=True, axis=1)
            # print(df.to_string())
            len_table = len(df.index)
            df.dropna(inplace=True, axis=0)
            len_non_null = len(df.index)
            nullrows = len_table - len_non_null
            if nullrows > 0:
                # print('Warning: Null values detected in table ' + table)
                self.warnings_container.add_warning(self.location,
                                                    null_warning,
                                                    value=nullrows)

    def _validate_danglers(self):
        """
        Checks for rows that are not referenced in the the tables that should be linked

        stops <> stop_times using stop_I
        stop_times <> trips <> days, using trip_I
        trips <> routes, using route_I
        :return:
        """
        for query, warning in zip(DANGLER_QUERIES, DANGLER_WARNINGS):
            dangler_count = self.gtfs.execute_custom_query(query).fetchone()[0]
            if dangler_count > 0:
                print(str(dangler_count) + " " + warning)
                self.warnings_container.add_warning(self.location,
                                                    warning,
                                                    value=dangler_count)

    def _frequency_generated_trips(self, source, txt):
        """
        This function calculates the equivalent rowcounts for trips when
        taking into account the generated rows in the gtfs object
        :param source: path to the source file
        :param txt: txt file in question
        :return: sum of all trips
        """
        df_freq = source_table_txt_to_pandas(source, u'frequencies.txt')
        df_trips = source_table_txt_to_pandas(source, txt)
        df_freq['n_trips'] = df_freq.apply(lambda row: len(
            range(str_time_to_day_seconds(row['start_time']),
                  str_time_to_day_seconds(row['end_time']), row['headway_secs']
                  )),
                                           axis=1)
        self.df_freq_dict[source] = df_freq
        df_trips_freq = pd.merge(df_freq, df_trips, how='outer', on='trip_id')

        return int(df_trips_freq['n_trips'].fillna(1).sum(axis=0))

    def _frequency_generated_stop_times(self, source, txt):
        """
        same as above except for stop times table
        :param source:
        :param txt:
        :return:
        """
        df_stop_times = source_table_txt_to_pandas(source, txt)
        df_freq = self.df_freq_dict[source]
        df_stop_freq = pd.merge(df_freq,
                                df_stop_times,
                                how='outer',
                                on='trip_id')

        return int(df_stop_freq['n_trips'].fillna(1).sum(axis=0))
Exemplo n.º 9
0
for city_id in ALL_CITIES:
    copy_dir_name = os.path.join("copies_from_hammer", city_id)
    try:
        directory_listing = os.listdir(copy_dir_name)
    except FileNotFoundError as e:
        print(e)
        continue

    date_regexp = re.compile("....-..-..")
    for directory_candidate in directory_listing:
        if date_regexp.match(directory_candidate) is not None:
            sqlite_fname = os.path.join(copy_dir_name, directory_candidate,
                                        "week.sqlite")

            if not os.path.exists(sqlite_fname):
                print(sqlite_fname + " does not exist!")
                continue
            G = GTFS(sqlite_fname)
            stop_distances_df = G.get_table("stop_distances")

            print(len(stop_distances_df))

            stop_distances_df = G.get_table("stop_distances")
            fig = plt.figure()
            ax = fig.add_subplot(111)
            ax.scatter(stop_distances_df['d'],
                       stop_distances_df['d_walk'],
                       s=1)
            ax.set_title(city_id)
            plt.show()
Exemplo n.º 10
0
class ImportValidator(object):
    def __init__(self, gtfssource, gtfs, verbose=True):
        """
        Parameters
        ----------
        gtfs_sources: list, string, dict
            list of paths to the strings, or a dictionary directly containing the gtfs data directly
        gtfs: gtfspy.gtfs.GTFS, or path to a relevant .sqlite GTFS database
        verbose: bool
            Whether or not to print warnings on-the-fly.
        """
        if isinstance(gtfssource, string_types + (dict, )):
            self.gtfs_sources = [gtfssource]
        else:
            assert isinstance(gtfssource, list)
            self.gtfs_sources = gtfssource
        assert len(
            self.gtfs_sources
        ) > 0, "There needs to be some source files for validating an import"

        if not isinstance(gtfs, GTFS):
            self.gtfs = GTFS(gtfs)
        else:
            self.gtfs = gtfs

        self.location = self.gtfs.get_location_name()
        self.warnings_container = WarningsContainer()
        self.verbose = verbose

    def validate_and_get_warnings(self):
        self.warnings_container.clear()
        self._validate_table_row_counts()
        self._validate_no_null_values()
        self._validate_danglers()
        return self.warnings_container

    def _validate_table_row_counts(self):
        """
        Imports source .txt files, checks row counts and then compares the rowcounts with the gtfsobject
        :return:
        """
        for db_table_name in DB_TABLE_NAME_TO_SOURCE_FILE.keys():
            table_name_source_file = DB_TABLE_NAME_TO_SOURCE_FILE[
                db_table_name]
            row_warning_str = DB_TABLE_NAME_TO_ROWS_MISSING_WARNING[
                db_table_name]

            # Row count in GTFS object:
            database_row_count = self.gtfs.get_row_count(db_table_name)

            # Row counts in source files:
            source_row_count = 0
            for gtfs_source in self.gtfs_sources:
                frequencies_in_source = source_csv_to_pandas(
                    gtfs_source, 'frequencies.txt')
                try:
                    if table_name_source_file == 'trips' and not frequencies_in_source.empty:
                        source_row_count += self._frequency_generated_trips_rows(
                            gtfs_source)

                    elif table_name_source_file == 'stop_times' and not frequencies_in_source.empty:
                        source_row_count += self._compute_number_of_frequency_generated_stop_times(
                            gtfs_source)
                    else:
                        df = source_csv_to_pandas(gtfs_source,
                                                  table_name_source_file)

                        source_row_count += len(df.index)
                except IOError as e:
                    if hasattr(e, "filename") and db_table_name in e.filename:
                        pass
                    else:
                        raise e

            if source_row_count == database_row_count and self.verbose:
                print("Row counts match for " + table_name_source_file +
                      " between the source and database (" +
                      str(database_row_count) + ")")
            else:
                difference = database_row_count - source_row_count
                ('Row counts do not match for ' + str(table_name_source_file) +
                 ': (source=' + str(source_row_count) + ', database=' +
                 str(database_row_count) + ")")
                if table_name_source_file == "calendar" and difference > 0:
                    query = "SELECT count(*) FROM (SELECT * FROM calendar ORDER BY service_I DESC LIMIT " \
                            + str(int(difference)) + \
                            ") WHERE start_date=end_date AND m=0 AND t=0 AND w=0 AND th=0 AND f=0 AND s=0 AND su=0"
                    number_of_entries_added_by_calendar_dates_loader = self.gtfs.execute_custom_query(
                        query).fetchone()[0]
                    if number_of_entries_added_by_calendar_dates_loader == difference and self.verbose:
                        print(
                            "    But don't worry, the extra entries seem to just dummy entries due to calendar_dates"
                        )
                    else:
                        if self.verbose:
                            print("    Reason for this is unknown.")
                        self.warnings_container.add_warning(
                            row_warning_str, self.location, difference)
                else:
                    self.warnings_container.add_warning(
                        row_warning_str, self.location, difference)

    def _validate_no_null_values(self):
        """
        Loads the tables from the gtfs object and counts the number of rows that have null values in
        fields that should not be null. Stores the number of null rows in warnings_container
        """
        for table in DB_TABLE_NAMES:
            null_not_ok_warning = "Null values in must-have columns in table {table}".format(
                table=table)
            null_warn_warning = "Null values in good-to-have columns in table {table}".format(
                table=table)
            null_not_ok_fields = DB_TABLE_NAME_TO_FIELDS_WHERE_NULL_NOT_OK[
                table]
            null_warn_fields = DB_TABLE_NAME_TO_FIELDS_WHERE_NULL_OK_BUT_WARN[
                table]

            # CW, TODO: make this validation source by source
            df = self.gtfs.get_table(table)

            for warning, fields in zip(
                [null_not_ok_warning, null_warn_warning],
                [null_not_ok_fields, null_warn_fields]):
                null_unwanted_df = df[fields]
                rows_having_null = null_unwanted_df.isnull().any(1)
                if sum(rows_having_null) > 0:
                    rows_having_unwanted_null = df[rows_having_null.values]
                    self.warnings_container.add_warning(
                        warning, rows_having_unwanted_null,
                        len(rows_having_unwanted_null))

    def _validate_danglers(self):
        """
        Checks for rows that are not referenced in the the tables that should be linked

        stops <> stop_times using stop_I
        stop_times <> trips <> days, using trip_I
        trips <> routes, using route_I
        :return:
        """
        for query, warning in zip(DANGLER_QUERIES, DANGLER_WARNINGS):
            dangler_count = self.gtfs.execute_custom_query(query).fetchone()[0]
            if dangler_count > 0:
                if self.verbose:
                    print(str(dangler_count) + " " + warning)
                self.warnings_container.add_warning(warning,
                                                    self.location,
                                                    count=dangler_count)

    def _frequency_generated_trips_rows(self,
                                        gtfs_soure_path,
                                        return_df_freq=False):
        """
        This function calculates the equivalent rowcounts for trips when
        taking into account the generated rows in the gtfs object
        Parameters
        ----------
        gtfs_soure_path: path to the source file
        param txt: txt file in question
        :return: sum of all trips
        """
        df_freq = source_csv_to_pandas(gtfs_soure_path, 'frequencies')
        df_trips = source_csv_to_pandas(gtfs_soure_path, "trips")
        df_freq['n_trips'] = df_freq.apply(lambda row: len(
            range(str_time_to_day_seconds(row['start_time']),
                  str_time_to_day_seconds(row['end_time']), row['headway_secs']
                  )),
                                           axis=1)
        df_trips_freq = pd.merge(df_freq, df_trips, how='outer', on='trip_id')
        n_freq_generated_trips = int(
            df_trips_freq['n_trips'].fillna(1).sum(axis=0))
        if return_df_freq:
            return df_trips_freq
        else:
            return n_freq_generated_trips

    def _compute_number_of_frequency_generated_stop_times(
            self, gtfs_source_path):
        """
        Parameters
        ----------
        Same as for "_frequency_generated_trips_rows" but for stop times table
        gtfs_source_path:
        table_name:

        Return
        ------
        """
        df_freq = self._frequency_generated_trips_rows(gtfs_source_path,
                                                       return_df_freq=True)
        df_stop_times = source_csv_to_pandas(gtfs_source_path, "stop_times")
        df_stop_freq = pd.merge(df_freq,
                                df_stop_times,
                                how='outer',
                                on='trip_id')
        return int(df_stop_freq['n_trips'].fillna(1).sum(axis=0))