def test_vehicle__move_between_trips_attached( db_session, add_model, system_1, route_1_1, previous_update, current_update, trip_for_vehicle, stop_1_3, ): vehicle = add_model( models.Vehicle( id="vehicle_id", system=system_1, source=previous_update, trip=trip_for_vehicle, )) importdriver.run_import( current_update.pk, ParserForTesting([ parse.Trip(id="trip_id_2", route_id=route_1_1.id), parse.Vehicle(id="vehicle_id", trip_id="trip_id_2"), ]), ) db_session.refresh(trip_for_vehicle) new_trip = (db_session.query( models.Trip).filter(models.Trip.id == "trip_id_2").one_or_none()) assert trip_for_vehicle.vehicle is None assert new_trip.vehicle == vehicle
def test_transfers( previous_update, current_update, stop_1_1, stop_1_2, from_stop_valid, to_stop_valid, expected_added, previous_transfer, expected_deleted, ): transfer = parse.Transfer( from_stop_id=stop_1_1.id if from_stop_valid else "blah", to_stop_id=stop_1_2.id if to_stop_valid else "blaf", min_transfer_time=300, ) if previous_transfer: importdriver.run_import( previous_update.pk, ParserForTesting([ parse.Transfer(from_stop_id=stop_1_1.id, to_stop_id=stop_1_2.id) ]), ) result = importdriver.run_import(current_update.pk, ParserForTesting([transfer])) verify_stats(result, (expected_added, 0, expected_deleted))
def test_vehicle__set_stop_simple_case( db_session, current_update, trip_for_vehicle, stop_1_3, provide_stop_id, provide_stop_sequence, ): vehicle = parse.Vehicle( id="vehicle_id", trip_id="trip_id", current_stop_id=stop_1_3.id if provide_stop_id else None, current_stop_sequence=3 if provide_stop_sequence else None, ) importdriver.run_import(current_update.pk, ParserForTesting([vehicle])) persisted_vehicle = db_session.query(models.Vehicle).all()[0] if not provide_stop_id and not provide_stop_sequence: assert persisted_vehicle.current_stop is None assert persisted_vehicle.current_stop_sequence is None else: assert persisted_vehicle.current_stop == stop_1_3 assert persisted_vehicle.current_stop_sequence == 3
def test_parse_error(current_update): class BuggyParser(parse.TransiterParser): def get_routes(self): raise ValueError with pytest.raises(ValueError): importdriver.run_import(current_update.pk, BuggyParser())
def test_trip__invalid_route(db_session, system_1, route_1_1, current_update): new_trip = parse.Trip(id="trip", route_id="unknown_route", direction_id=True) importdriver.run_import(current_update.pk, ParserForTesting([new_trip])) all_trips = db_session.query(models.Trip).all() assert [] == all_trips
def test_route__agency_linking(db_session, current_update): agency = parse.Agency(id="agency", name="My Agency", timezone="", url="") route = parse.Route(id="route", type=parse.Route.Type.RAIL, agency_id="agency") importdriver.run_import(current_update.pk, ParserForTesting([route, agency])) persisted_route = db_session.query(models.Route).all()[0] assert persisted_route.agency is not None
def test_stop__tree_linking( db_session, system_1, add_model, previous_update, current_update, old_id_to_parent_id, expected_id_to_parent_id, ): stop_id_to_stop = { id_: add_model( models.Stop( id=id_, name=id_, system=system_1, source=previous_update, longitude=0, latitude=0, type=parse.Stop.Type.STATION, )) for id_ in old_id_to_parent_id.keys() } for id_, parent_id in old_id_to_parent_id.items(): if parent_id is None: continue stop_id_to_stop[id_].parent_stop = stop_id_to_stop[parent_id] db_session.flush() stop_id_to_stop = { id_: parse.Stop(id=id_, name=id_, longitude=0, latitude=0, type=parse.Stop.Type.STATION) for id_ in expected_id_to_parent_id.keys() } for id_, parent_id in expected_id_to_parent_id.items(): if parent_id is None: continue stop_id_to_stop[id_].parent_stop = stop_id_to_stop[parent_id] importdriver.run_import(current_update.pk, ParserForTesting(list(stop_id_to_stop.values()))) actual_stop_id_parent_id = {} for stop in db_session.query(models.Stop).all(): if stop.parent_stop is not None: actual_stop_id_parent_id[stop.id] = stop.parent_stop.id else: actual_stop_id_parent_id[stop.id] = None assert expected_id_to_parent_id == actual_stop_id_parent_id
def test_schedule(db_session, stop_1_1, route_1_1, previous_update, current_update): stop_time = parse.ScheduledTripStopTime(stop_id=stop_1_1.id, stop_sequence=3, departure_time=None, arrival_time=None) trip = parse.ScheduledTrip( id="trid_id", route_id=route_1_1.id, direction_id=True, stop_times=[stop_time], frequencies=[ parse.ScheduledTripFrequency( start_time=datetime.time(3, 4, 5), end_time=datetime.time(6, 7, 8), headway=30, frequency_based=False, ) ], ) schedule = parse.ScheduledService( id="schedule", monday=True, tuesday=True, wednesday=True, thursday=True, friday=True, saturday=True, sunday=True, trips=[trip], added_dates=[datetime.date(2016, 9, 10)], removed_dates=[datetime.date(2016, 9, 11), datetime.date(2016, 9, 12)], ) actual_counts = importdriver.run_import(previous_update.pk, ParserForTesting([schedule])) assert 1 == len(db_session.query(models.ScheduledService).all()) assert 1 == len(db_session.query(models.ScheduledServiceAddition).all()) assert 2 == len(db_session.query(models.ScheduledServiceRemoval).all()) assert 1 == len(db_session.query(models.ScheduledTrip).all()) assert 1 == len(db_session.query(models.ScheduledTripFrequency).all()) assert 1 == len(db_session.query(models.ScheduledTripStopTime).all()) verify_stats(actual_counts, (1, 0, 0)) # Just to make sure we can delete it all importdriver.run_import(current_update.pk, ParserForTesting([]))
def test_direction_rules( db_session, add_model, stop_1_1, previous_update, current_update, previous, current, expected_counts, ): for rule in previous: rule.stop_pk = stop_1_1.pk rule.source = previous_update add_model(rule) expected_entities = list(current) for rule in expected_entities: rule.stop_pk = stop_1_1.pk for rule in current: rule.stop_id = stop_1_1.id actual_counts = importdriver.run_import(current_update.pk, ParserForTesting(current)) def fields_to_compare(entity): return entity.stop_pk, entity.track, entity.source_pk assert set(map(fields_to_compare, expected_entities)) == set( map(fields_to_compare, db_session.query(models.DirectionRule).all())) verify_stats(actual_counts, expected_counts)
def test_vehicle__no_vehicle_id( db_session, current_update, trip_for_vehicle, stop_1_3, vehicle_id, ): vehicle = parse.Vehicle(id=vehicle_id, trip_id="trip_id") importdriver.run_import(current_update.pk, ParserForTesting([vehicle])) persisted_vehicle = db_session.query(models.Vehicle).all()[0] db_session.refresh(trip_for_vehicle) assert trip_for_vehicle.vehicle == persisted_vehicle assert persisted_vehicle.trip == trip_for_vehicle
def test_trip__stop_time_reconciliation( db_session, add_model, system_1, route_1_1, previous_update, current_update, old_stop_time_data, new_stop_time_data, expected_stop_time_data, feed_1_1_update_1, ): stop_pk_to_stop = {} all_stop_ids = set(trip_stop_time.stop_id for trip_stop_time in itertools.chain( old_stop_time_data, new_stop_time_data)) for stop_id in all_stop_ids: stop = add_model( models.Stop( id=stop_id, system=system_1, type=models.Stop.Type.STATION, source=feed_1_1_update_1, )) stop_pk_to_stop[stop.pk] = stop trip = parse.Trip( id="trip", route_id=route_1_1.id, direction_id=True, stop_times=old_stop_time_data, ) importdriver.run_import(previous_update.pk, ParserForTesting([trip])) trip.stop_times = new_stop_time_data importdriver.run_import(current_update.pk, ParserForTesting([trip])) actual_stop_times = [ convert_trip_stop_time_model_to_parse(trip_stop_time, stop_pk_to_stop) for trip_stop_time in db_session.query(models.Trip).all()[0].stop_times ] assert expected_stop_time_data == actual_stop_times
def test_flush(db_session, add_model, system_1, previous_update, current_update): current_update.update_type = models.FeedUpdate.Type.FLUSH add_model( models.Stop( system=system_1, source_pk=previous_update.pk, type=models.Stop.Type.STATION, )) add_model(models.Route( system=system_1, source_pk=previous_update.pk, )) importdriver.run_import(current_update.pk, ParserForTesting([])) assert [] == db_session.query(models.Route).all()
def test_trip__invalid_stops_in_stop_times(db_session, system_1, route_1_1, stop_1_1, current_update): new_trip = parse.Trip( id="trip", route_id=route_1_1.id, direction_id=True, stop_times=[ parse.TripStopTime(stop_id=stop_1_1.id, stop_sequence=2), parse.TripStopTime(stop_id=stop_1_1.id + "blah_bla", stop_sequence=3), ], ) importdriver.run_import(current_update.pk, ParserForTesting([new_trip])) all_trips = db_session.query(models.Trip).all() assert 1 == len(all_trips) assert 1 == len(all_trips[0].stop_times)
def test_vehicle__merged_vehicle_edge_case( db_session, previous_update, current_update, trip_for_vehicle, stop_1_3, ): vehicle_1 = parse.Vehicle(id=None, trip_id="trip_id") vehicle_2 = parse.Vehicle(id="vehicle_id", trip_id=None) vehicle_3 = parse.Vehicle(id="vehicle_id", trip_id="trip_id") importdriver.run_import(previous_update.pk, ParserForTesting([vehicle_1, vehicle_2])) db_session.refresh(trip_for_vehicle) result = importdriver.run_import(current_update.pk, ParserForTesting([vehicle_3])) verify_stats(result, (0, 0, 2))
def test_duplicate_ids(current_update, parsed_type, route_1_1, stop_1_1): if isinstance(parsed_type, parse.Trip): parsed_type.route_id = route_1_1.id if isinstance(parsed_type, parse.DirectionRule): parsed_type.stop_id = stop_1_1.id actual_counts = importdriver.run_import( current_update.pk, ParserForTesting([parsed_type, parsed_type])) verify_stats(actual_counts, (1, 0, 0))
def test_alert__route_linking( db_session, current_update, route_1_1, stop_1_1, trip_1, agency_1_1, entity_type, valid_id, ): alert_kwargs = {} entity = None if entity_type == "routes": alert_kwargs["route_ids"] = [ route_1_1.id if valid_id else "buggy_route_id" ] entity = route_1_1 elif entity_type == "stops": alert_kwargs["stop_ids"] = [ stop_1_1.id if valid_id else "buggy_stop_id" ] entity = stop_1_1 elif entity_type == "trips": alert_kwargs["trip_ids"] = [trip_1.id if valid_id else "buggy_trip_id"] entity = trip_1 elif entity_type == "agencies": alert_kwargs["agency_ids"] = [ agency_1_1.id if valid_id else "buggy_agency_id" ] entity = agency_1_1 alert = parse.Alert(id="alert", **alert_kwargs) importdriver.run_import(current_update.pk, ParserForTesting([alert])) persisted_alert = db_session.query(models.Alert).all()[0] if valid_id: assert getattr(persisted_alert, entity_type) == [entity] else: assert getattr(persisted_alert, entity_type) == []
def test_vehicle__duplicate_trip_ids( db_session, current_update, trip_for_vehicle, stop_1_3, ): vehicle = parse.Vehicle(id=None, trip_id="trip_id") result = importdriver.run_import(current_update.pk, ParserForTesting([vehicle, vehicle])) verify_stats(result, (1, 0, 0))
def test_trip__route_from_schedule(db_session, add_model, system_1, route_1_1, current_update, feed_1_1_update_1): add_model( models.ScheduledTrip( id="trip", route=route_1_1, service=add_model( models.ScheduledService(id="service", system=system_1, source=feed_1_1_update_1)), )) new_trip = parse.Trip(id="trip", route_id=None, direction_id=True) importdriver.run_import(current_update.pk, ParserForTesting([new_trip])) all_trips = db_session.query(models.Trip).all() assert 1 == len(all_trips) assert "trip" == all_trips[0].id assert route_1_1 == all_trips[0].route
def test_entities_skipped(db_session, current_update): class BuggyParser(parse.parser.CallableBasedParser): @property def supported_types(self): return {parse.Stop} parser = BuggyParser(lambda: [new_route]) result = importdriver.run_import(current_update.pk, parser) assert [] == db_session.query(models.Route).all() verify_stats(result, (0, 0, 0))
def test_vehicle__delete_with_trip_attached( db_session, add_model, system_1, previous_update, current_update, trip_for_vehicle, stop_1_3, ): add_model( models.Vehicle( id="vehicle_id", system=system_1, source=previous_update, trip=trip_for_vehicle, )) importdriver.run_import(current_update.pk, ParserForTesting([])) db_session.refresh(trip_for_vehicle) assert trip_for_vehicle.vehicle is None
def test_direction_rules__skip_unknown_stop( db_session, system_1, current_update, ): current = [ parse.DirectionRule(name="Rule", id="blah", track="new track", stop_id="104401") ] actual_counts = importdriver.run_import(current_update.pk, ParserForTesting(current)) assert [] == db_session.query(models.DirectionRule).all() verify_stats(actual_counts, (0, 0, 0))
def test_simple_create_update_delete( db_session, add_model, system_1, previous_update, current_update, entity_type, previous, current, expected_counts, ): for entity in previous: entity.system_pk = system_1.pk entity.source = previous_update add_model(entity) actual_counts = importdriver.run_import(current_update.pk, ParserForTesting(current)) def fields_to_compare(entity): if entity_type is models.Route: return entity.id, entity.description, entity.source_pk if entity_type is models.Stop: return entity.id, entity.name, entity.source_pk if entity_type is models.Alert: return entity.id, entity.cause, entity.effect if entity_type is models.Agency: return entity.id, entity.name if entity_type is models.Vehicle: return entity.id, entity.label, entity.current_status raise NotImplementedError assert set(map(fields_to_compare, current)) == set( map(fields_to_compare, db_session.query(entity_type).all())) verify_stats(actual_counts, expected_counts)
def test_unknown_type(current_update): with pytest.raises(TypeError): importdriver.run_import(current_update.pk, ParserForTesting(["string"]))
def test_move_entity_across_feeds(current_update, other_feed_update, route_1_1, entity): importdriver.run_import(other_feed_update.pk, ParserForTesting([entity])) importdriver.run_import(current_update.pk, ParserForTesting([entity]))