def _set_stop_area_locality(connection): """ Add locality info based on stops contained within the stop areas. """ # Find stop areas with associated locality codes with connection.begin(): query_stop_areas = connection.execute( db.select([ models.StopArea.code.label("code"), models.StopPoint.locality_ref.label("ref"), db.func.count(models.StopPoint.locality_ref).label("count") ]).select_from( models.StopArea.__table__.join( models.StopPoint, models.StopArea.code == models.StopPoint.stop_area_ref)).group_by( models.StopArea.code, models.StopPoint.locality_ref)) stop_areas = query_stop_areas.fetchall() # Find locality for each stop area that contain the most stops areas, ambiguous = _find_stop_area_mode(stop_areas, "locality_ref") # if still ambiguous, measure distance between stop area and each # locality and add to above if ambiguous: add_areas = _find_locality_distance(connection, ambiguous.keys()) areas.extend(add_areas) utils.logger.info("Adding locality codes to stop areas") for a in areas: connection.execute( db.update(models.StopArea).values({ "locality_ref": a["locality_ref"] }).where(models.StopArea.code == a["code"]))
def __init__(self, connection): query_stops = connection.execute( db.select([models.StopPoint.atco_code])) self.stops = {c.atco_code for c in query_stops} self.not_exists = set() if not self.stops: raise ValueError( "No stop points were found. The TNDS dataset requires the " "database to be populated from NaPTAN data first.")
def _set_tram_admin_area(connection): """ Set admin area ref for tram stops and areas to be the same as their localities. """ tram_area = "147" with connection.begin(): # Update stop points admin_area_ref = (db.select([models.Locality.admin_area_ref]).where( models.Locality.code == models.StopPoint.locality_ref).as_scalar()) utils.logger.info("Updating tram stops with admin area ref") connection.execute( db.update(models.StopPoint).values({ models.StopPoint.admin_area_ref: admin_area_ref }).where(models.StopPoint.admin_area_ref == tram_area)) # Find stop areas with associated admin area codes stop_areas = connection.execute( db.select([ models.StopArea.code.label("code"), models.StopPoint.admin_area_ref.label("ref"), db.func.count(models.StopPoint.admin_area_ref).label("count") ]).select_from( models.StopArea.__table__.join( models.StopPoint, models.StopArea.code == models.StopPoint.stop_area_ref)). where(models.StopArea.admin_area_ref == tram_area).group_by( models.StopArea.code, models.StopPoint.admin_area_ref)) areas, ambiguous = _find_stop_area_mode(stop_areas.fetchall(), "admin_area_ref") utils.logger.info("Adding locality codes to stop areas") for a in areas: connection.execute( db.update(models.StopArea).values({ "admin_area_ref": a["admin_area_ref"] }).where(models.StopArea.code == a["code"])) for area, areas in ambiguous.items(): utils.logger.warning(f"Area {area}: ambiguous admin areas {areas}")
def populate_tnds_data(connection, path=None, delete=True, warn=False): """ Commits TNDS data to database. :param connection: Connection for population. :param path: Path for zip files with TNDS XML documents and named after region codes. Global expansion is supported - all unique files matching region codes will be used. The archives will be downloaded if this is None. :param delete: Truncate all data from TNDS tables before populating. :param warn: Log warning if no FTP credentials exist. If False an error will be raised instead. """ data = _get_archives(connection, path, warn) if data is None: return # Check if operators exist first operators_exist = connection.execute( db.exists(db.select([models.Operator.code])).select()).scalar() if not operators_exist: raise ValueError( "No operators were found. The TNDS dataset requires the database " "to be populated from NOC data first.") row_ids = setup_row_ids(connection, check_existing=not delete) setup_stop_exists(connection) setup_service_codes() # We don't want to delete any NOC data if they have been added excluded = models.Operator, models.LocalOperator metadata = utils.reflect_metadata(connection) with open_binary("nextbus.populate", "tnds.xslt") as file_: xslt = et.XSLT(et.parse(file_)) del_ = delete for region, archive in data.items(): for file_ in file_ops.iter_archive(archive): path = os.path.join(os.path.basename(archive), file_.name) utils.logger.info(f"Parsing file {path!r}") try: data = utils.xslt_transform(file_, xslt, region=region, file=file_.name) except RowIdError: # IDs do not match in XML file; log error and move on utils.logger.error(f"Invalid IDs in file {path!r}", exc_info=1) else: utils.populate_database(connection, utils.collect_xml_data(data), metadata=metadata, delete=del_, exclude=excluded) row_ids.clear() del_ = False
def _get_sequence(self, model_name): """ Check database to find next ID to use, or one if empty. """ start = 1 model = getattr(models, model_name, None) if self.existing and model is not None: # Get maximum ID integer from table and start off from there query = db.select([db.func.max(model.id)]) max_ = self._conn.execute(query).scalar() start = max_ + 1 if max_ is not None else start return start
def _get_regions(connection): """ Get list of regions in database excluding GB. If no regions are found, a ValueError is raised. """ query_regions = connection.execute( db.select([models.Region.code]).where(models.Region.code != "GB")) regions = [r[0] for r in query_regions] if not regions: raise ValueError("NPTG data not populated yet.") return regions
def _remove_districts(connection): """ Removes districts without associated localities. """ orphan_districts = (db.select([models.District.code]).select_from( models.District.__table__.outerjoin( models.Locality, models.District.code == models.Locality.district_ref)).where( models.Locality.code.is_(None)).alias("orphan_districts")) utils.logger.info("Deleting orphaned districts") connection.execute( db.delete(models.District.__table__).where( models.District.code.in_(orphan_districts)))
def _replace_row(connection, model, element): """ Replaces values for rows in tables matching attributes from this element. :param connection: Connection for population. :param model: The database model class. :param element: A ``replace`` XML element. :returns: Number of rows replaced. """ name = model.__name__ if not element.keys(): raise ValueError("Each <replace> element requires at least one XML " "attribute to filter rows.") matching = connection.execute( db.select([model.__table__]).where(_match_attr(model, element.attrib))) matching_entries = matching.fetchall() if not matching_entries: logger.warning(f"{name}: No rows matching {element.attrib} found.") return 0 updated_values = {} for value in element: column = value.tag old_value = value.get("old") new_value = value.text existing = set(getattr(r, column) for r in matching_entries) # Checks if new values already exist if existing == {new_value}: logger.warning(f"{name}: {column} {new_value!r} for " f"{element.attrib} already matches.") continue # Gives a warning if set value does not match the existing # value, suggesting it may have been changed in the dataset if old_value and not all(e == old_value for e in existing): if len(existing) > 1: values = f"values {sorted(existing)}" else: values = f"value {next(iter(existing))!r}" logger.warning(f"{name}: {column} {old_value!r} for " f"{element.attrib} does not match existing " f"{values}.") updated_values[column] = new_value if updated_values: # Update matched entries update_matching = connection.execute( db.update(model).values(updated_values).where( _match_attr(model, element.attrib))) return update_matching.rowcount else: return 0
def insert_fts_rows(connection): statements = _select_fts_vectors() service = statements.pop(-1) count_services = db.select([db.func.count()]).select_from(Service) services = connection.execute(count_services).scalar() step = 1000 for offset in range(0, services, step): range_ = (Service.id >= offset) & (Service.id < offset + step) statements.append(service.where(range_)) return statements
def _remove_stop_areas(connection): """ Remove all stop areas without associated stop points. """ orphan_stop_areas = (db.select([models.StopArea.code]).select_from( models.StopArea.__table__.outerjoin( models.StopPoint, models.StopArea.code == models.StopPoint.stop_area_ref)).where( models.StopPoint.atco_code.is_(None)).alias( "orphan_stop_areas")) utils.logger.info("Deleting orphaned stop areas") connection.execute( db.delete(models.StopArea).where( models.StopArea.code.in_(orphan_stop_areas)))
def _find_locality_distance(connection, ambiguous_areas): """ Finds the minimum distance between stop areas and localities for these with ambiguous localities. """ distance = db.func.sqrt( db.func.power(models.StopArea.easting - models.Locality.easting, 2) + db.func.power(models.StopArea.northing - models.Locality.northing, 2)) # Do another query over list of areas to find distance query_distances = connection.execute( db.select([ models.StopArea.code.label("code"), models.Locality.code.label("locality"), distance.label("distance") ]).distinct(models.StopArea.code, models.Locality.code).select_from( models.StopPoint.__table__.join( models.Locality, models.StopPoint.locality_ref == models.Locality.code).join( models.StopArea, models.StopPoint.stop_area_ref == models.StopArea.code)).where( models.StopPoint.stop_area_ref.in_(ambiguous_areas))) # Group by stop area and locality reference stop_areas = collections.defaultdict(dict) for row in query_distances: stop_areas[row.code][row.locality] = row.distance # Check each area and find the minimum distance update_areas = [] for sa, local in stop_areas.items(): min_dist = min(local.values()) # Find all localities with minimum distance from stop area local_min = [k for k, v in local.items() if v == min_dist] # Check if associated localities are far away - may be wrong locality for k, dist in local.items(): if dist > 2 * min_dist and dist > 1000: utils.logger.warning(f"Area {sa}: {dist:.0f} m away from {k}") # Else, check if only one locality matches min distance and set it if len(local_min) == 1: utils.logger.debug(f"Area {sa} set to locality {local_min[0]}, " f"dist {min_dist:.0f} m") update_areas.append({"code": sa, "locality_ref": local_min[0]}) else: utils.logger.warning( f"Area {sa}: ambiguous localities, {min_dist}") return update_areas
def _populate_table(connection, metadata, model, entries, overwrite=False, delete_first=False): """ Fills a table with data by using COPY and an intermediate table so a ON CONFLICT clause can be used. """ if not entries: return table = metadata.tables[model.__table__.name] new_name = _temp_table_name(table) if new_name in metadata.tables: temp_table = metadata.tables[new_name] else: temp_table = _temp_table(metadata, table) with connection.begin(): temp_table.create(connection, checkfirst=True) truncate(connection, temp_table) if delete_first: logger.debug(f"Truncating table {table.name}") truncate(connection, table) logger.debug(f"Copying {len(entries)} rows to table {table.name} via " f"{temp_table.name}") # Add entries to temporary table using COPY _copy(connection, temp_table, entries) # Insert entries from temporary table into main table avoiding conflicts insert = (postgresql.insert(table).from_select( [c.name for c in temp_table.columns], db.select([temp_table]))) if overwrite and not delete_first: p_key = table.primary_key.name cols = { c.name: getattr(insert.excluded, c.name) for c in temp_table.columns } insert = insert.on_conflict_do_update(constraint=p_key, set_=cols) else: insert = insert.on_conflict_do_nothing() connection.execute(insert)
def _array_lines(code): """ Create subquery for an distinct and ordered array of all lines serving a stop. """ subquery = ( db.select([_service.c.line]) .select_from( _service .join(_pattern, _pattern.c.service_ref == _service.c.id) .join(_link, _link.c.pattern_ref == _pattern.c.id) ) .where(_link.c.stop_point_ref == code) .group_by(_service.c.line) .order_by(_service.c.line) .as_scalar() ) return db.func.array(subquery)
class StopArea(db.Model): """ NaPTAN stop areas, eg bus interchanges. """ __tablename__ = "stop_area" code = db.Column(db.VARCHAR(12), primary_key=True) name = db.Column(db.Text, nullable=False, index=True) admin_area_ref = db.Column( db.VARCHAR(3), db.ForeignKey("admin_area.code", ondelete="CASCADE"), nullable=False, index=True ) locality_ref = db.Column( db.VARCHAR(8), db.ForeignKey("locality.code", ondelete="CASCADE"), index=True ) stop_area_type = db.Column(db.VARCHAR(4), nullable=False) active = db.Column(db.Boolean, nullable=False, index=True) latitude = db.Column(db.Float, nullable=False) longitude = db.Column(db.Float, nullable=False) easting = db.deferred(db.Column(db.Integer, nullable=False)) northing = db.deferred(db.Column(db.Integer, nullable=False)) modified = db.deferred(db.Column(db.DateTime)) # Number of stop points associated with this stop area stop_count = db.deferred( db.select([db.cast(db.func.count(), db.Text)]) .where((_stop_point.c.stop_area_ref == code) & _stop_point.c.active) .scalar_subquery() ) stop_points = db.relationship( "StopPoint", backref="stop_area", order_by="StopPoint.name, StopPoint.short_ind", lazy="raise" ) def __repr__(self): return f"<StopArea({self.code!r})>"
class Service(db.Model): """ Service group. """ __tablename__ = "service" id = db.Column(db.Integer, primary_key=True, autoincrement=False) code = db.Column(db.Text, index=True, nullable=True, unique=True) filename = db.Column(db.Text) line = db.Column(db.Text(collation="utf8_numeric"), nullable=False) description = db.Column(db.Text, nullable=False) short_description = db.Column(db.Text, nullable=False) mode = db.Column( db.Integer, db.ForeignKey("service_mode.id"), nullable=False, index=True ) # Get mode name for service mode_name = db.deferred( db.select([ServiceMode.name]) .where(ServiceMode.id == mode) .scalar_subquery() ) patterns = db.relationship("JourneyPattern", backref="service", innerjoin=True, lazy="raise") operators = db.relationship( "Operator", backref=db.backref("services", uselist=True, viewonly=True, order_by="Service.line, Service.description"), primaryjoin="Service.id == JourneyPattern.service_ref", secondary="join(JourneyPattern, LocalOperator, " "(JourneyPattern.local_operator_ref == LocalOperator.code) & " "(JourneyPattern.region_ref == LocalOperator.region_ref))", secondaryjoin="LocalOperator.operator_ref == Operator.code", order_by="Operator.name", viewonly=True, lazy="raise" ) regions = db.relationship( "Region", backref=db.backref("services", uselist=True, order_by="Service.line, Service.description"), primaryjoin="Service.id == JourneyPattern.service_ref", secondary="journey_pattern", secondaryjoin="JourneyPattern.region_ref == Region.code", order_by="Region.name", lazy="raise" ) def has_mirror(self, selected=None): """ Checks directions for all patterns for a service and return the right one. :param selected: Direction initially selected. :returns: New direction based on initial direction or new one if no mirror exists, and boolean indicating a mirror exists. """ set_dir = {p.direction for p in self.patterns} if set_dir == {True, False}: reverse = bool(selected) if selected is not None else False has_mirror = True else: reverse = set_dir.pop() has_mirror = False return reverse, has_mirror def similar(self, direction=None, threshold=None): """ Find all services sharing stops with this service in a direction. :param direction: Service direction, or None to include both. :param threshold: Minimum similarity value, or None to include all. """ id_ = db.bindparam("id", self.id) similar0 = ( db.session.query(_pair.c.service0.label("id"), _pair.c.direction0.label("direction")) .filter(_pair.c.service1 == id_) ) similar1 = ( db.session.query(_pair.c.service1.label("id"), _pair.c.direction1.label("direction")) .filter(_pair.c.service0 == id_) ) if direction is not None: dir_ = db.bindparam("dir", direction) similar0 = similar0.filter(_pair.c.direction1 == dir_) similar1 = similar1.filter(_pair.c.direction0 == dir_) if threshold is not None: value = db.bindparam("threshold", threshold) similar0 = similar0.filter(_pair.c.similarity > value) similar1 = similar1.filter(_pair.c.similarity > value) service = db.aliased(Service, name="service") similar = db.union_all(similar0, similar1).alias() return ( db.session.query( service, JourneyPattern.direction, db.func.string_agg(JourneyPattern.origin.distinct(), ' / ') .label("origin"), db.func.string_agg(JourneyPattern.destination.distinct(), ' / ') .label("destination") ) .join(similar, similar.c.id == service.id) .join(JourneyPattern, (service.id == JourneyPattern.service_ref) & (similar.c.direction == JourneyPattern.direction)) .group_by(service, similar.c.direction, JourneyPattern.direction) .order_by(service.line, service.description, similar.c.direction) .all() )
def _filter_journey_dates(query, date): """ Join multiple tables used to filter journeys by valid dates (eg week days, bank holidays or organisation working days). It is assumed the Journey and JourneyPattern models are in the FROM clause or joined, and the query is grouped by the journey ID. """ # Aggregate the matching bank holidays and operational periods before # joining laterally. If they were joined first the query planner may pick a # slower plan to compensate for the row count 'blowing up', but in practice # the actual number of matching rows is very low. # Match special period if they fall within inclusive date range matching_periods = ( db.select([ db.func.bool_and(models.SpecialPeriod.operational) .label("is_operational") ]) .select_from(models.SpecialPeriod) .where( models.SpecialPeriod.journey_ref == models.Journey.id, models.SpecialPeriod.date_start <= date, models.SpecialPeriod.date_end >= date, ) .lateral("matching_periods") ) query = query.join(matching_periods, db.true()) # Match bank holidays on the same day matching_bank_holidays = ( db.select([ db.func.bool_and( _bit_array_contains( models.Journey.include_holidays, models.BankHolidayDate.holiday_ref, ), ).label("is_operational"), db.func.bool_or( _bit_array_contains( models.Journey.exclude_holidays, models.BankHolidayDate.holiday_ref, ), ).label("not_operational") ]) .select_from(models.BankHolidayDate) .where(models.BankHolidayDate.date == date) .lateral("matching_bank_holidays") ) query = query.join(matching_bank_holidays, db.true()) # Match organisations working/holiday periods - can be operational # during holiday or working periods associated with organisation so # working attributes need to match (eg journey running during holidays # must match with operating periods for holidays or vice versa) matching_organisations = ( db.select([ db.func.bool_and( models.OperatingPeriod.id.isnot(None) & models.ExcludedDate.id.is_(None) & models.Organisations.operational ).label("is_operational"), db.func.bool_or( models.OperatingPeriod.id.isnot(None) & models.ExcludedDate.id.is_(None) & db.not_(models.Organisations.operational) ).label("not_operational") ]) .select_from(models.Organisations) .join( models.Organisation, models.Organisations.org_ref == models.Organisation.code, ) .join( models.OperatingPeriod, db.and_( models.Organisation.code == models.OperatingPeriod.org_ref, models.Organisations.working == models.OperatingPeriod.working, models.OperatingPeriod.date_start <= date, models.OperatingPeriod.date_end.is_(None) | (models.OperatingPeriod.date_end >= date), ), ) .outerjoin( models.ExcludedDate, db.and_( models.Organisation.code == models.ExcludedDate.org_ref, models.Organisations.working == models.ExcludedDate.working, models.ExcludedDate.date == date, ) ) .where(models.Organisations.journey_ref == models.Journey.id) .lateral("matching_organisations") ) query = query.join(matching_organisations, db.true()) # Find week of month (0 to 4) and day of week (Monday 1 to Sunday 7) week = db.cast(db.extract("DAY", date), db.Integer) / db.literal_column("7") weekday = db.cast(db.extract("ISODOW", date), db.Integer) query = query.filter( # Date must be within range for journey pattern, may be unbounded models.JourneyPattern.date_start <= date, models.JourneyPattern.date_end.is_(None) | (models.JourneyPattern.date_end >= date), # In order of precedence: # - Do not run on special days # - Do not run on bank holidays # - Run on special days # - Run on bank holidays # - Do not run during organisation working or holiday periods # - Run during organisation working or holiday periods # - Run or not run on specific weeks of month # - Run or not run on specific days of week matching_periods.c.is_operational.isnot(None) | matching_bank_holidays.c.is_operational.isnot(None) | (models.Journey.weeks.is_(None) | _bit_array_contains(models.Journey.weeks, week)) & _bit_array_contains(models.Journey.days, weekday) ) # Bank holidays and special dates have precedence over others so only # include journeys if all references are either null or are operational. # Include non-null references in WHERE so they can be checked here. # Check organisation working/holiday periods here after grouping as # there can be multiple periods for an organisation. query = query.having(db.func.bool_and( db.case([ (matching_periods.c.is_operational.isnot(None), matching_periods.c.is_operational), (matching_bank_holidays.c.not_operational, db.false()), (matching_bank_holidays.c.is_operational, db.true()), (matching_organisations.c.not_operational, db.false()), (matching_organisations.c.is_operational, db.true()), ], else_=db.true()) )) return query
def _query_next_services(atco_code, timestamp=None, interval=None): """ Creates query for getting all services stopping at this stop point in an interval. """ if timestamp is None: p_timestamp = db.func.now() elif timestamp.tzinfo is None: # Assume this is a local timestamp with GB timezone. p_timestamp = db.func.timezone(_GB_TZ, db.bindparam("timestamp", timestamp)) else: p_timestamp = db.bindparam("timestamp", timestamp) if interval is None: p_interval = _ONE_HOUR else: param = db.bindparam("interval", interval) p_interval = db.cast(param, db.Interval) journey_match = _query_journeys_at_stop(atco_code).cte("journey_match") time_start = p_timestamp - journey_match.c.t_offset time_end = time_start + p_interval times = ( db.select([ time_start.label("utc_start"), time_end.label("utc_end"), db.func.timezone(_GB_TZ, time_start).label("local_start"), db.func.timezone(_GB_TZ, time_end).label("local_end"), ]) .correlate(journey_match) .lateral("times") ) local_start_date = db.cast(times.c.local_start, db.Date) local_end_date = db.cast(times.c.local_end, db.Date) local_start_time = db.cast(times.c.local_start, db.Time) local_end_time = db.cast(times.c.local_end, db.Time) journey_departure = ( db.session.query( journey_match.c.id, journey_match.c.t_offset, times.c.utc_start, times.c.utc_end, db.case( ( (local_start_date == local_end_date) | (journey_match.c.departure > local_start_time), local_start_date ), else_=local_end_date, ).label("date"), journey_match.c.departure.label("time"), ) .select_from(journey_match) .join(times, db.true()) .filter( (local_start_date == local_end_date) & db.between( journey_match.c.departure, local_start_time, local_end_time ) | (local_start_date < local_end_date) & ( (journey_match.c.departure > local_start_time) | (journey_match.c.departure < local_end_time) ) ) .cte("journey_departure") ) utc_departures = _get_departure_range( journey_departure.c.date + journey_departure.c.time, "utc_departure", ) utc_departure = db.column("utc_departure") journey_filter = _filter_journey_dates( db.session.query( journey_departure.c.id, (utc_departure + journey_departure.c.t_offset).label("expected") ) .select_from(journey_departure) .join( utc_departures, db.between( utc_departure, journey_departure.c.utc_start, journey_departure.c.utc_end, ) ) .join(models.Journey, journey_departure.c.id == models.Journey.id) .join(models.Journey.pattern) .group_by( journey_departure.c.id, journey_departure.c.t_offset, utc_departure ), journey_departure.c.date, ).cte("journey_filter") query = ( db.session.query( models.Service.line.label("line"), models.JourneyPattern.origin.label("origin"), models.JourneyPattern.destination.label("destination"), models.Operator.code.label("op_code"), models.Operator.name.label("op_name"), journey_filter.c.expected, db.cast(db.extract("EPOCH", journey_filter.c.expected - p_timestamp), db.Integer).label("seconds") ) .select_from(journey_filter) .join(models.Journey, journey_filter.c.id == models.Journey.id) .join(models.Journey.pattern) .join(models.JourneyPattern.service) .join(models.JourneyPattern.operator) .order_by(journey_filter.c.expected) ) return query
def _create_journey_data_query(region): zero = db.cast("0", db.Interval) # For each link, add running and wait intervals from journey-specific link, # journey pattern link or zero if both are null sum_coalesced_times = db.func.sum( db.func.coalesce( JourneySpecificLink.run_time, JourneyLink.run_time, zero, ) + db.func.coalesce( JourneySpecificLink.wait_arrive, JourneyLink.wait_arrive, zero, ) + db.func.coalesce( JourneySpecificLink.wait_leave, JourneyLink.wait_leave, zero, ) ) # Sum all running and wait intervals from preceding rows plus this row's # running interval for arrival time arrive = ( sum_coalesced_times.over( partition_by=Journey.id, order_by=JourneyLink.sequence, rows=(None, -1) ) + db.func.coalesce( JourneySpecificLink.run_time, JourneyLink.run_time, zero, ) ) # Sum all running and wait intervals from preceding rows and this row depart = sum_coalesced_times.over( partition_by=Journey.id, order_by=JourneyLink.sequence, rows=(None, 0) ) last_sequence = ( db.func.max(JourneyLink.sequence).over(partition_by=Journey.id) ) jl_start = db.aliased(JourneyLink) jl_end = db.aliased(JourneyLink) times = ( db.select([ Journey.id, JourneyLink.stop_point_ref, JourneyLink.timing_point, # Journey may call or not call at this stop point db.func.coalesce( JourneySpecificLink.stopping, JourneyLink.stopping ).label("stopping"), JourneyLink.sequence, # Find arrival time if not first stop in journey db.case([(JourneyLink.sequence == 1, None)], else_=arrive).label("arrive"), # Find departure time if not last stop in journey db.case([(JourneyLink.sequence == last_sequence, None)], else_=depart).label("depart"), ]) .select_from(Journey) .join(Journey.pattern) .join(JourneyPattern.links) .outerjoin(jl_start, Journey.start_run == jl_start.id) .outerjoin(jl_end, Journey.end_run == jl_end.id) .outerjoin( JourneySpecificLink, (Journey.id == JourneySpecificLink.journey_ref) & (JourneyLink.id == JourneySpecificLink.link_ref) ) # Truncate journey pattern if journey has starting or ending dead runs .where( JourneyPattern.region_ref == region, jl_start.id.is_(None) | (JourneyLink.sequence >= jl_start.sequence), jl_end.id.is_(None) | (JourneyLink.sequence <= jl_end.sequence) ) .cte("times") ) # Take the record set from the CTE and build a JSON array of objects build_object = db.func.jsonb_build_object( "stop_point_ref", times.c.stop_point_ref, "timing_point", times.c.timing_point, "stopping", times.c.stopping, "sequence", times.c.sequence, "arrive", times.c.arrive, "depart", times.c.depart ) array = db.func.jsonb_agg(build_object).label("data") return db.select([times.c.id, array]).group_by(times.c.id)
def populate_naptan_data(connection, archive=None, list_files=None, split=True): """ Convert NaPTAN data (stop points and areas) to database objects and commit them to the application database. :param connection: Connection for population :param archive: Path to zipped archive file for NaPTAN XML files. :param list_files: List of file paths for NaPTAN XML files. :param split: Splits NaPTAN XML files in archive by admin area code. Has no effect if list_files is used. """ # Get complete list of ATCO admin areas and localities from NPTG data query_area = connection.execute(db.select([models.AdminArea.code])) query_local = connection.execute(db.select([models.Locality.code])) areas = [a[0] for a in query_area] localities = [local[0] for local in query_local] if not areas or not localities: raise ValueError("NPTG tables are not populated; stop point data " "cannot be added without the required locality data. " "Populate the database with NPTG data first.") temp = current_app.config.get("TEMP_DIRECTORY") if not temp: raise ValueError("TEMP_DIRECTORY is not defined.") if archive is not None and list_files is not None: raise ValueError("Can't specify both archive file and list of files.") elif archive is not None: path = archive elif list_files is not None: path = None else: downloaded = file_ops.download(NAPTAN_URL, directory=temp, params={"dataFormat": "XML"}) utils.logger.info(f"Zipping {downloaded!r}") # The downloaded file is not zipped. Move it into an archive path = os.path.join(temp, "NaPTAN.zip") with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED) as zf: zf.write(downloaded) os.remove(downloaded) if path is not None and split: split_path = os.path.join(temp, "NaPTAN_split.zip") _split_naptan_data(areas, path, split_path) path = split_path if path is not None: iter_files = file_ops.iter_archive(path) else: iter_files = iter(list_files) # Go through data and create objects for committing to database _setup_naptan_functions() metadata = utils.reflect_metadata(connection) with open_binary("nextbus.populate", "naptan.xslt") as file_: xslt = et.XSLT(et.parse(file_)) deleted = False for i, file_ in enumerate(iter_files): file_name = file_.name if hasattr(file_, "name") else file_ utils.logger.info(f"Parsing file {file_name!r}") utils.populate_database(connection, utils.collect_xml_data( utils.xslt_transform(file_, xslt)), metadata=metadata, delete=not deleted) deleted = True
def _select_fts_vectors(): """ Helper function to create a query for the full text search materialized view. Core expressions are required, because the session has not been set up yet, though ORM models and attributes can still be used. """ null = db.literal_column("NULL") region = (db.select([ utils.table_name(Region).label("table_name"), db.cast(Region.code, db.Text).label("code"), Region.name.label("name"), null.label("indicator"), null.label("street"), null.label("stop_type"), null.label("stop_area_ref"), null.label("locality_name"), null.label("district_name"), null.label("admin_area_ref"), null.label("admin_area_name"), db.cast(pg.array(()), pg.ARRAY(db.Text)).label("admin_areas"), _tsvector_column((Region.name, "A")).label("vector") ]).where(Region.code != 'GB')) admin_area = (db.select([ utils.table_name(AdminArea).label("table_name"), db.cast(AdminArea.code, db.Text).label("code"), AdminArea.name.label("name"), null.label("indicator"), null.label("street"), null.label("stop_type"), null.label("stop_area_ref"), null.label("locality_name"), null.label("district_name"), AdminArea.code.label("admin_area_ref"), AdminArea.name.label("admin_area_name"), pg.array((AdminArea.code, )).label("admin_areas"), _tsvector_column((AdminArea.name, "A")).label("vector") ]).where(AdminArea.region_ref != 'GB')) district = (db.select([ utils.table_name(District).label("table_name"), db.cast(District.code, db.Text).label("code"), District.name.label("name"), null.label("indicator"), null.label("street"), null.label("stop_type"), null.label("stop_area_ref"), null.label("locality_name"), null.label("district_name"), AdminArea.code.label("admin_area_ref"), AdminArea.name.label("admin_area_name"), db.cast(pg.array((AdminArea.code, )), pg.ARRAY(db.Text)).label("admin_areas"), _tsvector_column((District.name, "A"), (AdminArea.name, "C")).label("vector") ]).select_from( District.__table__.join(AdminArea, AdminArea.code == District.admin_area_ref))) locality = (db.select([ utils.table_name(Locality).label("table_name"), db.cast(Locality.code, db.Text).label("code"), Locality.name.label("name"), null.label("indicator"), null.label("street"), null.label("stop_type"), null.label("stop_area_ref"), null.label("locality_name"), District.name.label("district_name"), AdminArea.code.label("admin_area_ref"), AdminArea.name.label("admin_area_name"), db.cast(pg.array((AdminArea.code, )), pg.ARRAY(db.Text)).label("admin_areas"), _tsvector_column((Locality.name, "A"), (db.func.coalesce(District.name, ""), "C"), (AdminArea.name, "C")).label("vector") ]).select_from( Locality.__table__.outerjoin( District, District.code == Locality.district_ref).join( AdminArea, AdminArea.code == Locality.admin_area_ref)).where( db.exists([ StopPoint.atco_code ]).where(StopPoint.locality_ref == Locality.code))) stop_area = (db.select([ utils.table_name(StopArea).label("table_name"), db.cast(StopArea.code, db.Text).label("code"), StopArea.name.label("name"), db.cast(db.func.count(StopPoint.atco_code), db.Text).label("indicator"), null.label("street"), StopArea.stop_area_type.label("stop_type"), null.label("stop_area_ref"), Locality.name.label("locality_name"), District.name.label("district_name"), AdminArea.code.label("admin_area_ref"), AdminArea.name.label("admin_area_name"), db.cast(pg.array((AdminArea.code, )), pg.ARRAY(db.Text)).label("admin_areas"), _tsvector_column((StopArea.name, "B"), (db.func.coalesce(Locality.name, ""), "C"), (db.func.coalesce(District.name, ""), "D"), (AdminArea.name, "D")).label("vector") ]).select_from( StopArea.__table__.join( StopPoint, (StopArea.code == StopPoint.stop_area_ref) & StopPoint.active).outerjoin( Locality, Locality.code == StopArea.locality_ref).outerjoin( District, District.code == Locality.district_ref).join( AdminArea, AdminArea.code == StopArea.admin_area_ref)).where( StopArea.active).group_by(StopArea.code, Locality.name, District.name, AdminArea.code)) stop_point = (db.select([ utils.table_name(StopPoint).label("table_name"), db.cast(StopPoint.atco_code, db.Text).label("code"), StopPoint.name.label("name"), StopPoint.short_ind.label("indicator"), StopPoint.street.label("street"), StopPoint.stop_type.label("stop_type"), StopPoint.stop_area_ref.label("stop_area_ref"), Locality.name.label("locality_name"), District.name.label("district_name"), AdminArea.code.label("admin_area_ref"), AdminArea.name.label("admin_area_name"), db.cast(pg.array((AdminArea.code, )), pg.ARRAY(db.Text)).label("admin_areas"), _tsvector_column((StopPoint.name, "B"), (db.func.coalesce(StopPoint.street, ""), "B"), (Locality.name, "C"), (db.func.coalesce(District.name, ""), "D"), (AdminArea.name, "D")).label("vector") ]).select_from( StopPoint.__table__.join( Locality, Locality.code == StopPoint.locality_ref).outerjoin( District, District.code == Locality.district_ref).join( AdminArea, AdminArea.code == StopPoint.admin_area_ref)).where( StopPoint.active)) service = (db.select([ utils.table_name(Service).label("table_name"), Service.code.label("code"), Service.short_description.label("name"), Service.line.label("indicator"), null.label("street"), null.label("stop_type"), null.label("stop_area_ref"), null.label("locality_name"), null.label("district_name"), null.label("admin_area_ref"), null.label("admin_area_name"), db.cast(db.func.array_agg(db.distinct(AdminArea.code)), pg.ARRAY(db.Text)).label("admin_areas"), _tsvector_column( (Service.line, "B"), (Service.description, "B"), (db.func.coalesce( db.func.string_agg(db.distinct(Operator.name), " "), ""), "C"), (db.func.string_agg(db.distinct(Locality.name), " "), "C"), (db.func.coalesce( db.func.string_agg(db.distinct(District.name), " "), ""), "D"), (db.func.string_agg(db.distinct(AdminArea.name), " "), "D")).label("vector") ]).select_from( Service.__table__.join( JourneyPattern, Service.id == JourneyPattern.service_ref).join( LocalOperator, (JourneyPattern.local_operator_ref == LocalOperator.code) & (JourneyPattern.region_ref == LocalOperator.region_ref)).outerjoin( Operator, LocalOperator.operator_ref == Operator.code). join(JourneyLink, JourneyPattern.id == JourneyLink.pattern_ref).join( StopPoint, (JourneyLink.stop_point_ref == StopPoint.atco_code) & StopPoint.active).join( Locality, StopPoint.locality_ref == Locality.code).outerjoin( District, Locality.district_ref == District.code).join( AdminArea, Locality.admin_area_ref == AdminArea.code)).group_by( Service.id)) return [ region, admin_area, district, locality, stop_area, stop_point, service ]
def insert_service_pairs(connection): """ Uses existing service data to update list of pairs of similar services. """ service = Service.__table__ pattern = JourneyPattern.__table__ link = JourneyLink.__table__ # Temporary table for all services, direction and stops they call at service_stops = db.Table("service_stops", db.MetaData(), db.Column("id", db.Integer, autoincrement=False, index=True), db.Column("stop_point_ref", db.Text, nullable=True, index=True), db.Column("outbound", db.Boolean, nullable=False, index=True), db.Column("inbound", db.Boolean, nullable=False, index=True), prefixes=["TEMPORARY"], postgresql_on_commit="DROP") fill_service_stops = service_stops.insert().from_select( ["id", "stop_point_ref", "outbound", "inbound"], db.select([ service.c.id, db.cast(link.c.stop_point_ref, db.Text), db.func.bool_or(~pattern.c.direction), db.func.bool_or(pattern.c.direction) ]).select_from( service.join(pattern, service.c.id == pattern.c.service_ref).join( link, pattern.c.id == link.c.pattern_ref)).where( link.c.stop_point_ref.isnot(None)).group_by( service.c.id, db.cast(link.c.stop_point_ref, db.Text))) # Find all services sharing at least one stop ss0 = service_stops.alias("ss0") ss1 = service_stops.alias("ss1") shared = (db.select([ ss0.c.id.label("id0"), ss1.c.id.label("id1") ]).distinct().select_from( ss0.join(ss1, (ss0.c.id < ss1.c.id) & (ss0.c.stop_point_ref == ss1.c.stop_point_ref))).alias("t")) # Iterate over possible combinations of directions directions = (db.select([db.column("d", db.Integer)]).select_from( db.func.generate_series(0, 3).alias("d")).alias("d")) # For each service, find all stops and count them select_a = db.select([ service_stops.c.stop_point_ref ]).where((service_stops.c.id == shared.c.id0) & ((directions.c.d.op("&")(1) > 0) & service_stops.c.inbound | (directions.c.d.op("&") (1) == 0) & service_stops.c.outbound)).correlate( shared, directions) select_b = db.select([ service_stops.c.stop_point_ref ]).where((service_stops.c.id == shared.c.id1) & ((directions.c.d.op("&")(2) > 0) & service_stops.c.inbound | (directions.c.d.op("&") (2) == 0) & service_stops.c.outbound)).correlate( shared, directions) select_c = select_a.intersect(select_b) count = db.func.count(db.literal_column("*")).label("count") la = db.select([count]).select_from(select_a.alias("a")).lateral("a") lb = db.select([count]).select_from(select_b.alias("b")).lateral("b") lc = db.select([count]).select_from(select_c.alias("c")).lateral("c") utils.logger.info( "Querying all services and stops they call at to find similar services" ) service_stops.create(connection) connection.execute(fill_service_stops) return [ db.select([ db.func.row_number().over(), shared.c.id0, (directions.c.d.op("&")(1) > 0).label("dir0"), la.c.count, shared.c.id1, (directions.c.d.op("&")(2) > 0).label("dir1"), lb.c.count, db.cast(lc.c.count, db.Float) / db.func.least(la.c.count, lb.c.count) ]).where((la.c.count > 0) & (lb.c.count > 0) & (lc.c.count > 0)) ]
def insert_journey_data(connection): # Split the queries into each region result = connection.execute(db.select([Region.code])).all() return [_create_journey_data_query(r[0]) for r in result]