def sync(self, parsed_stops: typing.Iterable[parse.Stop]): # NOTE: the stop tree is manually linked together because otherwise SQL # Alchemy's cascades will result in duplicate entries in the DB because the # models do not have PKs yet. stop_id_to_parent_stop_id = {} stops = [] for parsed_stop in parsed_stops: stop = models.Stop.from_parsed_stop(parsed_stop) stop.system_pk = self.feed_update.feed.system_pk if parsed_stop.parent_stop is not None: stop_id_to_parent_stop_id[ parsed_stop.id] = parsed_stop.parent_stop.id else: stop_id_to_parent_stop_id[parsed_stop.id] = None stops.append(stop) persisted_stops, num_added, num_updated = self._merge_entities(stops) stop_id_to_persisted_stops = { stop.id: stop for stop in persisted_stops } # NOTE: flush the session the session to populate the primary keys dbconnection.get_session().flush() for stop_id in stop_id_to_parent_stop_id.keys(): stop = stop_id_to_persisted_stops[stop_id] parent_stop = stop_id_to_persisted_stops.get( stop_id_to_parent_stop_id.get(stop.id)) if parent_stop is not None: stop.parent_stop_pk = parent_stop.pk else: stop.parent_stop_pk = None return num_added, num_updated
def test_outside_unit_of_work_error__after(): @dbconnection.unit_of_work def uow(): return 1 + 1 uow() with pytest.raises(dbconnection.OutsideUnitOfWorkError): dbconnection.get_session()
def _create_feed_update_helper(system_id, feed_id, update_type) -> typing.Optional[int]: feed = feedqueries.get_in_system_by_id(system_id, feed_id) if feed is None: return None feed_update = models.FeedUpdate() feed_update.update_type = update_type feed_update.status = feed_update.Status.SCHEDULED feed_update.feed = feed dbconnection.get_session().add(feed_update) dbconnection.get_session().flush() return feed_update.pk
def list_all_from_feed(feed_pk): session = dbconnection.get_session() query = session.query(models.Trip).filter( models.Trip.source_pk == models.FeedUpdate.pk, models.FeedUpdate.feed_pk == feed_pk, ) return query.all()
def _get_entity_pk_to_active_alerts(entity_type, entity_relationship, pks, current_time, load_messages): pks = list(pks) if len(pks) == 0: return {} if current_time is None: current_time = datetime.datetime.utcnow() query = (dbconnection.get_session().query( entity_type.pk, models.AlertActivePeriod, models.Alert).filter( models.AlertActivePeriod.alert_pk == models.Alert.pk).filter( sql.or_( models.AlertActivePeriod.starts_at <= current_time, models.AlertActivePeriod.starts_at.is_(None), )).filter( sql.or_( models.AlertActivePeriod.ends_at >= current_time, models.AlertActivePeriod.ends_at.is_(None), )).order_by(models.AlertActivePeriod.starts_at).join( entity_relationship).filter(entity_type.pk.in_(pks))) if load_messages: query = query.options(joinedload(models.Alert.messages)) pk_to_alert_pks = collections.defaultdict(set) pk_to_tuple = {pk: [] for pk in pks} for pk, active_period, alert in query.all(): if alert.pk in pk_to_alert_pks[pk]: continue pk_to_alert_pks[pk].add(alert.pk) pk_to_tuple[pk].append((active_period, alert)) return pk_to_tuple
def list_groups_and_maps_for_stops_in_route(route_pk): """ This function is used to get the service maps for a route. It returns a list of tuples (service map group, service map) for each service map group having use_for_stops_in_route equal True. :param route_pk: the route's PK :return: the list described above """ session = dbconnection.get_session() query = (session.query(models.ServiceMapGroup, models.ServiceMap).join( models.System, models.System.pk == models.ServiceMapGroup.system_pk).join( models.Route, models.Route.system_pk == models.System.pk).outerjoin( models.ServiceMap, sql.and_( models.ServiceMap.route_pk == models.Route.pk, models.ServiceMap.group_pk == models.ServiceMapGroup.pk, ), ).filter(models.ServiceMapGroup.use_for_stops_in_route).filter( models.Route.pk == route_pk).options( selectinload(models.ServiceMap.vertices)).options( selectinload(models.ServiceMap.vertices, models.ServiceMapVertex.stop))) return [(group, map_) for (group, map_) in query]
def get_scheduled_trip_pk_to_path_in_system(system_pk): """ Get a map of trip PK to the path of that trip for every scheduled trip in a system. By path is meant a list of stop_pks that the trip stops at. This method completely bypasses the ORM and so relatively efficient. :param system_pk: the system's PK :return: map of trip PK to list of stop PKs """ start_time = time.time() session = dbconnection.get_session() query = (session.query( models.ScheduledTripStopTime.trip_pk, models.ScheduledTripStopTime.stop_pk).join( models.ScheduledTrip, models.ScheduledTrip.pk == models.ScheduledTripStopTime.trip_pk, ).join( models.ScheduledService, sql.and_( models.ScheduledService.pk == models.ScheduledTrip.service_pk, models.ScheduledService.system_pk == system_pk, ), ).order_by( models.ScheduledTripStopTime.trip_pk, models.ScheduledTripStopTime.stop_sequence, )) trip_pk_to_stop_pks = {} for trip_pk, stop_pk in query: if trip_pk not in trip_pk_to_stop_pks: trip_pk_to_stop_pks[trip_pk] = [] trip_pk_to_stop_pks[trip_pk].append(stop_pk) logger.info( "Query get_scheduled_trip_pk_to_path_in_system took {:.2} seconds". format(time.time() - start_time)) return trip_pk_to_stop_pks
def sync(self, parsed_transfers: typing.Iterable[parse.Transfer]): parsed_transfers = list(parsed_transfers) # TODO use the query to speed it up: # stopqueries.delete_transfers_in_system(self.feed_update.feed.system.pk) stop_ids = set() for transfer in parsed_transfers: stop_ids.add(transfer.from_stop_id) stop_ids.add(transfer.to_stop_id) stop_id_to_pk = stopqueries.get_id_to_pk_map_in_system( self.feed_update.feed.system.pk, stop_ids) session = dbconnection.get_session() num_added = 0 for transfer in parsed_transfers: from_stop_pk = stop_id_to_pk.get(transfer.from_stop_id) to_stop_pk = stop_id_to_pk.get(transfer.to_stop_id) if from_stop_pk is None or to_stop_pk is None: continue db_transfer = models.Transfer.from_parsed_transfer(transfer) db_transfer.system_pk = self.feed_update.feed.system.pk db_transfer.source_pk = self.feed_update.pk db_transfer.from_stop_pk = from_stop_pk db_transfer.to_stop_pk = to_stop_pk session.add(db_transfer) num_added += 1 return num_added, 0
def _save_feed_configuration(system, feeds_config): """ Save feeds in a system. Stale feed maps -- those that are currently attached to the system but that do not correspond to entries in the config -- are *not* deleted by this method. :param system: the system to save the service maps in :param feeds_config: the feeds config JSON blob :return: a two-tuple of feed ID collections; the first collection contains IDs for feeds that need to be updated and the second for feeds that need to be deleted """ feed_id_to_pk = genericqueries.get_id_to_pk_map(models.Feed, system.pk) session = dbconnection.get_session() feed_ids_to_update = list() for feed in _build_feeds_from_config(feeds_config): feed.system_pk = system.pk if feed.id in feed_id_to_pk: feed.pk = feed_id_to_pk[feed.id] del feed_id_to_pk[feed.id] logger.info("Updating feed {}/{}".format(system.id, feed.id)) else: logger.info("Creating feed {}/{}".format(system.id, feed.id)) session.merge(feed) if not feed.required_for_install: continue feed_ids_to_update.append(feed.id) feed_ids_to_delete = list(feed_id_to_pk.keys()) return feed_ids_to_update, feed_ids_to_delete
def list_all_stops_in_stop_tree(stop_pk) -> typing.Iterable[models.Stop]: """ List all stops in the stop tree of a given stop. """ session = dbconnection.get_session() # The first CTE retrieves all stop PKs of ancestors of the root stop. ancestor_cte = (session.query( models.Stop.pk, models.Stop.parent_stop_pk).filter(models.Stop.pk == stop_pk).cte( name="ancestor", recursive=True)) ancestor_cte = ancestor_cte.union_all( session.query(models.Stop.pk, models.Stop.parent_stop_pk).filter( models.Stop.pk == ancestor_cte.c.parent_stop_pk)) # The second CTE then retrieves all descendents of all stops from the first CTE. # Because the first CTE returns the root of the stops tree, the second CTE returns # all stops in the tree. relation_cte = (session.query( models.Stop.pk, models.Stop.parent_stop_pk).filter( models.Stop.pk == ancestor_cte.c.pk).cte(name="relation", recursive=True)) relation_cte = relation_cte.union_all( session.query(models.Stop.pk, models.Stop.parent_stop_pk).filter( models.Stop.parent_stop_pk == relation_cte.c.pk)) query = session.query( models.Stop).filter(models.Stop.pk == relation_cte.c.pk) return query.all()
def get_trip_pk_to_stop_time_data( trip_pk_stop_pk_stop_sequence_tuples, ) -> typing.Dict[int, StopTimeData]: session = dbconnection.get_session() conditions = [] for trip_pk, stop_pk, stop_sequence in trip_pk_stop_pk_stop_sequence_tuples: if stop_sequence is None and stop_pk is None: continue sub_conditions = [models.TripStopTime.trip_pk == trip_pk] if stop_pk is not None: sub_conditions.append(models.TripStopTime.stop_pk == stop_pk) if stop_sequence is not None: sub_conditions.append( models.TripStopTime.stop_sequence == stop_sequence) conditions.append(sql.and_(*sub_conditions)) if len(conditions) == 0: return {} query = session.query( models.TripStopTime.trip_pk, models.TripStopTime.pk, models.TripStopTime.stop_sequence, models.TripStopTime.stop_pk, ).filter(sql.or_(*conditions)) result = {} for (trip_pk, stop_time_pk, stop_sequence, stop_pk) in query.all(): result[trip_pk] = StopTimeData(pk=stop_time_pk, stop_sequence=stop_sequence, stop_pk=stop_pk) return result
def list_in_system(DbEntity: models.Base, system_id, order_by_field=None, ids=None): """ List all entities of a certain type that are in a given system. Note this method only works with entities that are direct children of the system. :param DbEntity: the entity's type :param system_id: the system's ID :param order_by_field: optional field to order the results by :param ids: ids to filter on :return: list of entities of type DbEntity """ if ids is not None and len(ids) == 0: return [] session = dbconnection.get_session() query = (session.query(DbEntity).filter( DbEntity.system_pk == models.System.pk).filter( models.System.id == system_id)) if ids is not None: query = query.filter(DbEntity.id.in_(ids)) if order_by_field is not None: query = query.order_by(order_by_field) return query.all()
def build_stop_pk_to_descendant_pks_map(stop_pks, stations_only=False ) -> typing.Dict[int, typing.Set[int]]: """ Construct a map whose key is a stop's pk and value is a list of all stop pks that are descendents of that stop. """ session = dbconnection.get_session() descendant_cte = (session.query( models.Stop.pk.label("ancestor_pk"), models.Stop.pk.label("descendent_pk")).filter( models.Stop.pk.in_(stop_pks)).cte(name="descendent_cte", recursive=True)) recursive_part = session.query( descendant_cte.c.ancestor_pk.label("ancestor_pk"), models.Stop.pk.label("descendent_pk"), ).filter(models.Stop.parent_stop_pk == descendant_cte.c.descendent_pk) if stations_only: recursive_part = recursive_part.filter( models.Stop.type.in_(models.Stop.STATION_TYPES)) descendant_cte = descendant_cte.union_all(recursive_part) stop_pk_to_descendant_pks = collections.defaultdict(set) for stop_pk, descendant_pk in session.query(descendant_cte).all(): stop_pk_to_descendant_pks[stop_pk].add(descendant_pk) return dict(stop_pk_to_descendant_pks)
def get_update_by_pk(feed_update_pk) -> Optional[models.FeedUpdate]: session = dbconnection.get_session() return (session.query(models.FeedUpdate).filter( models.FeedUpdate.pk == feed_update_pk).options( joinedload(models.FeedUpdate.feed)).options( joinedload(models.FeedUpdate.feed, models.Feed.system)).one_or_none())
def list_aggregated_updates(feed_pks, start_time): session = dbconnection.get_session() query = (session.query( models.FeedUpdate.feed_pk, func.count(models.FeedUpdate.status), models.FeedUpdate.status, models.FeedUpdate.result, func.min(models.FeedUpdate.completed_at), func.max(models.FeedUpdate.completed_at), ).group_by( models.FeedUpdate.feed_pk, models.FeedUpdate.status, models.FeedUpdate.result, ).filter( models.FeedUpdate.feed_pk.in_(feed_pks), models.FeedUpdate.completed_at > start_time, models.FeedUpdate.status.in_({ models.FeedUpdate.Status.SUCCESS, models.FeedUpdate.Status.FAILURE }), )) feed_pk_to_updates = {} for row in query.all(): if row[0] not in feed_pk_to_updates: feed_pk_to_updates[row[0]] = [] feed_pk_to_updates[row[0]].append(row[1:]) return feed_pk_to_updates
def list_all(): return ( dbconnection.get_session() .query(models.TransfersConfig) .options(selectinload(models.TransfersConfig.systems)) .all() )
def list_stale_entities(DbEntity: typing.Type[models.Base], feed_update: models.FeedUpdate): session = dbconnection.get_session() query = (session.query(DbEntity).join( models.FeedUpdate, DbEntity.source_pk == models.FeedUpdate.pk).filter( models.FeedUpdate.feed_pk == feed_update.feed_pk).filter( models.FeedUpdate.pk != feed_update.pk)) return query.all()
def delete_stale_entities(DbEntity: typing.Type[models.Base], feed_update: models.FeedUpdate): session = dbconnection.get_session() (session.query(DbEntity).filter( DbEntity.source_pk == models.FeedUpdate.pk).filter( models.FeedUpdate.feed_pk == feed_update.feed_pk).filter( models.FeedUpdate.pk != feed_update.pk)).delete( synchronize_session=False)
def list_by_system_and_trip_ids(system_id, trip_ids): trip_ids = list(trip_ids) if len(trip_ids) == 0: return [] return (dbconnection.get_session().query(models.Trip).join( models.Route).filter(models.Trip.id.in_(trip_ids)).filter( models.Route.system_pk == models.System.pk).filter( models.System.id == system_id).all())
def get_id_to_pk_map_by_feed_pk(DbEntity: typing.Type[models.Base], feed_pk): id_to_pk = {} session = dbconnection.get_session() query = (session.query(DbEntity.id, DbEntity.pk).join( models.FeedUpdate, DbEntity.source_pk == models.FeedUpdate.pk).filter( models.FeedUpdate.feed_pk == feed_pk)) for (id_, pk) in query.all(): id_to_pk[id_] = pk return id_to_pk
def create(system_ids, distance) -> str: systems = _list_systems(system_ids) transfers_config = models.TransfersConfig(distance=distance, systems=systems) _build_and_add_transfers(transfers_config) session = dbconnection.get_session() session.add(transfers_config) session.flush() return transfers_config.id
def list_all_active(): """ List all feeds in active systems. """ session = dbconnection.get_session() query = (session.query(models.Feed).join( models.System, models.System.pk == models.Feed.system_pk).filter( models.System.status == models.System.SystemStatus.ACTIVE)) return query.all()
def get_by_id(DbEntity: models.Base, id_): """ Get an entity by its ID. :param DbEntity: the entity's type :param id_: the entity's ID :return: the entity, if it exists in the DB, or None otherwise """ session = dbconnection.get_session() return session.query(DbEntity).filter(DbEntity.id == id_).one_or_none()
def get_id_to_pk_map_in_system(system_pk, trip_ids): trip_ids = list(trip_ids) if len(trip_ids) == 0: return {} id_to_pk = {id_: None for id_ in trip_ids} for id_, pk in (dbconnection.get_session().query( models.Trip.id, models.Trip.pk).join(models.Route).filter( models.Trip.id.in_(trip_ids)).filter( models.Route.system_pk == system_pk).all()): id_to_pk[id_] = pk return id_to_pk
def count_number_of_related_entities(relationship, instance) -> int: """ Count the number of entities related to an instance along a specified relationship. """ base_type = relationship.class_ related_type = relationship.mapper.class_ session = dbconnection.get_session() query = (session.query(func.count(1)).select_from(base_type).join( related_type, relationship).filter(getattr(base_type, "pk") == instance.pk)) return query.one()[0]
def trim_feed_updates(feed_pk, before_datetime): """ Trip all FeedUpdates for a feed whose last action time was before a certain cut-off point. :param feed_pk: pk of the feed :param before_datetime: the cut-off point """ not_exists_conditions = [ ~sql.exists( sql.select(1).where( UpdatableEntity.source_pk == models.FeedUpdate.pk)) for UpdatableEntity in models.list_updatable_entities() ] query = (sql.delete(models.FeedUpdate).where( sql.and_(models.FeedUpdate.feed_pk == feed_pk, models.FeedUpdate.completed_at <= before_datetime, *not_exists_conditions)).execution_options( synchronize_session=False)) dbconnection.get_session().execute(query)
def get_trip_id_to_pk_map_by_feed_pk(feed_pk): id_to_pk = {} query = (dbconnection.get_session().query( models.ScheduledTrip.id, models.ScheduledTrip.pk).filter( models.ScheduledService.pk == models.ScheduledTrip.service_pk, models.ScheduledService.source_pk == models.FeedUpdate.pk, models.FeedUpdate.feed_pk == feed_pk, )) for (id_, pk) in query.all(): id_to_pk[id_] = pk return id_to_pk
def list_updates_in_feed(feed_pk): """ List the most recent updates in a feed, ordered descending in time. :param feed_pk: the Feed's PK :return: list of FeedUpdates """ session = dbconnection.get_session() query = (session.query(models.FeedUpdate).filter( models.FeedUpdate.feed_pk == feed_pk).order_by( models.FeedUpdate.pk.desc()).limit(100)) return query.all()
def list_all_in_route_by_pk(route_pk): """ List all of the Trips in a route. :param route_pk: the route's PK :return: list of Trips """ session = dbconnection.get_session() query = (session.query( models.Trip).filter(models.Trip.route_pk == route_pk).options( selectinload(models.Trip.stop_times))) return query.all()
def list_all_auto_updating(): """ List all auto-updating Feeds. :return: list of Feeds """ session = dbconnection.get_session() query = (session.query(models.Feed).join( models.System, models.System.pk == models.Feed.system_pk).filter( models.Feed.auto_update_enabled).filter( models.System.auto_update_enabled).filter( models.System.status == models.System.SystemStatus.ACTIVE)) return query.all()