示例#1
0
    async def set_active_state_of_custom_poi(
            self, *, db: AsyncSession, obj_in: schemas.CutomDataUploadState,
            current_user: models.User):
        """Set active state of custom poi."""
        data_upload_obj = await db.execute(
            select(models.DataUpload).filter(
                models.DataUpload.id == obj_in.data_upload_id))
        data_upload_obj = data_upload_obj.scalars().first()
        if data_upload_obj.user_id != current_user.id:
            raise HTTPException(status_code=400,
                                detail="User ID does not match")

        data_upload_ids_obj = current_user.active_data_upload_ids

        if obj_in.state == False and data_upload_obj.id in data_upload_ids_obj:
            try:
                data_upload_ids_obj.remove(obj_in.data_upload_id)
            except ValueError:
                print("Data upload doesn't exist")
        elif obj_in.state == True and data_upload_obj.id not in data_upload_ids_obj:
            data_upload_ids_obj.append(obj_in.data_upload_id)
        else:
            return current_user

        current_user.active_data_upload_ids = data_upload_ids_obj
        flag_modified(current_user, "active_data_upload_ids")
        db.add(current_user)
        await db.commit()
        await db.refresh(current_user)
        return current_user
示例#2
0
async def add_user_to_group_channel(channel_id: str, user_name: str,
                                    user_hash: str, session: AsyncSession):
    user = await get_user(user_name, user_hash, session)
    if user is None:
        return False
    async with session.begin():
        session.add(models.Membership(user_id=user.id, channel_id=channel_id))
    return True
示例#3
0
async def new_group_channel(owner_id: int, name: str,
                            session: AsyncSession) -> str:
    channel = models.Channel(name=name,
                             type=models.ChannelType.GROUP,
                             owner_id=owner_id)
    async with session.begin():
        session.add(channel)
    return channel.id
 async def resolve_users(self,
                         info: ResolveInfo,
                         ids: Optional[List[int]] = None) -> List[User]:
     engine = await _get_engine(info)
     statement = select(User)
     if ids:
         statement = statement.filter(User.id.in_(ids))
     async with AsyncSession(engine) as session:
         query_result = await session.execute(statement)
     users: List[User] = query_result.scalars().fetchall()
     return users
示例#5
0
    async def update_scenario_features(
        self,
        db: AsyncSession,
        current_user: models.User,
        scenario_id: int,
        layer_name: str,
        feature_in: schemas.ScenarioFeatureUpdate,
    ) -> Any:
        layer = scenario_layer_models[layer_name.value]
        features = feature_in.features
        features_obj = {}
        feature_ids = []
        for feature in features:
            features_obj[feature.id] = {}
            feature_dict = {}
            # Check if population modified intersect with sub study area
            if layer_name.value == schemas.ScenarioLayerFeatureEnum.population_modified.value:
                point = WKTElement(feature.geom, srid=4326)
                statement = select(models.SubStudyArea).where(
                    and_(models.SubStudyArea.geom.ST_Intersects(point))
                )
                sub_study_area_result = await db.execute(statement)
                sub_study_area_result = sub_study_area_result.scalars().all()
                if len(sub_study_area_result) == 0:
                    raise HTTPException(
                        status_code=400,
                        detail="The population feature does not intersect with any sub study area",
                    )
                feature_dict["sub_study_area_id"] = sub_study_area_result[0].id
            try:
                for key, value in feature:
                    if key == "id":
                        feature_ids.append(value)
                        continue
                    elif isinstance(value, enum.Enum):
                        feature_dict[key] = value.value
                    else:
                        feature_dict[key] = value
            except Exception as e:
                raise HTTPException(status_code=400, detail="Invalid feature")

            features_obj[feature.id] = feature_dict

        features_in_db = await db.execute(
            select(layer).where(and_(layer.scenario_id == scenario_id, layer.id.in_(feature_ids)))
        )
        features_in_db = features_in_db.scalars().fetchall()

        for db_feature in features_in_db:
            feature_id = db_feature.id
            for key, value in features_obj[feature_id].items():
                if value is not None:
                    # TODO: For population check if geometry and building with {building_modified_id} intersect
                    setattr(db_feature, key, value)

        db.add_all(features_in_db)
        await db.commit()

        # Execute population distribution on population modified
        if layer_name.value in (schemas.ScenarioLayerFeatureEnum.building_modified.value, schemas.ScenarioLayerFeatureEnum.population_modified.value):
            await db.execute(
                func.basic.population_modification(scenario_id)
            )
            await db.commit()

        for feature in features_in_db:
            await db.refresh(feature)
        
        return features_in_db
示例#6
0
    async def create_scenario_features(
        self,
        db: AsyncSession,
        current_user: models.User,
        scenario_id: int,
        layer_name: str,
        feature_in: schemas.ScenarioFeatureCreate,
    ) -> Any:
        layer = scenario_layer_models[layer_name.value]
        features = feature_in.features
        features_in_db = []
        for feature in features:
            feature_dict = {}

            feature_dict["scenario_id"] = scenario_id

            # Check if population modified intersect with sub study area
            if layer_name.value == schemas.ScenarioLayerFeatureEnum.population_modified.value:
                point = WKTElement(feature.geom, srid=4326)
                statement = select(models.SubStudyArea).where(
                    and_(models.SubStudyArea.geom.ST_Intersects(point))
                )
                sub_study_area_result = await db.execute(statement)
                sub_study_area_result = sub_study_area_result.scalars().all()
                if len(sub_study_area_result) == 0:
                    raise HTTPException(
                        status_code=400,
                        detail="The population feature does not intersect with any sub study area",
                    )
                feature_dict["sub_study_area_id"] = sub_study_area_result[0].id
            try:
                for key, value in feature:
                    if (
                        key == "uid"
                        and layer_name.value == schemas.ScenarioLayerFeatureEnum.poi_modified.value
                    ):
                        if value is None:
                            # new POI
                            feature_dict["uid"] = uuid.uuid4().hex
                        else:
                            # existing POI
                            feature_dict["uid"] = value
                            splited_values = value.split("_")
                            if len(splited_values) >= 5:
                                feature_dict["data_upload_id"] = int(
                                    splited_values[-1].replace("u", "")
                                )

                            # TODO: check if uid is valid (poi / poi_user)
                    elif (
                        layer_name.value == schemas.ScenarioLayerFeatureEnum.way_modified.value
                        and key == "way_id"
                    ):
                        if value is not None:
                            feature_dict["way_id"] = value

                    # TODO: For population check if geometry and building with {building_modified_id} intersect

                    elif isinstance(value, enum.Enum):
                        feature_dict[key] = value.value
                    elif key == "class_id" and value is None:
                        feature_dict[key] = 100
                    elif value is None:
                        continue
                    else:
                        feature_dict[key] = value
                feature_obj = layer.from_orm(layer(**feature_dict))
                features_in_db.append(feature_obj)
            except Exception as e:
                raise HTTPException(status_code=400, detail="Invalid feature")

        db.add_all(features_in_db)
        await db.commit()
        # Execute population distribution on population modified
        if layer_name.value in (schemas.ScenarioLayerFeatureEnum.building_modified.value, schemas.ScenarioLayerFeatureEnum.population_modified.value):
            await db.execute(
                func.basic.population_modification(scenario_id)
            )
            await db.commit()

        for feature in features_in_db:
            await db.refresh(feature)
        return features_in_db
示例#7
0
async def _create_user(username: str, password: str) -> None:
    config = load_configuration()
    engine = get_engine(config.database)
    async with AsyncSession(engine) as session:
        await session.run_sync(create_user, username, password)
示例#8
0
    async def compute_reached_pois_user(self, db: AsyncSession,
                                        current_user: models.User,
                                        data_upload_id: int):
        """Compute the reached pois for a certain data upload id."""

        # Check if data upload in current active data uploads
        if data_upload_id not in current_user.active_data_upload_ids:
            raise HTTPException(
                status_code=400,
                detail="Data upload id not in your active data uploads.")

        # Check if data upload is already computed
        data_upload_obj = await db.execute(
            select(models.DataUpload).where(
                models.DataUpload.id == data_upload_id))
        data_upload_obj = data_upload_obj.first()[0]
        if data_upload_obj.reached_poi_heatmap_computed == True:
            return {"msg": "Data upload already computed."}

        # Delete old reached pois for the data upload id
        await db.execute(
            delete(models.ReachedPoiHeatmap).where(
                models.ReachedPoiHeatmap.data_upload_id == data_upload_id))

        # Compute reached pois for the data upload id
        await db.execute(
            text("""
            SELECT r.* 
            FROM basic.study_area s, 
            LATERAL basic.reached_pois_heatmap(:table_name, s.geom, :user_id_input, :scenario_id_input, :data_upload_ids) r 
            WHERE s.id = :active_study_area_id
            """),
            {
                "active_study_area_id": current_user.active_study_area_id,
                "table_name": "poi_user",
                "user_id_input": current_user.id,
                "scenario_id_input": 0,
                "data_upload_ids": [data_upload_id],
            },
        )
        await db.commit()

        scenario_ids = await db.execute(
            text("""
            SELECT s.id  
            FROM customer.scenario s 
            WHERE s.data_upload_ids && :data_upload_ids"""),
            {"data_upload_ids": [data_upload_id]},
        )

        scenario_ids = scenario_ids.fetchall()
        scenario_ids = [scenario_id[0] for scenario_id in scenario_ids]

        # Calculate all scenario for the data upload id
        if len(scenario_ids) > 0:
            for scenario_id in scenario_ids:
                await db.execute(
                    text("""
                    SELECT r.* 
                    FROM basic.study_area s, 
                    LATERAL basic.reached_pois_heatmap(:table_name, s.geom, :user_id_input, :scenario_id_input) r 
                    WHERE s.id = :active_study_area_id
                    """),
                    {
                        "active_study_area_id":
                        current_user.active_study_area_id,
                        "table_name": "poi_modified",
                        "user_id_input": current_user.id,
                        "scenario_id_input": scenario_id,
                    },
                )
                await db.commit()

        # Update data upload
        data_upload_obj.reached_poi_heatmap_computed = True
        db.add(data_upload_obj)
        await db.commit()
async def get_session(request: Request) -> AsyncSession:
    state = cast(AppState, request.app.state)

    async with AsyncSession(state.engine) as session:
        yield session