示例#1
0
文件: derived.py 项目: macph/nextbus
    def ts_rank(cls, query):
        """ Full text search rank expression with a tsquery.

            :param query: Query as string.
            :returns: Expression to be used in a query.
        """
        dict_ = db.bindparam("dictionary", cls.DICTIONARY)
        tsquery = db.func.querytree(db.func.websearch_to_tsquery(dict_, query))

        return db.func.ts_rank(db.bindparam("weights", cls.WEIGHTS),
                               cls.vector, db.cast(tsquery, TSQUERY), 1)
示例#2
0
    def similar(self, direction=None, threshold=None):
        """ Find all services sharing stops with this service in a direction.

            :param direction: Service direction, or None to include both.
            :param threshold: Minimum similarity value, or None to include all.
        """
        id_ = db.bindparam("id", self.id)
        similar0 = (
            db.session.query(_pair.c.service0.label("id"),
                             _pair.c.direction0.label("direction"))
            .filter(_pair.c.service1 == id_)
        )
        similar1 = (
            db.session.query(_pair.c.service1.label("id"),
                             _pair.c.direction1.label("direction"))
            .filter(_pair.c.service0 == id_)
        )

        if direction is not None:
            dir_ = db.bindparam("dir", direction)
            similar0 = similar0.filter(_pair.c.direction1 == dir_)
            similar1 = similar1.filter(_pair.c.direction0 == dir_)

        if threshold is not None:
            value = db.bindparam("threshold", threshold)
            similar0 = similar0.filter(_pair.c.similarity > value)
            similar1 = similar1.filter(_pair.c.similarity > value)

        service = db.aliased(Service, name="service")
        similar = db.union_all(similar0, similar1).alias()

        return (
            db.session.query(
                service,
                JourneyPattern.direction,
                db.func.string_agg(JourneyPattern.origin.distinct(), ' / ')
                .label("origin"),
                db.func.string_agg(JourneyPattern.destination.distinct(), ' / ')
                .label("destination")
            )
            .join(similar, similar.c.id == service.id)
            .join(JourneyPattern,
                  (service.id == JourneyPattern.service_ref) &
                  (similar.c.direction == JourneyPattern.direction))
            .group_by(service, similar.c.direction, JourneyPattern.direction)
            .order_by(service.line, service.description, similar.c.direction)
            .all()
        )
示例#3
0
def _query_journeys_at_stop(atco_code):
    """ Creates query for journeys stopping at a specified stop point. """
    p_atco_code = db.bindparam("atco_code", atco_code)

    data = models.Journey.record_set()
    query = (
        db.session.query(
            models.Journey.id,
            models.Journey.departure,
            data.c.depart.label("t_offset"),
        )
        .select_from(models.Journey)
        .join(models.Journey.pattern)
        .join(models.JourneyPattern.links)
        .join(
            data,
            db.and_(
                data.c.stopping,
                data.c.depart.isnot(None),
                models.JourneyLink.sequence == data.c.sequence,
            )
        )
        .filter(models.JourneyLink.stop_point_ref == p_atco_code)
    )

    return query
示例#4
0
    def call(cls, limit):
        """ Request a call, checking whether it was within the daily limit.
            :param limit: The limit on number of calls each day starting at
            00:00 UTC. Ignored if is None or negative.
        """
        tz = db.bindparam("utc", "UTC")
        one = db.literal_column("1")
        today = db.func.date(db.func.timezone(tz, db.func.now()))
        date_last_called = db.func.date(db.func.timezone(tz, cls.last_called))

        statement = (
            db.update(cls)
            .values(
                last_called=db.func.now(),
                call_count=db.case(
                    (date_last_called < today, one),
                    else_=cls.call_count + one,
                ),
            )
            .returning(cls.call_count)
        )
        count = db.session.execute(statement).scalar()

        if limit is None or limit < 0:
            utils.logger.debug(f"Request limit {limit!r} ignored")
            return True
        elif count <= limit:
            utils.logger.debug(f"Request was allowed: {count} <= {limit}")
            return True
        else:
            utils.logger.warning(f"Request limit exceeded: {count} > {limit}")
            return False
示例#5
0
文件: derived.py 项目: macph/nextbus
    def match(cls, query):
        """ Full text search expression with a tsquery.

            :param query: Query as string.
            :returns: Expression to be used in a query.
        """
        dict_ = db.bindparam("dictionary", cls.DICTIONARY)
        return cls.vector.op("@@")(db.func.websearch_to_tsquery(dict_, query))
示例#6
0
def _query_journeys(service_id, direction, date):
    """ Creates query to find all IDs for journeys that run on a particular day.

        Journeys are included and excluded them by matching with special dates,
        bank holidays, date ranges associated with organisations, weeks of month
        and days of week.
    """
    # Set as parameters for SQL query - reduces repetition of dates
    p_service_id = db.bindparam("service", service_id)
    p_direction = db.bindparam("direction", direction)
    p_date = db.bindparam("date", date, type_=db.Date)

    departures = _get_departure_range(
        p_date + models.Journey.departure,
        "departures"
    )
    departure = db.column("departures")

    # Find all journeys and their departures
    journeys = (
        db.session.query(
            models.Journey.id.label("journey_id"),
            departure.label("departure")
        )
        .select_from(models.JourneyPattern)
        .join(models.JourneyPattern.journeys)
        # SQLAlchemy does not have CROSS JOIN so use INNER JOIN ON TRUE
        .join(departures, db.true())
        .filter(
            # Match journey patterns on service ID and direction
            models.JourneyPattern.service_ref == p_service_id,
            models.JourneyPattern.direction.is_(p_direction),
            # Filter out generated times 1 hour before and after departures
            db.extract("HOUR", db.func.timezone(_GB_TZ, departure)) ==
            db.extract("HOUR", models.Journey.departure),
        )
        .group_by(models.Journey.id, departure)
    )

    # Add filters for departure dates
    journeys = _filter_journey_dates(journeys, p_date)

    return journeys
示例#7
0
def _query_next_services(atco_code, timestamp=None, interval=None):
    """ Creates query for getting all services stopping at this stop point in an
        interval.
    """
    if timestamp is None:
        p_timestamp = db.func.now()
    elif timestamp.tzinfo is None:
        # Assume this is a local timestamp with GB timezone.
        p_timestamp = db.func.timezone(_GB_TZ, db.bindparam("timestamp", timestamp))
    else:
        p_timestamp = db.bindparam("timestamp", timestamp)

    if interval is None:
        p_interval = _ONE_HOUR
    else:
        param = db.bindparam("interval", interval)
        p_interval = db.cast(param, db.Interval)

    journey_match = _query_journeys_at_stop(atco_code).cte("journey_match")

    time_start = p_timestamp - journey_match.c.t_offset
    time_end = time_start + p_interval
    times = (
        db.select([
            time_start.label("utc_start"),
            time_end.label("utc_end"),
            db.func.timezone(_GB_TZ, time_start).label("local_start"),
            db.func.timezone(_GB_TZ, time_end).label("local_end"),
        ])
        .correlate(journey_match)
        .lateral("times")
    )

    local_start_date = db.cast(times.c.local_start, db.Date)
    local_end_date = db.cast(times.c.local_end, db.Date)
    local_start_time = db.cast(times.c.local_start, db.Time)
    local_end_time = db.cast(times.c.local_end, db.Time)

    journey_departure = (
        db.session.query(
            journey_match.c.id,
            journey_match.c.t_offset,
            times.c.utc_start,
            times.c.utc_end,
            db.case(
                (
                    (local_start_date == local_end_date) |
                    (journey_match.c.departure > local_start_time),
                    local_start_date
                ),
                else_=local_end_date,
            ).label("date"),
            journey_match.c.departure.label("time"),
        )
        .select_from(journey_match)
        .join(times, db.true())
        .filter(
            (local_start_date == local_end_date) &
            db.between(
                journey_match.c.departure,
                local_start_time,
                local_end_time
            ) |
            (local_start_date < local_end_date) &
            (
                (journey_match.c.departure > local_start_time) |
                (journey_match.c.departure < local_end_time)
            )
        )
        .cte("journey_departure")
    )

    utc_departures = _get_departure_range(
        journey_departure.c.date + journey_departure.c.time,
        "utc_departure",
    )
    utc_departure = db.column("utc_departure")

    journey_filter = _filter_journey_dates(
        db.session.query(
            journey_departure.c.id,
            (utc_departure + journey_departure.c.t_offset).label("expected")
        )
        .select_from(journey_departure)
        .join(
            utc_departures,
            db.between(
                utc_departure,
                journey_departure.c.utc_start,
                journey_departure.c.utc_end,
            )
        )
        .join(models.Journey, journey_departure.c.id == models.Journey.id)
        .join(models.Journey.pattern)
        .group_by(
            journey_departure.c.id,
            journey_departure.c.t_offset,
            utc_departure
        ),
        journey_departure.c.date,
    ).cte("journey_filter")

    query = (
        db.session.query(
            models.Service.line.label("line"),
            models.JourneyPattern.origin.label("origin"),
            models.JourneyPattern.destination.label("destination"),
            models.Operator.code.label("op_code"),
            models.Operator.name.label("op_name"),
            journey_filter.c.expected,
            db.cast(db.extract("EPOCH", journey_filter.c.expected - p_timestamp), db.Integer).label("seconds")
        )
        .select_from(journey_filter)
        .join(models.Journey, journey_filter.c.id == models.Journey.id)
        .join(models.Journey.pattern)
        .join(models.JourneyPattern.service)
        .join(models.JourneyPattern.operator)
        .order_by(journey_filter.c.expected)
    )

    return query
示例#8
0
"""
Creating timetables for a service.
"""
from collections import abc
import functools

from nextbus import db, graph, models


_ONE_HOUR = db.cast(db.literal_column("'1 hour'"), db.Interval)
_GB_TZ = db.bindparam("gb", "Europe/London")
_UTC_TZ = db.bindparam("utc", "UTC")
_TRUNCATE_MIN = db.bindparam("trunc_min", "minute")
_FORMAT_TIME = db.bindparam("format_time", "HH24MI")


def _bit_array_contains(array, col):
    """ SQL expression for matching integer with a bit array, equivalent to
        `(1 << col) & array > 0`.
    """
    return (
        db.literal_column("1").op("<<")(col).op("&")(array) >
        db.literal_column("0")
    )


def _format_time(timestamp):
    """ SQL expression to format a date or timestamp as `HHMM`, eg 0730. """
    return db.func.to_char(
        db.func.date_trunc(_TRUNCATE_MIN, timestamp),
        _FORMAT_TIME,