def with_cleaned_data_in_database() -> Iterator[OrchestratorResult]: cleanup_cleaned_data_query = Query(CleanedData).filter(CleanedData.c.Contract_ID == TEST_CONTRACT) # type: ignore delete_test_data(cleanup_cleaned_data_query) # Cleanup in case of previously failed test with Database(DatabaseType.internal).transaction_context() as session: cleaned_data_count = cleanup_cleaned_data_query.with_session(session).count() assert cleaned_data_count == 0, "Found old test data in database when setting up the test" orchestrator_result = setup_orchestrator_result(use_real_database=True) orchestrator_result.orchestrator._initialize_forecast_run() test_run_id = cast(int, orchestrator_result.orchestrator._forecast_run_id) test_cleaned_data = [ { "run_id": test_run_id, "Project_ID": "Test_Project", "Contract_ID": TEST_CONTRACT, "Wesco_Master_Number": "Test_Master_Number", "Date": "2020-03-01 00:00:00.000", "Date_YYYYMM": 202003, "Item_ID": -2, "Unit_Cost": 1, "Order_Quantity": 2, "Order_Cost": 2, } ] with Database(DatabaseType.internal).transaction_context() as session: session.execute(CleanedData.insert().values(test_cleaned_data)) yield orchestrator_result assert delete_test_data(cleanup_cleaned_data_query) == len(test_cleaned_data) assert delete_test_data(Query(ForecastRun).filter(ForecastRun.id == test_run_id)) == 1 # type: ignore
def test_length_geom_linestring_missing_epsg_from_global_settings(session): if session.bind.name == "postgresql": pytest.skip( "Postgres already has a constrain that checks on the length") factories.ChannelFactory(the_geom="SRID=4326;LINESTRING(" "-0.38222938832999598 -0.13872236685816669, " "-0.38222930900909202 -0.13872236685816669)", ) factories.ChannelFactory(the_geom="SRID=4326;LINESTRING(" "-0.38222938468305784 -0.13872235682908687, " "-0.38222931083256106 -0.13872235591735235, " "-0.38222930992082654 -0.13872207236791409, " "-0.38222940929989008 -0.13872235591735235)", ) q = Query(models.Channel).filter( geo_func.ST_Length( geo_func.ST_Transform( models.Channel.the_geom, Query(models.GlobalSetting.epsg_code).limit(1))) < 0.05) check_length_linestring = QueryCheck( column=models.Channel.the_geom, invalid=q, message= "Length of the v2_channel is too short, should be at least 0.05m", ) errors = check_length_linestring.get_invalid(session) assert len(errors) == 0
def get_invalid(self, session: Session) -> List[NamedTuple]: definitions_in_use = self.to_check(session).filter( models.CrossSectionDefinition.id.in_( Query(models.CrossSectionLocation.definition_id).union_all( Query(models.Pipe.cross_section_definition_id), Query(models.Culvert.cross_section_definition_id), Query(models.Weir.cross_section_definition_id), Query(models.Orifice.cross_section_definition_id), )), ) # closed_rectangle, circle, and egg cross-section definitions are always closed: closed_definitions = definitions_in_use.filter( models.CrossSectionDefinition.shape.in_([ constants.CrossSectionShape.CLOSED_RECTANGLE, constants.CrossSectionShape.CIRCLE, constants.CrossSectionShape.EGG, ])) result = list(closed_definitions.with_session(session).all()) # tabulated cross-section definitions are closed when the last element of 'width' # is zero tabulated_definitions = definitions_in_use.filter( models.CrossSectionDefinition.shape.in_([ constants.CrossSectionShape.TABULATED_RECTANGLE, constants.CrossSectionShape.TABULATED_TRAPEZIUM, ])) for definition in tabulated_definitions.with_session(session).all(): try: if float(definition.width.split(" ")[-1]) == 0.0: # Closed channel result.append(definition) except Exception: # Many things can go wrong, these are caught elsewhere pass return result
def test_transform_ORM(): # Define the transform query for both the geometry and the raster in a naive way wrong_query = Query([ RasterTable.geom.ST_Transform(2154), RasterTable.rast.ST_Transform(2154) ]) # Check the query assert str(wrong_query) == ( "SELECT " "ST_AsEWKB(" "ST_Transform(raster_table_orm.geom, :ST_Transform_2)) AS \"ST_Transform_1\", " "ST_AsEWKB(" # <= Note that the raster is processed as a Geometry here "ST_Transform(raster_table_orm.rast, :ST_Transform_4)) AS \"ST_Transform_3\" \n" "FROM raster_table_orm" ) # Define the transform query for both the geometry and the raster in the correct way correct_query = Query([ RasterTable.geom.ST_Transform(2154), RasterTable.rast.ST_Transform(2154, type_=Raster) ]) # Check the query assert str(correct_query) == ( "SELECT " "ST_AsEWKB(" "ST_Transform(raster_table_orm.geom, :ST_Transform_2)) AS \"ST_Transform_1\", " "raster(" # <= This time the raster is correctly processed as a Raster "ST_Transform(raster_table_orm.rast, :ST_Transform_4)) AS \"ST_Transform_3\" \n" "FROM raster_table_orm" )
def epsg_code_query(): try: epsg_code = Query( models.GlobalSetting.epsg_code).limit(1).scalar_subquery() except AttributeError: epsg_code = Query(models.GlobalSetting.epsg_code).limit(1).as_scalar() return func.coalesce(epsg_code, literal(DEFAULT_EPSG)).label("epsg_code")
def update_order(order_id: UUID, order: Dict) -> Dict: with db_session() as session: order_update = Query(Order, session=session).filter_by(id=order_id).first() if not order_update: raise OrderNotFoundException() order_update_dict = order_update.to_dict() delete_customizations_by_order_id(order_id) customizations = [] for customization_code in order["customizations"]: customization = get_customization_by_code(customization_code.value) customizations.append(customization) if not customization: raise CustomizationNotFoundException() session.add( OrderCustomization(order_id=order_id, customization_id=customization["id"])) size = get_size_by_id(order_update_dict["size_id"]) flavor = get_flavor_by_id(order_update_dict["flavor_id"]) order_update = Query(Order, session=session).filter_by(id=order_id).first() order_update.setup_time = utils.calculate_time(size, flavor, customizations) order_update.amount = utils.calculate_amount(size, customizations) session.commit() return get_order(order_id)
def test_query_check_manhole_drain_level_calc_type_2(session): # manhole.drain_level can be null, but if manhole.calculation_type == 2 (Connected) # then manhole.drain_level >= manhole.bottom_level factories.ManholeFactory(drain_level=None) factories.ManholeFactory(drain_level=1) m3_error = factories.ManholeFactory( drain_level=None, calculation_type=constants.CalculationTypeNode.CONNECTED ) # drain_level cannot be null when calculation_type is CONNECTED m4_error = factories.ManholeFactory( drain_level=1, bottom_level=2, calculation_type=constants.CalculationTypeNode.CONNECTED, ) # bottom_level >= drain_level when calculation_type is CONNECTED factories.ManholeFactory( drain_level=1, bottom_level=0, calculation_type=constants.CalculationTypeNode.CONNECTED, ) factories.ManholeFactory( drain_level=None, bottom_level=0, calculation_type=constants.CalculationTypeNode.EMBEDDED, ) query_drn_lvl_st_bttm_lvl = Query(models.Manhole).filter( models.Manhole.drain_level < models.Manhole.bottom_level, models.Manhole.calculation_type == constants.CalculationTypeNode.CONNECTED, ) query_invalid_not_null = Query(models.Manhole).filter( models.Manhole.calculation_type == constants.CalculationTypeNode.CONNECTED, models.Manhole.drain_level == None, ) check_drn_lvl_gt_bttm_lvl = QueryCheck( column=models.Manhole.bottom_level, invalid=query_drn_lvl_st_bttm_lvl, message="Manhole.drain_level >= Manhole.bottom_level when " "Manhole.calculation_type is CONNECTED", ) check_invalid_not_null = QueryCheck( column=models.Manhole.drain_level, invalid=query_invalid_not_null, message= "Manhole.drain_level cannot be null when Manhole.calculation_type is " "CONNECTED", ) errors1 = check_drn_lvl_gt_bttm_lvl.get_invalid(session) errors2 = check_invalid_not_null.get_invalid(session) assert len(errors1) == 1 assert len(errors2) == 1 assert m3_error.id == errors2[0].id assert m4_error.id == errors1[0].id
def to_check(self, session): qs = super().to_check(session) if self.shapes is not None: qs = qs.filter(models.CrossSectionDefinition.shape.in_( self.shapes)) return qs.filter( models.CrossSectionDefinition.id.in_( Query(models.CrossSectionLocation.definition_id).union_all( Query(models.Pipe.cross_section_definition_id), Query(models.Culvert.cross_section_definition_id), Query(models.Weir.cross_section_definition_id), Query(models.Orifice.cross_section_definition_id), )))
def list(self, type: RoleType, limit: int, marker: uuid.UUID): if type == RoleType.GLOBAL: starting_query = Query(AuthZRole).filter( AuthZRole.project_id == None) # noqa: E711 else: self.mount.validate_project_scope() starting_query = Query(AuthZRole).filter( AuthZRole.project_id == cherrypy.request.project.id) return self.paginate(AuthZRole, ResponseRole, limit, marker, starting_query=starting_query)
def with_cleaned_and_forecast_data_in_database() -> Iterator[Orchestrator]: cleanup_cleaned_data_query = Query(CleanedData).filter( CleanedData.c.Contract_ID == TEST_CONTRACT) # type: ignore cleanup_forecast_data_query = Query(ForecastData).filter( ForecastData.c.Contract_ID == TEST_CONTRACT) # type: ignore delete_test_data(cleanup_cleaned_data_query ) # Cleanup in case of previously failed test delete_test_data(cleanup_forecast_data_query ) # Cleanup in case of previously failed test with Database(DatabaseType.internal).transaction_context() as session: cleaned_data_count = cleanup_cleaned_data_query.with_session( session).count() forecast_data_count = cleanup_forecast_data_query.with_session( session).count() assert cleaned_data_count + forecast_data_count == 0, "Found old test data in database when setting up the test" runtime_config = RuntimeConfig( engine_run_type=EngineRunType.development, forecast_periods=1, output_location=".", prediction_month=pd.Timestamp(year=2020, month=2, day=1), ) orchestrator = Orchestrator( runtime_config, Mock(spec=DataLoader), Mock(spec=DataOutput), Database(DatabaseType.internal), Mock(spec=Queue), Mock(), Mock(), ) test_cleaned_data = _setup_cleaned_data(orchestrator) test_forecast_data = _setup_forecast_data(orchestrator) test_run_id = cast(int, orchestrator._forecast_run_id) with Database(DatabaseType.internal).transaction_context() as session: session.execute(CleanedData.insert().values(test_cleaned_data)) session.execute(ForecastData.insert().values(test_forecast_data)) yield orchestrator assert delete_test_data(cleanup_cleaned_data_query) == len( test_cleaned_data) assert delete_test_data(cleanup_forecast_data_query) == len( test_forecast_data) assert delete_test_data( Query(ForecastRun).filter( ForecastRun.id == test_run_id)) == 1 # type: ignore
def test_dwf_calculator_surface(session): # same algorithm as impervious surface conn_node = factories.ConnectionNodeFactory.create(id=1) sur1: Surface = factories.SurfaceFactory.create( id=1, code="030007", display_name="030007", nr_of_inhabitants=3.34, dry_weather_flow=120.0, ) sur2: Surface = factories.SurfaceFactory.create( id=2, code="030008", display_name="030008", nr_of_inhabitants=1.92, dry_weather_flow=120.0, ) sur_map1: SurfaceMap = factories.SurfaceMapFactory.create( id=1, surface_id=sur1.id, connection_node_id=conn_node.id, percentage=69.0, ) sur_map2: SurfaceMap = factories.SurfaceMapFactory.create( id=2, surface_id=sur2.id, connection_node_id=conn_node.id, percentage=42.0, ) # Because we use a raw SQL query in DWF we need to commit the data session.commit() calculator = DWFCalculator(session, InflowType.SURFACE) laterals = calculator.laterals weighted_flow_sum = ( sur1.nr_of_inhabitants * sur1.dry_weather_flow * sur_map1.percentage / 100 + sur2.nr_of_inhabitants * sur2.dry_weather_flow * sur_map2.percentage / 100 ) expected_values = [ [i * 3600, (factor * weighted_flow_sum) / 1000 / 3600] for i, factor in DWF_FACTORS ] # Remove committed data Query(SurfaceMap).with_session(session).delete() Query(Surface).with_session(session).delete() Query(ConnectionNode).with_session(session).delete() session.commit() np.testing.assert_array_almost_equal(laterals[0]["values"], expected_values)
def __initialize_controls(self): if self._controls is None: self._controls = {"table": [], "memory": []} table_lookup = dict([ (x.id, x) for x in Query(ControlTable).with_session(self.session).all() ]) memory_lookup = dict([ (x.id, x) for x in Query(ControlMemory).with_session(self.session).all() ]) maps_lookup = {} for map_item in Query([ControlMeasureMap ]).with_session(self.session).all(): if map_item.measure_group_id not in maps_lookup: maps_lookup[map_item.measure_group_id] = [] maps_lookup[map_item.measure_group_id].append( control_measure_map_to_measure_location(map_item)) all_controls = (Query([ Control, ControlGroup, ControlMeasureGroup ]).join(ControlGroup, ControlMeasureGroup).with_session(self.session).filter( Control.control_group_id == self._control_group_id, ControlGroup.id == self._control_group_id, ).all()) for control, group, measuregroup in all_controls: control: Control maps: List[ControlMeasureGroup] = maps_lookup[measuregroup.id] api_control = None if control.control_type is ControlType.table: table: ControlTable = table_lookup[control.control_id] measure_spec = to_measure_specification(table, group, maps) api_control = to_table_control(control, table, measure_spec) elif control.control_type is ControlType.memory: memory: ControlMemory = memory_lookup[control.control_id] measure_spec = to_measure_specification( memory, group, maps) api_control = to_memory_control(control, memory, measure_spec) else: raise SchematisationError( f"Unknown control_type '{control.control_type.value}'") self._controls[control.control_type.value].append(api_control)
def test_join_multiple(self): """ Test join() same table multiple times""" mq = models.Edit.mongoquery(Query([models.Edit])) mq = mq.query(project=['id'], outerjoin={ 'user': { 'project': ['name'] }, 'creator': { 'project': ['id', 'tags'], 'filter': { 'id': { '$lt': 1 } } } }) q = mq.end() qs = q2sql(q) self._check_qs( """SELECT u_1.id AS u_1_id, u_1.name AS u_1_name, u_2.id AS u_2_id, u_2.tags AS u_2_tags, e.id AS e_id FROM e LEFT OUTER JOIN u AS u_1 ON u_1.id = e.uid LEFT OUTER JOIN u AS u_2 ON u_2.id = e.cuid AND u_2.id < 1""", qs)
def test_limit(self): """ Test limit() """ m = models.User limit = lambda limit=None, skip=None: m.mongoquery(Query([m])).limit( limit, skip).end() def test_limit(lim, skip, expected_endswith): qs = q2sql(limit(lim, skip)) self.assertTrue( qs.endswith(expected_endswith), '{!r} should end with {!r}'.format(qs, expected_endswith)) # Skip test_limit(None, None, 'FROM u') test_limit(None, -1, 'FROM u') test_limit(None, 0, 'FROM u') test_limit(None, 1, 'LIMIT ALL OFFSET 1') test_limit(None, 9, 'LIMIT ALL OFFSET 9') # Limit test_limit(-1, None, 'FROM u') test_limit(0, None, 'FROM u') test_limit(1, None, 'LIMIT 1') test_limit(9, None, 'LIMIT 9') # Both test_limit(5, 10, 'LIMIT 5 OFFSET 10') # Twice q = limit(limit=10) q = m.mongoquery(q).limit(limit=15, skip=30).end() qs = q2sql(q) self.assertTrue(qs.endswith('LIMIT 15 OFFSET 30'), qs)
def test_join(self): """ Test join() """ m = models.User # Okay mq = m.mongoquery(Query([models.User])) mq = mq.join(('articles', 'comments')) q = mq.end() qs = q2sql(q) self.assertIn('FROM u', qs) #self.assertIn('LEFT OUTER JOIN a', qs) # immediateload(), used in this case, does not add any JOIN clauses: a subquery is used for that #self.assertIn('LEFT OUTER JOIN c', qs) # Unknown relation mq = m.mongoquery(Query([models.User])) self.assertRaises(AssertionError, mq.join, ('???'))
def example_with_statement(): # DataAPI supports with statement for handling transaction with DataAPI(database=database, resource_arn=resource_arn, secret_arn=secret_arn) as data_api: # start transaction insert: Insert = Insert(Pets, {'name': 'dog'}) # INSERT INTO pets (name) VALUES ('dog') # `execute` accepts SQL statement as str or SQL Alchemy SQL objects result: Result = data_api.execute(insert) print(result.number_of_records_updated) # 1 query = Query(Pets).filter(Pets.id == 1) result: Result = data_api.execute(query) # or data_api.execute('select id, name from pets') # SELECT pets.id, pets.name FROM pets WHERE pets.id = 1 # `Result` like a Result object in SQL Alchemy print(result.scalar()) # 1 print(result.one()) # [Record<id=1, name='dog'>] # `Result` is Sequence[Record] records: List[Record] = list(result) print(records) # [Record<id=1, name='dog'>] # Record is Sequence and Iterator record = records[0] print(record[0]) # 1 print(record[1]) # dog for column in record: print(column) # 1 ... # show record as dict() print(record.dict()) # {'id': 1, 'name': 'dog'} # batch insert insert: Insert = Insert(Pets) data_api.batch_execute(insert, [ {'id': 2, 'name': 'cat'}, {'id': 3, 'name': 'snake'}, {'id': 4, 'name': 'rabbit'}, ]) result = data_api.execute('select * from pets') print(list(result)) # [Record<id=1, name='dog'>, Record<id=2, name='cat'>, Record<id=3, name='snake'>, Record<id=4, name='rabbit'>] # result is a sequence object for record in result: print(record)
def show_projects(): df = get_dataframe_from_query( Query([ db.Project.name.label('Название проекта'), db.Project.description.label('Описание'), ])) st.dataframe(df) selected_project = st.selectbox('Выбор проекта', df['Название проекта'].unique()) st.dataframe( get_dataframe_from_query( Query([ db.Model.name.label('Модель'), db.Model.description.label('Описание'), db.Model.params.label('Параметры модели'), db.Model.metrics.label('Метрики'), db.Dataset.name.label('Датасет'), case([(db.Model.pretrained, 'Да'), (~db.Model.pretrained, 'Нет')]).label('Предобученная'), db.Model.training_time.label('Время обучения, с'), db.Model.created_at.label('Дата обучения'), ]).join(db.Project).join(db.Dataset).filter( db.Project.name == selected_project).order_by( desc(db.Model.created_at))))
def test_node_distance(session): if session.bind.name == "postgresql": pytest.skip("Check only applicable to spatialite") con1_too_close = factories.ConnectionNodeFactory( the_geom="SRID=4326;POINT(4.728282 52.64579283592512)" ) con2_too_close = factories.ConnectionNodeFactory( the_geom="SRID=4326;POINT(4.72828 52.64579283592512)" ) # Good distance factories.ConnectionNodeFactory( the_geom="SRID=4326;POINT(4.726838755789598 52.64514133594995)" ) # sanity check to see the distances between the nodes node_a = aliased(models.ConnectionNode) node_b = aliased(models.ConnectionNode) distances_query = Query( geo_func.ST_Distance(node_a.the_geom, node_b.the_geom, 1) ).filter(node_a.id != node_b.id) # Shows the distances between all 3 nodes: node 1 and 2 are too close distances_query.with_session(session).all() check = ConnectionNodesDistance(minimum_distance=10) invalid = check.get_invalid(session) assert len(invalid) == 2 invalid_ids = [i.id for i in invalid] assert con1_too_close.id in invalid_ids assert con2_too_close.id in invalid_ids
def list(self, image_id, region_id, zone_id, limit: int, marker: uuid.UUID): # TODO: allow filtering by tags project = cherrypy.request.project starting_query = Query(Instance).filter( Instance.project_id == project.id) if image_id is not None: starting_query = starting_query.filter( Instance.image_id == image_id) if region_id is not None: with cherrypy.request.db_session() as session: region = session.query(Region).filter( Region.id == region_id).first() if region is None: raise cherrypy.HTTPError( 404, "A region with the requested id does not exist.") starting_query = starting_query.filter( Instance.region_id == region.id) if zone_id is not None: with cherrypy.request.db_session() as session: zone = session.query(Zone).filter(Zone.id == zone_id).first() if zone is None: raise cherrypy.HTTPError( 404, "A zone with the requested id does not exist.") starting_query = starting_query.filter(Instance.zone_id == zone_id) return self.paginate(Instance, ResponseInstance, limit, marker, starting_query=starting_query)
def purge(db_engine: Engine) -> None: """Remove all tasks from the database :param db_engine: engine for the database """ with session_scope(bind=db_engine) as session: Query(db.Task, session).delete()
def test_sort(self): """ Test sort() """ m = models.User sort = lambda sort_spec: m.mongoquery(Query([m])).sort(sort_spec).end() def test_sort(sort_spec, expected_ends): qs = q2sql(sort(sort_spec)) self.assertTrue( qs.endswith(expected_ends), '{!r} should end with {!r}'.format(qs, expected_ends)) # Empty test_sort(None, u'FROM u') test_sort([], u'FROM u') test_sort(OrderedDict(), u'FROM u') # List test_sort(['id-', 'age-'], 'ORDER BY u.id DESC, u.age DESC') # Dict test_sort(OrderedDict([['id', -1], ['age', -1]]), 'ORDER BY u.id DESC, u.age DESC') # Fail self.assertRaises(AssertionError, test_sort, OrderedDict([['id', -2], ['age', -1]]), '')
def test_length_geom_linestring_in_28992(session): if session.bind.name == "postgresql": pytest.skip( "Postgres already has a constrain that checks on the length") # around 0.109m factories.ChannelFactory( the_geom="SRID=4326;LINESTRING(" "122829.98048471771471668 473589.68720115750329569, " "122830.00490918199648149 473589.68720115750329569, " "122829.95687440223991871 473589.70983449439518154, " "122829.9793449093849631 473589.68850379559444264)") # around 0.001m channel_too_short = factories.ChannelFactory( the_geom="SRID=4326;LINESTRING(" "122829.98185859377554152 473589.69248294795397669, " "122829.98260150455462281 473589.69248294795397669)", ) check_length_linestring = QueryCheck( column=models.Channel.the_geom, invalid=Query(models.Channel).filter( geo_func.ST_Length(models.Channel.the_geom) < 0.05), message= "Length of the v2_channel is too short, should be at least 0.05m", ) errors = check_length_linestring.get_invalid(session) assert len(errors) == 1 assert errors[0].id == channel_too_short.id
def __init__( self, # An optional `Declarative class <http://docs.sqlalchemy.org/en/latest/orm/tutorial.html#declare-a-mapping>`_ to query. declarative_class=None, # Optionally, begin with an existing query_. query=None): if declarative_class: assert _is_mapped_class(declarative_class) # If a query is provided, try to infer the declarative_class. if query is not None: assert isinstance(query, Query) self._query = query try: self._select = self._get_joinpoint_zero_class() except: # We can't infer it. Use what's provided instead, and add this to the query. assert declarative_class self._select = declarative_class self._query = self._query.select_from(declarative_class) else: # If a declarative_class was provided, make sure it's consistent with the inferred class. if declarative_class: assert declarative_class is self._select else: # The declarative class must be provided if the query wasn't. assert declarative_class # Since a query was not provied, create an empty `query <http://docs.sqlalchemy.org/en/latest/orm/query.html>`_; ``to_query`` will fill in the missing information. self._query = Query([]).select_from(declarative_class) # Keep track of the last selectable construct, to generate the select in ``to_query``. self._select = declarative_class
def test_query_check_on_pumpstation(session): connection_node1 = factories.ConnectionNodeFactory() connection_node2 = factories.ConnectionNodeFactory() factories.ManholeFactory(connection_node=connection_node1, bottom_level=1.0) factories.ManholeFactory(connection_node=connection_node2, bottom_level=-1.0) pumpstation_wrong = factories.PumpstationFactory( connection_node_start=connection_node1, lower_stop_level=0.0) factories.PumpstationFactory(connection_node_start=connection_node2, lower_stop_level=2.0) query = ( Query(models.Pumpstation).join( models.ConnectionNode, models.Pumpstation.connection_node_start_id == models.ConnectionNode.id, # noqa: E501 ).join( models.Manhole, models.Manhole.connection_node_id == models.ConnectionNode.id, ).filter( models.Pumpstation.lower_stop_level <= models.Manhole.bottom_level, )) check = QueryCheck( column=models.Pumpstation.lower_stop_level, invalid=query, message="Pumpstation lower_stop_level should be higher than Manhole " "bottom_level", ) invalids = check.get_invalid(session) assert len(invalids) == 1 assert invalids[0].id == pumpstation_wrong.id
def query(self, fields=None, _filter=None, sort=None, limit=None, raise_ex=False, return_field_value=None): if return_field_value and limit and limit == 1: fields = return_field_value query = Query(self.table) query = self._generate_filter(query, _filter) if _filter else query query = self._select_fields(query, fields) if fields else query query = self._sort_query(query, sort) if sort else query limit = self._validate_limit(limit) with Session(self.engine) as session: if isinstance(limit, int): if limit == 1: data = query.with_session(session).limit(1).all() data = data[0] if data else {} else: data = query.with_session(session).limit(limit).all() else: data = query.with_session(session).all() self._error('query', raise_ex) data = self._deserialize(data, fields) if limit == 1 and return_field_value and data: data = data.get(return_field_value) return data
def get_flavor_by_id(flavor_id: UUID) -> Union[Dict, None]: with db_session() as session: flavor: Flavor = Query(Flavor, session=session).filter_by(id=flavor_id).one() if not flavor: return None return flavor.to_dict()
def get_flavor_by_code(flavor_code: int) -> Union[Dict, None]: with db_session() as session: flavor: Flavor = Query( Flavor, session=session).filter_by(code=flavor_code).one() if not flavor: return None return flavor.to_dict()
def query1(*args): engine = create_engine("sqlite:///fundamental.sqlite") Session = sessionmaker(bind=engine) session = Session() args = [full.code, full.report_date, full.roe] return Query(args).with_session(session).filter( full.trade_date == pd.to_datetime('2013-07-18'))
def get_size_by_code(size_code: int) -> Union[Dict, None]: with db_session() as session: size: Size = Query(Size, session=session).filter_by(code=size_code).one() if not size: return None return size.to_dict()
def query(self, dt, *args, **kwargs): args = list(args) + [ fundamental.code, func.max(fundamental.report_date).label('report_date') ] return Query(args).with_session( self.session).filter(fundamental.report_date < dt)