def with_cleaned_data_in_database() -> Iterator[OrchestratorResult]:
    cleanup_cleaned_data_query = Query(CleanedData).filter(CleanedData.c.Contract_ID == TEST_CONTRACT)  # type: ignore
    delete_test_data(cleanup_cleaned_data_query)  # Cleanup in case of previously failed test

    with Database(DatabaseType.internal).transaction_context() as session:
        cleaned_data_count = cleanup_cleaned_data_query.with_session(session).count()
    assert cleaned_data_count == 0, "Found old test data in database when setting up the test"

    orchestrator_result = setup_orchestrator_result(use_real_database=True)
    orchestrator_result.orchestrator._initialize_forecast_run()
    test_run_id = cast(int, orchestrator_result.orchestrator._forecast_run_id)
    test_cleaned_data = [
        {
            "run_id": test_run_id,
            "Project_ID": "Test_Project",
            "Contract_ID": TEST_CONTRACT,
            "Wesco_Master_Number": "Test_Master_Number",
            "Date": "2020-03-01 00:00:00.000",
            "Date_YYYYMM": 202003,
            "Item_ID": -2,
            "Unit_Cost": 1,
            "Order_Quantity": 2,
            "Order_Cost": 2,
        }
    ]

    with Database(DatabaseType.internal).transaction_context() as session:
        session.execute(CleanedData.insert().values(test_cleaned_data))

    yield orchestrator_result

    assert delete_test_data(cleanup_cleaned_data_query) == len(test_cleaned_data)
    assert delete_test_data(Query(ForecastRun).filter(ForecastRun.id == test_run_id)) == 1  # type: ignore
Exemplo n.º 2
0
def test_length_geom_linestring_missing_epsg_from_global_settings(session):
    if session.bind.name == "postgresql":
        pytest.skip(
            "Postgres already has a constrain that checks on the length")
    factories.ChannelFactory(the_geom="SRID=4326;LINESTRING("
                             "-0.38222938832999598 -0.13872236685816669, "
                             "-0.38222930900909202 -0.13872236685816669)", )
    factories.ChannelFactory(the_geom="SRID=4326;LINESTRING("
                             "-0.38222938468305784 -0.13872235682908687, "
                             "-0.38222931083256106 -0.13872235591735235, "
                             "-0.38222930992082654 -0.13872207236791409, "
                             "-0.38222940929989008 -0.13872235591735235)", )

    q = Query(models.Channel).filter(
        geo_func.ST_Length(
            geo_func.ST_Transform(
                models.Channel.the_geom,
                Query(models.GlobalSetting.epsg_code).limit(1))) < 0.05)
    check_length_linestring = QueryCheck(
        column=models.Channel.the_geom,
        invalid=q,
        message=
        "Length of the v2_channel is too short, should be at least 0.05m",
    )

    errors = check_length_linestring.get_invalid(session)
    assert len(errors) == 0
Exemplo n.º 3
0
    def get_invalid(self, session: Session) -> List[NamedTuple]:
        definitions_in_use = self.to_check(session).filter(
            models.CrossSectionDefinition.id.in_(
                Query(models.CrossSectionLocation.definition_id).union_all(
                    Query(models.Pipe.cross_section_definition_id),
                    Query(models.Culvert.cross_section_definition_id),
                    Query(models.Weir.cross_section_definition_id),
                    Query(models.Orifice.cross_section_definition_id),
                )), )

        # closed_rectangle, circle, and egg cross-section definitions are always closed:
        closed_definitions = definitions_in_use.filter(
            models.CrossSectionDefinition.shape.in_([
                constants.CrossSectionShape.CLOSED_RECTANGLE,
                constants.CrossSectionShape.CIRCLE,
                constants.CrossSectionShape.EGG,
            ]))
        result = list(closed_definitions.with_session(session).all())

        # tabulated cross-section definitions are closed when the last element of 'width'
        # is zero
        tabulated_definitions = definitions_in_use.filter(
            models.CrossSectionDefinition.shape.in_([
                constants.CrossSectionShape.TABULATED_RECTANGLE,
                constants.CrossSectionShape.TABULATED_TRAPEZIUM,
            ]))
        for definition in tabulated_definitions.with_session(session).all():
            try:
                if float(definition.width.split(" ")[-1]) == 0.0:
                    # Closed channel
                    result.append(definition)
            except Exception:
                # Many things can go wrong, these are caught elsewhere
                pass
        return result
Exemplo n.º 4
0
def test_transform_ORM():
    # Define the transform query for both the geometry and the raster in a naive way
    wrong_query = Query([
        RasterTable.geom.ST_Transform(2154),
        RasterTable.rast.ST_Transform(2154)
    ])

    # Check the query
    assert str(wrong_query) == (
        "SELECT "
        "ST_AsEWKB("
        "ST_Transform(raster_table_orm.geom, :ST_Transform_2)) AS \"ST_Transform_1\", "
        "ST_AsEWKB("  # <= Note that the raster is processed as a Geometry here
        "ST_Transform(raster_table_orm.rast, :ST_Transform_4)) AS \"ST_Transform_3\" \n"
        "FROM raster_table_orm"
    )

    # Define the transform query for both the geometry and the raster in the correct way
    correct_query = Query([
        RasterTable.geom.ST_Transform(2154),
        RasterTable.rast.ST_Transform(2154, type_=Raster)
    ])

    # Check the query
    assert str(correct_query) == (
        "SELECT "
        "ST_AsEWKB("
        "ST_Transform(raster_table_orm.geom, :ST_Transform_2)) AS \"ST_Transform_1\", "
        "raster("  # <= This time the raster is correctly processed as a Raster
        "ST_Transform(raster_table_orm.rast, :ST_Transform_4)) AS \"ST_Transform_3\" \n"
        "FROM raster_table_orm"
    )
Exemplo n.º 5
0
def epsg_code_query():
    try:
        epsg_code = Query(
            models.GlobalSetting.epsg_code).limit(1).scalar_subquery()
    except AttributeError:
        epsg_code = Query(models.GlobalSetting.epsg_code).limit(1).as_scalar()
    return func.coalesce(epsg_code, literal(DEFAULT_EPSG)).label("epsg_code")
Exemplo n.º 6
0
def update_order(order_id: UUID, order: Dict) -> Dict:
    with db_session() as session:
        order_update = Query(Order,
                             session=session).filter_by(id=order_id).first()
        if not order_update:
            raise OrderNotFoundException()
        order_update_dict = order_update.to_dict()
        delete_customizations_by_order_id(order_id)

        customizations = []
        for customization_code in order["customizations"]:
            customization = get_customization_by_code(customization_code.value)
            customizations.append(customization)
            if not customization:
                raise CustomizationNotFoundException()

            session.add(
                OrderCustomization(order_id=order_id,
                                   customization_id=customization["id"]))
        size = get_size_by_id(order_update_dict["size_id"])
        flavor = get_flavor_by_id(order_update_dict["flavor_id"])
        order_update = Query(Order,
                             session=session).filter_by(id=order_id).first()
        order_update.setup_time = utils.calculate_time(size, flavor,
                                                       customizations)
        order_update.amount = utils.calculate_amount(size, customizations)
        session.commit()

    return get_order(order_id)
Exemplo n.º 7
0
def test_query_check_manhole_drain_level_calc_type_2(session):
    # manhole.drain_level can be null, but if manhole.calculation_type == 2 (Connected)
    # then manhole.drain_level >= manhole.bottom_level
    factories.ManholeFactory(drain_level=None)
    factories.ManholeFactory(drain_level=1)
    m3_error = factories.ManholeFactory(
        drain_level=None,
        calculation_type=constants.CalculationTypeNode.CONNECTED
    )  # drain_level cannot be null when calculation_type is CONNECTED
    m4_error = factories.ManholeFactory(
        drain_level=1,
        bottom_level=2,
        calculation_type=constants.CalculationTypeNode.CONNECTED,
    )  # bottom_level  >= drain_level when calculation_type is CONNECTED
    factories.ManholeFactory(
        drain_level=1,
        bottom_level=0,
        calculation_type=constants.CalculationTypeNode.CONNECTED,
    )
    factories.ManholeFactory(
        drain_level=None,
        bottom_level=0,
        calculation_type=constants.CalculationTypeNode.EMBEDDED,
    )

    query_drn_lvl_st_bttm_lvl = Query(models.Manhole).filter(
        models.Manhole.drain_level < models.Manhole.bottom_level,
        models.Manhole.calculation_type ==
        constants.CalculationTypeNode.CONNECTED,
    )
    query_invalid_not_null = Query(models.Manhole).filter(
        models.Manhole.calculation_type ==
        constants.CalculationTypeNode.CONNECTED,
        models.Manhole.drain_level == None,
    )
    check_drn_lvl_gt_bttm_lvl = QueryCheck(
        column=models.Manhole.bottom_level,
        invalid=query_drn_lvl_st_bttm_lvl,
        message="Manhole.drain_level >= Manhole.bottom_level when "
        "Manhole.calculation_type is CONNECTED",
    )
    check_invalid_not_null = QueryCheck(
        column=models.Manhole.drain_level,
        invalid=query_invalid_not_null,
        message=
        "Manhole.drain_level cannot be null when Manhole.calculation_type is "
        "CONNECTED",
    )
    errors1 = check_drn_lvl_gt_bttm_lvl.get_invalid(session)
    errors2 = check_invalid_not_null.get_invalid(session)
    assert len(errors1) == 1
    assert len(errors2) == 1
    assert m3_error.id == errors2[0].id
    assert m4_error.id == errors1[0].id
 def to_check(self, session):
     qs = super().to_check(session)
     if self.shapes is not None:
         qs = qs.filter(models.CrossSectionDefinition.shape.in_(
             self.shapes))
     return qs.filter(
         models.CrossSectionDefinition.id.in_(
             Query(models.CrossSectionLocation.definition_id).union_all(
                 Query(models.Pipe.cross_section_definition_id),
                 Query(models.Culvert.cross_section_definition_id),
                 Query(models.Weir.cross_section_definition_id),
                 Query(models.Orifice.cross_section_definition_id),
             )))
Exemplo n.º 9
0
 def list(self, type: RoleType, limit: int, marker: uuid.UUID):
     if type == RoleType.GLOBAL:
         starting_query = Query(AuthZRole).filter(
             AuthZRole.project_id == None)  # noqa: E711
     else:
         self.mount.validate_project_scope()
         starting_query = Query(AuthZRole).filter(
             AuthZRole.project_id == cherrypy.request.project.id)
     return self.paginate(AuthZRole,
                          ResponseRole,
                          limit,
                          marker,
                          starting_query=starting_query)
def with_cleaned_and_forecast_data_in_database() -> Iterator[Orchestrator]:
    cleanup_cleaned_data_query = Query(CleanedData).filter(
        CleanedData.c.Contract_ID == TEST_CONTRACT)  # type: ignore
    cleanup_forecast_data_query = Query(ForecastData).filter(
        ForecastData.c.Contract_ID == TEST_CONTRACT)  # type: ignore
    delete_test_data(cleanup_cleaned_data_query
                     )  # Cleanup in case of previously failed test
    delete_test_data(cleanup_forecast_data_query
                     )  # Cleanup in case of previously failed test

    with Database(DatabaseType.internal).transaction_context() as session:
        cleaned_data_count = cleanup_cleaned_data_query.with_session(
            session).count()
        forecast_data_count = cleanup_forecast_data_query.with_session(
            session).count()
    assert cleaned_data_count + forecast_data_count == 0, "Found old test data in database when setting up the test"

    runtime_config = RuntimeConfig(
        engine_run_type=EngineRunType.development,
        forecast_periods=1,
        output_location=".",
        prediction_month=pd.Timestamp(year=2020, month=2, day=1),
    )

    orchestrator = Orchestrator(
        runtime_config,
        Mock(spec=DataLoader),
        Mock(spec=DataOutput),
        Database(DatabaseType.internal),
        Mock(spec=Queue),
        Mock(),
        Mock(),
    )

    test_cleaned_data = _setup_cleaned_data(orchestrator)
    test_forecast_data = _setup_forecast_data(orchestrator)
    test_run_id = cast(int, orchestrator._forecast_run_id)

    with Database(DatabaseType.internal).transaction_context() as session:
        session.execute(CleanedData.insert().values(test_cleaned_data))
        session.execute(ForecastData.insert().values(test_forecast_data))

    yield orchestrator

    assert delete_test_data(cleanup_cleaned_data_query) == len(
        test_cleaned_data)
    assert delete_test_data(cleanup_forecast_data_query) == len(
        test_forecast_data)
    assert delete_test_data(
        Query(ForecastRun).filter(
            ForecastRun.id == test_run_id)) == 1  # type: ignore
Exemplo n.º 11
0
def test_dwf_calculator_surface(session):  # same algorithm as impervious surface
    conn_node = factories.ConnectionNodeFactory.create(id=1)
    sur1: Surface = factories.SurfaceFactory.create(
        id=1,
        code="030007",
        display_name="030007",
        nr_of_inhabitants=3.34,
        dry_weather_flow=120.0,
    )
    sur2: Surface = factories.SurfaceFactory.create(
        id=2,
        code="030008",
        display_name="030008",
        nr_of_inhabitants=1.92,
        dry_weather_flow=120.0,
    )
    sur_map1: SurfaceMap = factories.SurfaceMapFactory.create(
        id=1,
        surface_id=sur1.id,
        connection_node_id=conn_node.id,
        percentage=69.0,
    )
    sur_map2: SurfaceMap = factories.SurfaceMapFactory.create(
        id=2,
        surface_id=sur2.id,
        connection_node_id=conn_node.id,
        percentage=42.0,
    )
    # Because we use a raw SQL query in DWF we need to commit the data
    session.commit()

    calculator = DWFCalculator(session, InflowType.SURFACE)
    laterals = calculator.laterals

    weighted_flow_sum = (
        sur1.nr_of_inhabitants * sur1.dry_weather_flow * sur_map1.percentage / 100
        + sur2.nr_of_inhabitants * sur2.dry_weather_flow * sur_map2.percentage / 100
    )
    expected_values = [
        [i * 3600, (factor * weighted_flow_sum) / 1000 / 3600]
        for i, factor in DWF_FACTORS
    ]

    # Remove committed data
    Query(SurfaceMap).with_session(session).delete()
    Query(Surface).with_session(session).delete()
    Query(ConnectionNode).with_session(session).delete()
    session.commit()

    np.testing.assert_array_almost_equal(laterals[0]["values"], expected_values)
Exemplo n.º 12
0
    def __initialize_controls(self):
        if self._controls is None:
            self._controls = {"table": [], "memory": []}
            table_lookup = dict([
                (x.id, x)
                for x in Query(ControlTable).with_session(self.session).all()
            ])
            memory_lookup = dict([
                (x.id, x)
                for x in Query(ControlMemory).with_session(self.session).all()
            ])
            maps_lookup = {}

            for map_item in Query([ControlMeasureMap
                                   ]).with_session(self.session).all():
                if map_item.measure_group_id not in maps_lookup:
                    maps_lookup[map_item.measure_group_id] = []
                maps_lookup[map_item.measure_group_id].append(
                    control_measure_map_to_measure_location(map_item))

            all_controls = (Query([
                Control, ControlGroup, ControlMeasureGroup
            ]).join(ControlGroup,
                    ControlMeasureGroup).with_session(self.session).filter(
                        Control.control_group_id == self._control_group_id,
                        ControlGroup.id == self._control_group_id,
                    ).all())

            for control, group, measuregroup in all_controls:
                control: Control
                maps: List[ControlMeasureGroup] = maps_lookup[measuregroup.id]

                api_control = None

                if control.control_type is ControlType.table:
                    table: ControlTable = table_lookup[control.control_id]
                    measure_spec = to_measure_specification(table, group, maps)
                    api_control = to_table_control(control, table,
                                                   measure_spec)
                elif control.control_type is ControlType.memory:
                    memory: ControlMemory = memory_lookup[control.control_id]
                    measure_spec = to_measure_specification(
                        memory, group, maps)
                    api_control = to_memory_control(control, memory,
                                                    measure_spec)
                else:
                    raise SchematisationError(
                        f"Unknown control_type '{control.control_type.value}'")
                self._controls[control.control_type.value].append(api_control)
Exemplo n.º 13
0
    def test_join_multiple(self):
        """ Test join() same table multiple times"""

        mq = models.Edit.mongoquery(Query([models.Edit]))
        mq = mq.query(project=['id'],
                      outerjoin={
                          'user': {
                              'project': ['name']
                          },
                          'creator': {
                              'project': ['id', 'tags'],
                              'filter': {
                                  'id': {
                                      '$lt': 1
                                  }
                              }
                          }
                      })
        q = mq.end()

        qs = q2sql(q)
        self._check_qs(
            """SELECT u_1.id AS u_1_id, u_1.name AS u_1_name, u_2.id AS u_2_id, u_2.tags AS u_2_tags, e.id AS e_id
        FROM e LEFT OUTER JOIN u AS u_1 ON u_1.id = e.uid LEFT OUTER JOIN u AS u_2 ON u_2.id = e.cuid AND u_2.id < 1""",
            qs)
Exemplo n.º 14
0
    def test_limit(self):
        """ Test limit() """
        m = models.User

        limit = lambda limit=None, skip=None: m.mongoquery(Query([m])).limit(
            limit, skip).end()

        def test_limit(lim, skip, expected_endswith):
            qs = q2sql(limit(lim, skip))
            self.assertTrue(
                qs.endswith(expected_endswith),
                '{!r} should end with {!r}'.format(qs, expected_endswith))

        # Skip
        test_limit(None, None, 'FROM u')
        test_limit(None, -1, 'FROM u')
        test_limit(None, 0, 'FROM u')
        test_limit(None, 1, 'LIMIT ALL OFFSET 1')
        test_limit(None, 9, 'LIMIT ALL OFFSET 9')

        # Limit
        test_limit(-1, None, 'FROM u')
        test_limit(0, None, 'FROM u')
        test_limit(1, None, 'LIMIT 1')
        test_limit(9, None, 'LIMIT 9')

        # Both
        test_limit(5, 10, 'LIMIT 5 OFFSET 10')

        # Twice
        q = limit(limit=10)
        q = m.mongoquery(q).limit(limit=15, skip=30).end()
        qs = q2sql(q)
        self.assertTrue(qs.endswith('LIMIT 15 OFFSET 30'), qs)
Exemplo n.º 15
0
    def test_join(self):
        """ Test join() """
        m = models.User

        # Okay
        mq = m.mongoquery(Query([models.User]))
        mq = mq.join(('articles', 'comments'))
        q = mq.end()
        qs = q2sql(q)
        self.assertIn('FROM u', qs)
        #self.assertIn('LEFT OUTER JOIN a', qs)  # immediateload(), used in this case, does not add any JOIN clauses: a subquery is used for that
        #self.assertIn('LEFT OUTER JOIN c', qs)

        # Unknown relation
        mq = m.mongoquery(Query([models.User]))
        self.assertRaises(AssertionError, mq.join, ('???'))
Exemplo n.º 16
0
def example_with_statement():
    # DataAPI supports with statement for handling transaction
    with DataAPI(database=database, resource_arn=resource_arn, secret_arn=secret_arn) as data_api:

        # start transaction

        insert: Insert = Insert(Pets, {'name': 'dog'})
        # INSERT INTO pets (name) VALUES ('dog')

        # `execute` accepts SQL statement as str or SQL Alchemy SQL objects
        result: Result = data_api.execute(insert)
        print(result.number_of_records_updated)
        # 1

        query = Query(Pets).filter(Pets.id == 1)
        result: Result = data_api.execute(query)  # or data_api.execute('select id, name from pets')
        # SELECT pets.id, pets.name FROM pets WHERE pets.id = 1

        # `Result` like a Result object in SQL Alchemy
        print(result.scalar())
        # 1

        print(result.one())
        # [Record<id=1, name='dog'>]

        # `Result` is Sequence[Record]
        records: List[Record] = list(result)
        print(records)
        # [Record<id=1, name='dog'>]

        # Record is Sequence and Iterator
        record = records[0]
        print(record[0])
        # 1
        print(record[1])
        # dog

        for column in record:
            print(column)
            # 1 ...

        # show record as dict()
        print(record.dict())
        # {'id': 1, 'name': 'dog'}

        # batch insert
        insert: Insert = Insert(Pets)
        data_api.batch_execute(insert, [
            {'id': 2, 'name': 'cat'},
            {'id': 3, 'name': 'snake'},
            {'id': 4, 'name': 'rabbit'},
        ])

        result = data_api.execute('select * from pets')
        print(list(result))
        # [Record<id=1, name='dog'>, Record<id=2, name='cat'>, Record<id=3, name='snake'>, Record<id=4, name='rabbit'>]

        # result is a sequence object
        for record in result:
            print(record)
Exemplo n.º 17
0
def show_projects():
    df = get_dataframe_from_query(
        Query([
            db.Project.name.label('Название проекта'),
            db.Project.description.label('Описание'),
        ]))
    st.dataframe(df)

    selected_project = st.selectbox('Выбор проекта',
                                    df['Название проекта'].unique())
    st.dataframe(
        get_dataframe_from_query(
            Query([
                db.Model.name.label('Модель'),
                db.Model.description.label('Описание'),
                db.Model.params.label('Параметры модели'),
                db.Model.metrics.label('Метрики'),
                db.Dataset.name.label('Датасет'),
                case([(db.Model.pretrained, 'Да'),
                      (~db.Model.pretrained, 'Нет')]).label('Предобученная'),
                db.Model.training_time.label('Время обучения, с'),
                db.Model.created_at.label('Дата обучения'),
            ]).join(db.Project).join(db.Dataset).filter(
                db.Project.name == selected_project).order_by(
                    desc(db.Model.created_at))))
Exemplo n.º 18
0
def test_node_distance(session):
    if session.bind.name == "postgresql":
        pytest.skip("Check only applicable to spatialite")
    con1_too_close = factories.ConnectionNodeFactory(
        the_geom="SRID=4326;POINT(4.728282 52.64579283592512)"
    )
    con2_too_close = factories.ConnectionNodeFactory(
        the_geom="SRID=4326;POINT(4.72828 52.64579283592512)"
    )
    # Good distance
    factories.ConnectionNodeFactory(
        the_geom="SRID=4326;POINT(4.726838755789598 52.64514133594995)"
    )

    # sanity check to see the distances between the nodes
    node_a = aliased(models.ConnectionNode)
    node_b = aliased(models.ConnectionNode)
    distances_query = Query(
        geo_func.ST_Distance(node_a.the_geom, node_b.the_geom, 1)
    ).filter(node_a.id != node_b.id)
    # Shows the distances between all 3 nodes: node 1 and 2 are too close
    distances_query.with_session(session).all()

    check = ConnectionNodesDistance(minimum_distance=10)
    invalid = check.get_invalid(session)
    assert len(invalid) == 2
    invalid_ids = [i.id for i in invalid]
    assert con1_too_close.id in invalid_ids
    assert con2_too_close.id in invalid_ids
Exemplo n.º 19
0
 def list(self, image_id, region_id, zone_id, limit: int,
          marker: uuid.UUID):
     # TODO: allow filtering by tags
     project = cherrypy.request.project
     starting_query = Query(Instance).filter(
         Instance.project_id == project.id)
     if image_id is not None:
         starting_query = starting_query.filter(
             Instance.image_id == image_id)
     if region_id is not None:
         with cherrypy.request.db_session() as session:
             region = session.query(Region).filter(
                 Region.id == region_id).first()
             if region is None:
                 raise cherrypy.HTTPError(
                     404, "A region with the requested id does not exist.")
         starting_query = starting_query.filter(
             Instance.region_id == region.id)
     if zone_id is not None:
         with cherrypy.request.db_session() as session:
             zone = session.query(Zone).filter(Zone.id == zone_id).first()
             if zone is None:
                 raise cherrypy.HTTPError(
                     404, "A zone with the requested id does not exist.")
         starting_query = starting_query.filter(Instance.zone_id == zone_id)
     return self.paginate(Instance,
                          ResponseInstance,
                          limit,
                          marker,
                          starting_query=starting_query)
Exemplo n.º 20
0
def purge(db_engine: Engine) -> None:
    """Remove all tasks from the database

    :param db_engine: engine for the database
    """
    with session_scope(bind=db_engine) as session:
        Query(db.Task, session).delete()
Exemplo n.º 21
0
    def test_sort(self):
        """ Test sort() """
        m = models.User

        sort = lambda sort_spec: m.mongoquery(Query([m])).sort(sort_spec).end()

        def test_sort(sort_spec, expected_ends):
            qs = q2sql(sort(sort_spec))
            self.assertTrue(
                qs.endswith(expected_ends),
                '{!r} should end with {!r}'.format(qs, expected_ends))

        # Empty
        test_sort(None, u'FROM u')
        test_sort([], u'FROM u')
        test_sort(OrderedDict(), u'FROM u')

        # List
        test_sort(['id-', 'age-'], 'ORDER BY u.id DESC, u.age DESC')

        # Dict
        test_sort(OrderedDict([['id', -1], ['age', -1]]),
                  'ORDER BY u.id DESC, u.age DESC')

        # Fail
        self.assertRaises(AssertionError, test_sort,
                          OrderedDict([['id', -2], ['age', -1]]), '')
Exemplo n.º 22
0
def test_length_geom_linestring_in_28992(session):
    if session.bind.name == "postgresql":
        pytest.skip(
            "Postgres already has a constrain that checks on the length")
    # around 0.109m
    factories.ChannelFactory(
        the_geom="SRID=4326;LINESTRING("
        "122829.98048471771471668 473589.68720115750329569, "
        "122830.00490918199648149 473589.68720115750329569, "
        "122829.95687440223991871 473589.70983449439518154, "
        "122829.9793449093849631 473589.68850379559444264)")
    # around 0.001m
    channel_too_short = factories.ChannelFactory(
        the_geom="SRID=4326;LINESTRING("
        "122829.98185859377554152 473589.69248294795397669, "
        "122829.98260150455462281 473589.69248294795397669)", )

    check_length_linestring = QueryCheck(
        column=models.Channel.the_geom,
        invalid=Query(models.Channel).filter(
            geo_func.ST_Length(models.Channel.the_geom) < 0.05),
        message=
        "Length of the v2_channel is too short, should be at least 0.05m",
    )

    errors = check_length_linestring.get_invalid(session)
    assert len(errors) == 1
    assert errors[0].id == channel_too_short.id
Exemplo n.º 23
0
    def __init__(
            self,
            # An optional `Declarative class <http://docs.sqlalchemy.org/en/latest/orm/tutorial.html#declare-a-mapping>`_ to query.
            declarative_class=None,
            # Optionally, begin with an existing query_.
            query=None):

        if declarative_class:
            assert _is_mapped_class(declarative_class)

        # If a query is provided, try to infer the declarative_class.
        if query is not None:
            assert isinstance(query, Query)
            self._query = query
            try:
                self._select = self._get_joinpoint_zero_class()
            except:
                # We can't infer it. Use what's provided instead, and add this to the query.
                assert declarative_class
                self._select = declarative_class
                self._query = self._query.select_from(declarative_class)
            else:
                # If a declarative_class was provided, make sure it's consistent with the inferred class.
                if declarative_class:
                    assert declarative_class is self._select
        else:
            # The declarative class must be provided if the query wasn't.
            assert declarative_class
            # Since a query was not provied, create an empty `query <http://docs.sqlalchemy.org/en/latest/orm/query.html>`_; ``to_query`` will fill in the missing information.
            self._query = Query([]).select_from(declarative_class)
            # Keep track of the last selectable construct, to generate the select in ``to_query``.
            self._select = declarative_class
Exemplo n.º 24
0
def test_query_check_on_pumpstation(session):
    connection_node1 = factories.ConnectionNodeFactory()
    connection_node2 = factories.ConnectionNodeFactory()
    factories.ManholeFactory(connection_node=connection_node1,
                             bottom_level=1.0)
    factories.ManholeFactory(connection_node=connection_node2,
                             bottom_level=-1.0)
    pumpstation_wrong = factories.PumpstationFactory(
        connection_node_start=connection_node1, lower_stop_level=0.0)
    factories.PumpstationFactory(connection_node_start=connection_node2,
                                 lower_stop_level=2.0)

    query = (
        Query(models.Pumpstation).join(
            models.ConnectionNode,
            models.Pumpstation.connection_node_start_id ==
            models.ConnectionNode.id,  # noqa: E501
        ).join(
            models.Manhole,
            models.Manhole.connection_node_id == models.ConnectionNode.id,
        ).filter(
            models.Pumpstation.lower_stop_level <=
            models.Manhole.bottom_level, ))
    check = QueryCheck(
        column=models.Pumpstation.lower_stop_level,
        invalid=query,
        message="Pumpstation lower_stop_level should be higher than Manhole "
        "bottom_level",
    )
    invalids = check.get_invalid(session)
    assert len(invalids) == 1
    assert invalids[0].id == pumpstation_wrong.id
Exemplo n.º 25
0
    def query(self,
              fields=None,
              _filter=None,
              sort=None,
              limit=None,
              raise_ex=False,
              return_field_value=None):
        if return_field_value and limit and limit == 1:
            fields = return_field_value

        query = Query(self.table)
        query = self._generate_filter(query, _filter) if _filter else query
        query = self._select_fields(query, fields) if fields else query
        query = self._sort_query(query, sort) if sort else query
        limit = self._validate_limit(limit)

        with Session(self.engine) as session:
            if isinstance(limit, int):
                if limit == 1:
                    data = query.with_session(session).limit(1).all()
                    data = data[0] if data else {}
                else:
                    data = query.with_session(session).limit(limit).all()
            else:
                data = query.with_session(session).all()

        self._error('query', raise_ex)

        data = self._deserialize(data, fields)
        if limit == 1 and return_field_value and data:
            data = data.get(return_field_value)

        return data
Exemplo n.º 26
0
def get_flavor_by_id(flavor_id: UUID) -> Union[Dict, None]:
    with db_session() as session:
        flavor: Flavor = Query(Flavor,
                               session=session).filter_by(id=flavor_id).one()
        if not flavor:
            return None
        return flavor.to_dict()
Exemplo n.º 27
0
def get_flavor_by_code(flavor_code: int) -> Union[Dict, None]:
    with db_session() as session:
        flavor: Flavor = Query(
            Flavor, session=session).filter_by(code=flavor_code).one()
        if not flavor:
            return None
        return flavor.to_dict()
Exemplo n.º 28
0
def query1(*args):
    engine = create_engine("sqlite:///fundamental.sqlite")
    Session = sessionmaker(bind=engine)
    session = Session()
    args = [full.code, full.report_date, full.roe]
    return Query(args).with_session(session).filter(
        full.trade_date == pd.to_datetime('2013-07-18'))
Exemplo n.º 29
0
def get_size_by_code(size_code: int) -> Union[Dict, None]:
    with db_session() as session:
        size: Size = Query(Size,
                           session=session).filter_by(code=size_code).one()
        if not size:
            return None
        return size.to_dict()
Exemplo n.º 30
0
 def query(self, dt, *args, **kwargs):
     args = list(args) + [
         fundamental.code,
         func.max(fundamental.report_date).label('report_date')
     ]
     return Query(args).with_session(
         self.session).filter(fundamental.report_date < dt)