Example #1
0
def test_vehicle__move_between_trips_attached(
    db_session,
    add_model,
    system_1,
    route_1_1,
    previous_update,
    current_update,
    trip_for_vehicle,
    stop_1_3,
):
    vehicle = add_model(
        models.Vehicle(
            id="vehicle_id",
            system=system_1,
            source=previous_update,
            trip=trip_for_vehicle,
        ))

    importdriver.run_import(
        current_update.pk,
        ParserForTesting([
            parse.Trip(id="trip_id_2", route_id=route_1_1.id),
            parse.Vehicle(id="vehicle_id", trip_id="trip_id_2"),
        ]),
    )

    db_session.refresh(trip_for_vehicle)
    new_trip = (db_session.query(
        models.Trip).filter(models.Trip.id == "trip_id_2").one_or_none())

    assert trip_for_vehicle.vehicle is None
    assert new_trip.vehicle == vehicle
Example #2
0
def test_transfers(
    previous_update,
    current_update,
    stop_1_1,
    stop_1_2,
    from_stop_valid,
    to_stop_valid,
    expected_added,
    previous_transfer,
    expected_deleted,
):
    transfer = parse.Transfer(
        from_stop_id=stop_1_1.id if from_stop_valid else "blah",
        to_stop_id=stop_1_2.id if to_stop_valid else "blaf",
        min_transfer_time=300,
    )

    if previous_transfer:
        importdriver.run_import(
            previous_update.pk,
            ParserForTesting([
                parse.Transfer(from_stop_id=stop_1_1.id,
                               to_stop_id=stop_1_2.id)
            ]),
        )

    result = importdriver.run_import(current_update.pk,
                                     ParserForTesting([transfer]))

    verify_stats(result, (expected_added, 0, expected_deleted))
Example #3
0
def test_vehicle__set_stop_simple_case(
    db_session,
    current_update,
    trip_for_vehicle,
    stop_1_3,
    provide_stop_id,
    provide_stop_sequence,
):
    vehicle = parse.Vehicle(
        id="vehicle_id",
        trip_id="trip_id",
        current_stop_id=stop_1_3.id if provide_stop_id else None,
        current_stop_sequence=3 if provide_stop_sequence else None,
    )

    importdriver.run_import(current_update.pk, ParserForTesting([vehicle]))

    persisted_vehicle = db_session.query(models.Vehicle).all()[0]

    if not provide_stop_id and not provide_stop_sequence:
        assert persisted_vehicle.current_stop is None
        assert persisted_vehicle.current_stop_sequence is None
    else:
        assert persisted_vehicle.current_stop == stop_1_3
        assert persisted_vehicle.current_stop_sequence == 3
Example #4
0
def test_parse_error(current_update):
    class BuggyParser(parse.TransiterParser):
        def get_routes(self):
            raise ValueError

    with pytest.raises(ValueError):
        importdriver.run_import(current_update.pk, BuggyParser())
Example #5
0
def test_trip__invalid_route(db_session, system_1, route_1_1, current_update):
    new_trip = parse.Trip(id="trip",
                          route_id="unknown_route",
                          direction_id=True)

    importdriver.run_import(current_update.pk, ParserForTesting([new_trip]))

    all_trips = db_session.query(models.Trip).all()

    assert [] == all_trips
Example #6
0
def test_route__agency_linking(db_session, current_update):
    agency = parse.Agency(id="agency", name="My Agency", timezone="", url="")
    route = parse.Route(id="route",
                        type=parse.Route.Type.RAIL,
                        agency_id="agency")

    importdriver.run_import(current_update.pk,
                            ParserForTesting([route, agency]))

    persisted_route = db_session.query(models.Route).all()[0]

    assert persisted_route.agency is not None
Example #7
0
def test_stop__tree_linking(
    db_session,
    system_1,
    add_model,
    previous_update,
    current_update,
    old_id_to_parent_id,
    expected_id_to_parent_id,
):
    stop_id_to_stop = {
        id_: add_model(
            models.Stop(
                id=id_,
                name=id_,
                system=system_1,
                source=previous_update,
                longitude=0,
                latitude=0,
                type=parse.Stop.Type.STATION,
            ))
        for id_ in old_id_to_parent_id.keys()
    }
    for id_, parent_id in old_id_to_parent_id.items():
        if parent_id is None:
            continue
        stop_id_to_stop[id_].parent_stop = stop_id_to_stop[parent_id]
    db_session.flush()

    stop_id_to_stop = {
        id_: parse.Stop(id=id_,
                        name=id_,
                        longitude=0,
                        latitude=0,
                        type=parse.Stop.Type.STATION)
        for id_ in expected_id_to_parent_id.keys()
    }
    for id_, parent_id in expected_id_to_parent_id.items():
        if parent_id is None:
            continue
        stop_id_to_stop[id_].parent_stop = stop_id_to_stop[parent_id]

    importdriver.run_import(current_update.pk,
                            ParserForTesting(list(stop_id_to_stop.values())))

    actual_stop_id_parent_id = {}
    for stop in db_session.query(models.Stop).all():
        if stop.parent_stop is not None:
            actual_stop_id_parent_id[stop.id] = stop.parent_stop.id
        else:
            actual_stop_id_parent_id[stop.id] = None

    assert expected_id_to_parent_id == actual_stop_id_parent_id
Example #8
0
def test_schedule(db_session, stop_1_1, route_1_1, previous_update,
                  current_update):
    stop_time = parse.ScheduledTripStopTime(stop_id=stop_1_1.id,
                                            stop_sequence=3,
                                            departure_time=None,
                                            arrival_time=None)

    trip = parse.ScheduledTrip(
        id="trid_id",
        route_id=route_1_1.id,
        direction_id=True,
        stop_times=[stop_time],
        frequencies=[
            parse.ScheduledTripFrequency(
                start_time=datetime.time(3, 4, 5),
                end_time=datetime.time(6, 7, 8),
                headway=30,
                frequency_based=False,
            )
        ],
    )

    schedule = parse.ScheduledService(
        id="schedule",
        monday=True,
        tuesday=True,
        wednesday=True,
        thursday=True,
        friday=True,
        saturday=True,
        sunday=True,
        trips=[trip],
        added_dates=[datetime.date(2016, 9, 10)],
        removed_dates=[datetime.date(2016, 9, 11),
                       datetime.date(2016, 9, 12)],
    )

    actual_counts = importdriver.run_import(previous_update.pk,
                                            ParserForTesting([schedule]))

    assert 1 == len(db_session.query(models.ScheduledService).all())
    assert 1 == len(db_session.query(models.ScheduledServiceAddition).all())
    assert 2 == len(db_session.query(models.ScheduledServiceRemoval).all())
    assert 1 == len(db_session.query(models.ScheduledTrip).all())
    assert 1 == len(db_session.query(models.ScheduledTripFrequency).all())
    assert 1 == len(db_session.query(models.ScheduledTripStopTime).all())
    verify_stats(actual_counts, (1, 0, 0))

    # Just to make sure we can delete it all
    importdriver.run_import(current_update.pk, ParserForTesting([]))
Example #9
0
def test_direction_rules(
    db_session,
    add_model,
    stop_1_1,
    previous_update,
    current_update,
    previous,
    current,
    expected_counts,
):
    for rule in previous:
        rule.stop_pk = stop_1_1.pk
        rule.source = previous_update
        add_model(rule)

    expected_entities = list(current)
    for rule in expected_entities:
        rule.stop_pk = stop_1_1.pk

    for rule in current:
        rule.stop_id = stop_1_1.id

    actual_counts = importdriver.run_import(current_update.pk,
                                            ParserForTesting(current))

    def fields_to_compare(entity):
        return entity.stop_pk, entity.track, entity.source_pk

    assert set(map(fields_to_compare, expected_entities)) == set(
        map(fields_to_compare,
            db_session.query(models.DirectionRule).all()))
    verify_stats(actual_counts, expected_counts)
Example #10
0
def test_vehicle__no_vehicle_id(
    db_session,
    current_update,
    trip_for_vehicle,
    stop_1_3,
    vehicle_id,
):
    vehicle = parse.Vehicle(id=vehicle_id, trip_id="trip_id")

    importdriver.run_import(current_update.pk, ParserForTesting([vehicle]))

    persisted_vehicle = db_session.query(models.Vehicle).all()[0]

    db_session.refresh(trip_for_vehicle)

    assert trip_for_vehicle.vehicle == persisted_vehicle
    assert persisted_vehicle.trip == trip_for_vehicle
Example #11
0
def test_trip__stop_time_reconciliation(
    db_session,
    add_model,
    system_1,
    route_1_1,
    previous_update,
    current_update,
    old_stop_time_data,
    new_stop_time_data,
    expected_stop_time_data,
    feed_1_1_update_1,
):
    stop_pk_to_stop = {}
    all_stop_ids = set(trip_stop_time.stop_id
                       for trip_stop_time in itertools.chain(
                           old_stop_time_data, new_stop_time_data))
    for stop_id in all_stop_ids:
        stop = add_model(
            models.Stop(
                id=stop_id,
                system=system_1,
                type=models.Stop.Type.STATION,
                source=feed_1_1_update_1,
            ))
        stop_pk_to_stop[stop.pk] = stop

    trip = parse.Trip(
        id="trip",
        route_id=route_1_1.id,
        direction_id=True,
        stop_times=old_stop_time_data,
    )

    importdriver.run_import(previous_update.pk, ParserForTesting([trip]))

    trip.stop_times = new_stop_time_data

    importdriver.run_import(current_update.pk, ParserForTesting([trip]))

    actual_stop_times = [
        convert_trip_stop_time_model_to_parse(trip_stop_time, stop_pk_to_stop)
        for trip_stop_time in db_session.query(models.Trip).all()[0].stop_times
    ]

    assert expected_stop_time_data == actual_stop_times
Example #12
0
def test_flush(db_session, add_model, system_1, previous_update,
               current_update):
    current_update.update_type = models.FeedUpdate.Type.FLUSH

    add_model(
        models.Stop(
            system=system_1,
            source_pk=previous_update.pk,
            type=models.Stop.Type.STATION,
        ))
    add_model(models.Route(
        system=system_1,
        source_pk=previous_update.pk,
    ))

    importdriver.run_import(current_update.pk, ParserForTesting([]))

    assert [] == db_session.query(models.Route).all()
Example #13
0
def test_trip__invalid_stops_in_stop_times(db_session, system_1, route_1_1,
                                           stop_1_1, current_update):
    new_trip = parse.Trip(
        id="trip",
        route_id=route_1_1.id,
        direction_id=True,
        stop_times=[
            parse.TripStopTime(stop_id=stop_1_1.id, stop_sequence=2),
            parse.TripStopTime(stop_id=stop_1_1.id + "blah_bla",
                               stop_sequence=3),
        ],
    )

    importdriver.run_import(current_update.pk, ParserForTesting([new_trip]))

    all_trips = db_session.query(models.Trip).all()

    assert 1 == len(all_trips)
    assert 1 == len(all_trips[0].stop_times)
Example #14
0
def test_vehicle__merged_vehicle_edge_case(
    db_session,
    previous_update,
    current_update,
    trip_for_vehicle,
    stop_1_3,
):
    vehicle_1 = parse.Vehicle(id=None, trip_id="trip_id")
    vehicle_2 = parse.Vehicle(id="vehicle_id", trip_id=None)
    vehicle_3 = parse.Vehicle(id="vehicle_id", trip_id="trip_id")

    importdriver.run_import(previous_update.pk,
                            ParserForTesting([vehicle_1, vehicle_2]))
    db_session.refresh(trip_for_vehicle)

    result = importdriver.run_import(current_update.pk,
                                     ParserForTesting([vehicle_3]))

    verify_stats(result, (0, 0, 2))
Example #15
0
def test_duplicate_ids(current_update, parsed_type, route_1_1, stop_1_1):
    if isinstance(parsed_type, parse.Trip):
        parsed_type.route_id = route_1_1.id
    if isinstance(parsed_type, parse.DirectionRule):
        parsed_type.stop_id = stop_1_1.id

    actual_counts = importdriver.run_import(
        current_update.pk, ParserForTesting([parsed_type, parsed_type]))

    verify_stats(actual_counts, (1, 0, 0))
Example #16
0
def test_alert__route_linking(
    db_session,
    current_update,
    route_1_1,
    stop_1_1,
    trip_1,
    agency_1_1,
    entity_type,
    valid_id,
):
    alert_kwargs = {}
    entity = None
    if entity_type == "routes":
        alert_kwargs["route_ids"] = [
            route_1_1.id if valid_id else "buggy_route_id"
        ]
        entity = route_1_1
    elif entity_type == "stops":
        alert_kwargs["stop_ids"] = [
            stop_1_1.id if valid_id else "buggy_stop_id"
        ]
        entity = stop_1_1
    elif entity_type == "trips":
        alert_kwargs["trip_ids"] = [trip_1.id if valid_id else "buggy_trip_id"]
        entity = trip_1
    elif entity_type == "agencies":
        alert_kwargs["agency_ids"] = [
            agency_1_1.id if valid_id else "buggy_agency_id"
        ]
        entity = agency_1_1

    alert = parse.Alert(id="alert", **alert_kwargs)

    importdriver.run_import(current_update.pk, ParserForTesting([alert]))

    persisted_alert = db_session.query(models.Alert).all()[0]

    if valid_id:
        assert getattr(persisted_alert, entity_type) == [entity]
    else:
        assert getattr(persisted_alert, entity_type) == []
Example #17
0
def test_vehicle__duplicate_trip_ids(
    db_session,
    current_update,
    trip_for_vehicle,
    stop_1_3,
):
    vehicle = parse.Vehicle(id=None, trip_id="trip_id")

    result = importdriver.run_import(current_update.pk,
                                     ParserForTesting([vehicle, vehicle]))

    verify_stats(result, (1, 0, 0))
Example #18
0
def test_trip__route_from_schedule(db_session, add_model, system_1, route_1_1,
                                   current_update, feed_1_1_update_1):
    add_model(
        models.ScheduledTrip(
            id="trip",
            route=route_1_1,
            service=add_model(
                models.ScheduledService(id="service",
                                        system=system_1,
                                        source=feed_1_1_update_1)),
        ))

    new_trip = parse.Trip(id="trip", route_id=None, direction_id=True)

    importdriver.run_import(current_update.pk, ParserForTesting([new_trip]))

    all_trips = db_session.query(models.Trip).all()

    assert 1 == len(all_trips)
    assert "trip" == all_trips[0].id
    assert route_1_1 == all_trips[0].route
Example #19
0
def test_entities_skipped(db_session, current_update):
    class BuggyParser(parse.parser.CallableBasedParser):
        @property
        def supported_types(self):
            return {parse.Stop}

    parser = BuggyParser(lambda: [new_route])

    result = importdriver.run_import(current_update.pk, parser)

    assert [] == db_session.query(models.Route).all()
    verify_stats(result, (0, 0, 0))
Example #20
0
def test_vehicle__delete_with_trip_attached(
    db_session,
    add_model,
    system_1,
    previous_update,
    current_update,
    trip_for_vehicle,
    stop_1_3,
):
    add_model(
        models.Vehicle(
            id="vehicle_id",
            system=system_1,
            source=previous_update,
            trip=trip_for_vehicle,
        ))

    importdriver.run_import(current_update.pk, ParserForTesting([]))

    db_session.refresh(trip_for_vehicle)

    assert trip_for_vehicle.vehicle is None
Example #21
0
def test_direction_rules__skip_unknown_stop(
    db_session,
    system_1,
    current_update,
):
    current = [
        parse.DirectionRule(name="Rule",
                            id="blah",
                            track="new track",
                            stop_id="104401")
    ]

    actual_counts = importdriver.run_import(current_update.pk,
                                            ParserForTesting(current))

    assert [] == db_session.query(models.DirectionRule).all()
    verify_stats(actual_counts, (0, 0, 0))
Example #22
0
def test_simple_create_update_delete(
    db_session,
    add_model,
    system_1,
    previous_update,
    current_update,
    entity_type,
    previous,
    current,
    expected_counts,
):
    for entity in previous:
        entity.system_pk = system_1.pk
        entity.source = previous_update
        add_model(entity)

    actual_counts = importdriver.run_import(current_update.pk,
                                            ParserForTesting(current))

    def fields_to_compare(entity):
        if entity_type is models.Route:
            return entity.id, entity.description, entity.source_pk
        if entity_type is models.Stop:
            return entity.id, entity.name, entity.source_pk
        if entity_type is models.Alert:
            return entity.id, entity.cause, entity.effect
        if entity_type is models.Agency:
            return entity.id, entity.name
        if entity_type is models.Vehicle:
            return entity.id, entity.label, entity.current_status
        raise NotImplementedError

    assert set(map(fields_to_compare, current)) == set(
        map(fields_to_compare,
            db_session.query(entity_type).all()))
    verify_stats(actual_counts, expected_counts)
Example #23
0
def test_unknown_type(current_update):
    with pytest.raises(TypeError):
        importdriver.run_import(current_update.pk,
                                ParserForTesting(["string"]))
Example #24
0
def test_move_entity_across_feeds(current_update, other_feed_update, route_1_1,
                                  entity):
    importdriver.run_import(other_feed_update.pk, ParserForTesting([entity]))

    importdriver.run_import(current_update.pk, ParserForTesting([entity]))