def insert_annotation_data(self, chunk: List[int], mat_metadata: dict): """Insert annotation data into database Args: chunk (List[int]): chunk of annotation ids mat_metadata (dict): materialized metadata Returns: bool: True if data was inserted """ aligned_volume = mat_metadata["aligned_volume"] analysis_version = mat_metadata["analysis_version"] annotation_table_name = mat_metadata["annotation_table_name"] datastack = mat_metadata["datastack"] session = sqlalchemy_cache.get(aligned_volume) engine = sqlalchemy_cache.get_engine(aligned_volume) # build table models AnnotationModel = create_annotation_model(mat_metadata, with_crud_columns=False) SegmentationModel = create_segmentation_model(mat_metadata) analysis_table = get_analysis_table(aligned_volume, datastack, annotation_table_name, analysis_version) query_columns = [] for col in AnnotationModel.__table__.columns: query_columns.append(col) for col in SegmentationModel.__table__.columns: if not col.name == "id": query_columns.append(col) chunked_id_query = query_id_range(AnnotationModel.id, chunk[0], chunk[1]) anno_ids = (session.query( AnnotationModel.id).filter(chunked_id_query).filter( AnnotationModel.valid == True)) query = (session.query(*query_columns).join(SegmentationModel).filter( SegmentationModel.id == AnnotationModel.id).filter( SegmentationModel.id.in_(anno_ids))) data = query.all() mat_df = pd.DataFrame(data) mat_df = mat_df.to_dict(orient="records") SQL_URI_CONFIG = get_config_param("SQLALCHEMY_DATABASE_URI") analysis_sql_uri = create_analysis_sql_uri(SQL_URI_CONFIG, datastack, analysis_version) analysis_session, analysis_engine = create_session(analysis_sql_uri) try: analysis_engine.execute(analysis_table.insert(), [data for data in mat_df]) except Exception as e: celery_logger.error(e) analysis_session.rollback() finally: analysis_session.close() analysis_engine.dispose() session.close() engine.dispose() return True
def generic_report(id): aligned_volume_name, pcg_table_name = get_relevant_datastack_info( datastack_name) session = sqlalchemy_cache.get(aligned_volume_name) table = session.query(AnalysisTable).filter(AnalysisTable.id == id).first() make_dataset_models(table.analysisversion.dataset, [], version=table.analysisversion.version) Model = make_annotation_model( table.analysisversion.dataset, table.schema, table.tablename, version=table.analysisversion.version, ) n_annotations = Model.query.count() return render_template( "generic.html", n_annotations=n_annotations, dataset=table.analysisversion.dataset, analysisversion=table.analysisversion.version, version=__version__, table_name=table.tablename, schema_name=table.schema, )
def upload_data(self, data: List, bulk_upload_info: dict): aligned_volume = bulk_upload_info["aligned_volume"] model_data = { "annotation_table_name": bulk_upload_info["annotation_table_name"], "schema": bulk_upload_info["schema"], "pcg_table_name": bulk_upload_info["pcg_table_name"], } AnnotationModel = create_annotation_model(model_data) SegmentationModel = create_segmentation_model(model_data) session = sqlalchemy_cache.get(aligned_volume) engine = sqlalchemy_cache.get_engine(aligned_volume) try: with engine.begin() as connection: connection.execute(AnnotationModel.__table__.insert(), data[0]) connection.execute(SegmentationModel.__table__.insert(), data[1]) except Exception as e: celery_logger.error(f"ERROR: {e}") raise self.retry(exc=Exception, countdown=3) finally: session.close() engine.dispose() return True
def insert_segmentation_data(materialization_data: dict, mat_metadata: dict) -> dict: """Insert supervoxel and root id data into segmentation table. Args: materialization_data (dict): supervoxel and/or root id data mat_metadata (dict): materialization metadata Returns: dict: returns description of number of rows inserted """ if not materialization_data: return {"status": "empty"} SegmentationModel = create_segmentation_model(mat_metadata) aligned_volume = mat_metadata.get("aligned_volume") session = sqlalchemy_cache.get(aligned_volume) engine = sqlalchemy_cache.get_engine(aligned_volume) try: with engine.begin() as connection: connection.execute(SegmentationModel.__table__.insert(), materialization_data) except SQLAlchemyError as e: session.rollback() celery_logger.error(e) finally: session.close() return {"Segmentation data inserted": len(materialization_data)}
def get(self, datastack_name: str, version: int, table_name: str): """get frozen table metadata Args: datastack_name (str): datastack name version (int): version number table_name (str): table name Returns: dict: dictionary of table metadata """ aligned_volume_name, pcg_table_name = get_relevant_datastack_info( datastack_name) session = sqlalchemy_cache.get(aligned_volume_name) analysis_version, analysis_table = get_analysis_version_and_table( datastack_name, table_name, version, session) schema = AnalysisTableSchema() tables = schema.dump(analysis_table) db = dynamic_annotation_cache.get_db(aligned_volume_name) ann_md = db.get_table_metadata(table_name) ann_md.pop("id") ann_md.pop("deleted") tables.update(ann_md) return tables, 200
def datastack_view(datastack_name): aligned_volume_name, pcg_table_name = get_relevant_datastack_info( datastack_name) session = sqlalchemy_cache.get(aligned_volume_name) version_query = session.query(AnalysisVersion).filter( AnalysisVersion.datastack == datastack_name) show_all = request.args.get("all", False) is not False if not show_all: version_query = version_query.filter(AnalysisVersion.valid == True) versions = version_query.order_by(AnalysisVersion.version.desc()).all() if len(versions) > 0: schema = AnalysisVersionSchema(many=True) df = make_df_with_links_to_id( versions, schema, "views.version_view", "version", datastack_name=datastack_name, ) df_html_table = df.to_html(escape=False) else: df_html_table = "" return render_template( "datastack.html", datastack=datastack_name, table=df_html_table, version=__version__, )
def get(self, datastack_name: str, version: int): """get frozen tables Args: datastack_name (str): datastack name version (int): version number Returns: list(str): list of frozen tables in this version """ aligned_volume_name, pcg_table_name = get_relevant_datastack_info( datastack_name) session = sqlalchemy_cache.get(aligned_volume_name) av = (session.query(AnalysisVersion).filter( AnalysisVersion.version == version).filter( AnalysisVersion.datastack == datastack_name).first()) if av is None: return None, 404 response = (session.query(AnalysisTable).filter( AnalysisTable.analysisversion_id == av.id).filter( AnalysisTable.valid == True).all()) if response is None: return None, 404 return [r.table_name for r in response], 200
def get(self, datastack_name: str, version: int, table_name: str): """get annotation count in table Args: datastack_name (str): datastack name of table version (int): version of table table_name (str): table name Returns: int: number of rows in this table """ aligned_volume_name, pcg_table_name = get_relevant_datastack_info( datastack_name) Session = sqlalchemy_cache.get(aligned_volume_name) Model = get_flat_model(datastack_name, table_name, version, Session) Session = sqlalchemy_cache.get("{}__mat{}".format( datastack_name, version)) return Session().query(Model).count(), 200
def get_new_roots(self, supervoxel_chunk: list, mat_metadata: dict): """Get new roots from supervoxels ids of expired roots Args: supervoxel_chunk (list): [description] mat_metadata (dict): [description] Returns: dict: dicts of new root_ids """ pcg_table_name = mat_metadata.get("pcg_table_name") materialization_time_stamp = mat_metadata["materialization_time_stamp"] try: formatted_mat_ts = datetime.datetime.strptime( materialization_time_stamp, "%Y-%m-%dT%H:%M:%S.%f") except: formatted_mat_ts = datetime.datetime.strptime( materialization_time_stamp, "%Y-%m-%d %H:%M:%S.%f") root_ids_df = pd.DataFrame(supervoxel_chunk, dtype=object) supervoxel_col_name = list( root_ids_df.loc[:, root_ids_df.columns.str.endswith("supervoxel_id")]) root_id_col_name = list( root_ids_df.loc[:, root_ids_df.columns.str.endswith("root_id")]) supervoxel_df = root_ids_df.loc[:, supervoxel_col_name[0]] supervoxel_data = supervoxel_df.to_list() root_id_array = lookup_new_root_ids(pcg_table_name, supervoxel_data, formatted_mat_ts) del supervoxel_data root_ids_df.loc[supervoxel_df.index, root_id_col_name[0]] = root_id_array root_ids_df.drop(columns=[supervoxel_col_name[0]]) SegmentationModel = create_segmentation_model(mat_metadata) aligned_volume = mat_metadata.get("aligned_volume") session = sqlalchemy_cache.get(aligned_volume) data = root_ids_df.to_dict(orient="records") try: session.bulk_update_mappings(SegmentationModel, data) session.commit() except Exception as e: session.rollback() celery_logger.error(f"ERROR: {e}") raise self.retry(exc=e, countdown=3) finally: session.close() return f"Number of rows updated: {len(data)}"
def get(self, aligned_volume_name): check_aligned_volume(aligned_volume_name) session = sqlalchemy_cache.get(aligned_volume_name) response = (session.query(AnalysisVersion).filter( AnalysisVersion.datastack == aligned_volume_name).all()) schema = AnalysisVersionSchema(many=True) versions, error = schema.dump(response) logging.info(versions) if versions: return versions, 200 else: logging.error(error) return abort(404)
def get_sql_supervoxel_ids(chunks: List[int], mat_metadata: dict) -> List[int]: """Iterates over columns with 'supervoxel_id' present in the name and returns supervoxel ids between start and stop ids. Parameters ---------- chunks: dict name of database to target mat_metadata : dict Materialization metadata Returns ------- List[int] list of supervoxel ids between 'start_id' and 'end_id' """ SegmentationModel = create_segmentation_model(mat_metadata) aligned_volume = mat_metadata.get("aligned_volume") session = sqlalchemy_cache.get(aligned_volume) columns = [ model_column.name for model_column in SegmentationModel.__table__.columns ] supervoxel_id_columns = [ model_column for model_column in columns if "supervoxel_id" in model_column ] mapped_columns = [ getattr(SegmentationModel, supervoxel_id_column) for supervoxel_id_column in supervoxel_id_columns ] try: filter_query = session.query(SegmentationModel.id, *mapped_columns) if len(chunks) > 1: query = filter_query.filter( or_(SegmentationModel.id).between(int(chunks[0]), int(chunks[1]))) elif len(chunks) == 1: query = filter_query.filter(SegmentationModel.id == chunks[0]) data = query.all() df = pd.DataFrame(data) return df.to_dict(orient="list") except Exception as e: celery_logger.error(e) session.rollback() finally: session.close()
def get(self, aligned_volume_name, version): check_aligned_volume(aligned_volume_name) session = sqlalchemy_cache.get(aligned_volume_name) response = (session.query(AnalysisTable).filter( AnalysisTable.analysisversion).filter( AnalysisVersion.version == version).filter( AnalysisVersion.datastack == aligned_volume_name).all()) schema = AnalysisTableSchema(many=True) tables, error = schema.dump(response) if tables: return tables, 200 else: logging.error(error) return abort(404)
def chunk_ids(mat_metadata, model, chunk_size: int): aligned_volume = mat_metadata.get("aligned_volume") session = sqlalchemy_cache.get(aligned_volume) q = session.query( model, func.row_number().over(order_by=model).label("row_count") ).from_self(model) if chunk_size > 1: q = q.filter(text("row_count %% %d=1" % chunk_size)) chunks = [id for id, in q] while chunks: chunk_start = chunks.pop(0) chunk_end = chunks[0] if chunks else None yield [chunk_start, chunk_end]
def get(self, datastack_name: str): """get materialized metadata for all valid versions Args: datastack_name (str): datastack name Returns: list: list of metadata dictionaries """ aligned_volume_name, pcg_table_name = get_relevant_datastack_info( datastack_name) session = sqlalchemy_cache.get(aligned_volume_name) response = (session.query(AnalysisVersion).filter( AnalysisVersion.datastack == datastack_name).filter( AnalysisVersion.valid == True).all()) if response is None: return "No valid versions found", 404 schema = AnalysisVersionSchema() return schema.dump(response, many=True), 200
def create_missing_segmentation_table(self, mat_metadata: dict) -> dict: """Create missing segmentation tables associated with an annotation table if it does not already exist. Parameters ---------- mat_metadata : dict Materialization metadata Returns: dict: Materialization metadata """ segmentation_table_name = mat_metadata.get("segmentation_table_name") aligned_volume = mat_metadata.get("aligned_volume") SegmentationModel = create_segmentation_model(mat_metadata) session = sqlalchemy_cache.get(aligned_volume) engine = sqlalchemy_cache.get_engine(aligned_volume) if (not session.query(SegmentationMetadata).filter( SegmentationMetadata.table_name == segmentation_table_name).scalar()): SegmentationModel.__table__.create(bind=engine, checkfirst=True) creation_time = datetime.datetime.utcnow() metadata_dict = { "annotation_table": mat_metadata.get("annotation_table_name"), "schema_type": mat_metadata.get("schema"), "table_name": segmentation_table_name, "valid": True, "created": creation_time, "pcg_table_name": mat_metadata.get("pcg_table_name"), } seg_metadata = SegmentationMetadata(**metadata_dict) try: session.add(seg_metadata) session.commit() except Exception as e: celery_logger.error(f"SQL ERROR: {e}") session.rollback() else: session.close() return mat_metadata
def get_supervoxel_ids(root_id_chunk: list, mat_metadata: dict): """Get supervoxel ids associated with expired root ids Args: root_id_chunk (list): [description] mat_metadata (dict): [description] Returns: dict: supervoxels of a group of expired root ids None: no supervoxel ids exist for the expired root id """ aligned_volume = mat_metadata.get("aligned_volume") SegmentationModel = create_segmentation_model(mat_metadata) session = sqlalchemy_cache.get(aligned_volume) columns = [column.name for column in SegmentationModel.__table__.columns] root_id_columns = [column for column in columns if "root_id" in column] expired_root_id_data = {} try: for root_id_column in root_id_columns: prefix = root_id_column.rsplit("_", 2)[0] supervoxel_name = f"{prefix}_supervoxel_id" supervoxels = [ data for data in session.query( SegmentationModel.id, getattr(SegmentationModel, root_id_column), getattr(SegmentationModel, supervoxel_name), ).filter( or_(getattr(SegmentationModel, root_id_column)).in_( root_id_chunk)) ] if supervoxels: expired_root_id_data[root_id_column] = pd.DataFrame( supervoxels).to_dict(orient="records") except Exception as e: raise e finally: session.close() if expired_root_id_data: return expired_root_id_data else: return None
def get(self, datastack_name: str): """get available versions Args: datastack_name (str): datastack name Returns: list(int): list of versions that are available """ aligned_volume_name, pcg_table_name = get_relevant_datastack_info( datastack_name) session = sqlalchemy_cache.get(aligned_volume_name) response = (session.query(AnalysisVersion).filter( AnalysisVersion.datastack == datastack_name).filter( AnalysisVersion.valid == True).all()) versions = [av.version for av in response] return versions, 200
def update_metadata(self, mat_metadata: dict): """Update 'last_updated' column in the segmentation metadata table for a given segmentation table. Args: mat_metadata (dict): materialization metadata Returns: str: description of table that was updated """ aligned_volume = mat_metadata["aligned_volume"] segmentation_table_name = mat_metadata["segmentation_table_name"] session = sqlalchemy_cache.get(aligned_volume) materialization_time_stamp = mat_metadata["materialization_time_stamp"] try: last_updated_time_stamp = datetime.datetime.strptime( materialization_time_stamp, "%Y-%m-%d %H:%M:%S.%f" ) except ValueError: last_updated_time_stamp = datetime.datetime.strptime( materialization_time_stamp, "%Y-%m-%dT%H:%M:%S.%f" ) try: seg_metadata = ( session.query(SegmentationMetadata) .filter(SegmentationMetadata.table_name == segmentation_table_name) .one() ) seg_metadata.last_updated = last_updated_time_stamp session.commit() except Exception as e: celery_logger.error(f"SQL ERROR: {e}") session.rollback() finally: session.close() return { f"Table: {segmentation_table_name}": f"Time stamp {materialization_time_stamp}" }
def update_segmentation_data(materialization_data: dict, mat_metadata: dict) -> dict: if not materialization_data: return {"status": "empty"} SegmentationModel = create_segmentation_model(mat_metadata) aligned_volume = mat_metadata.get("aligned_volume") session = sqlalchemy_cache.get(aligned_volume) try: session.bulk_update_mappings(SegmentationModel, materialization_data) session.commit() except Exception as e: session.rollback() celery_logger.error(f"ERROR: {e}") raise (e) finally: session.close() return f"Number of rows updated: {len(materialization_data)}"
def table_view(datastack_name, id: int): aligned_volume_name, pcg_table_name = get_relevant_datastack_info( datastack_name) session = sqlalchemy_cache.get(aligned_volume_name) table = session.query(AnalysisTable).filter(AnalysisTable.id == id).first() mapping = { "synapse": url_for("views.synapse_report", id=id, datastack_name=datastack_name), "cell_type_local": url_for("views.cell_type_local_report", id=id, datastack_name=datastack_name), } if table.schema in mapping: return redirect(mapping[table.schema]) else: return redirect( url_for("views.generic_report", datastack_name=datastack_name, id=id))
def get(self, datastack_name: str, version: int): """get version metadata Args: datastack_name (str): datastack name version (int): version number Returns: dict: metadata dictionary for this version """ aligned_volume_name, pcg_table_name = get_relevant_datastack_info( datastack_name) session = sqlalchemy_cache.get(aligned_volume_name) response = (session.query(AnalysisVersion).filter( AnalysisVersion.datastack == datastack_name).filter( AnalysisVersion.version == version).first()) if response is None: return "No version found", 404 schema = AnalysisVersionSchema() return schema.dump(response), 200
def get_ids_with_missing_roots(mat_metadata: dict) -> List: """Get a chunk generator of the primary key ids for rows that contain at least one missing root id. Finds the min and max primary key id values globally across the table where a missing root id is present in a column. Args: mat_metadata (dict): materialization metadata Returns: List: generator of chunked primary key ids. """ SegmentationModel = create_segmentation_model(mat_metadata) aligned_volume = mat_metadata.get("aligned_volume") session = sqlalchemy_cache.get(aligned_volume) columns = [ seg_column.name for seg_column in SegmentationModel.__table__.columns ] root_id_columns = [ root_column for root_column in columns if "root_id" in root_column ] query_columns = [ getattr(SegmentationModel, root_id_column).is_(None) for root_id_column in root_id_columns ] max_id = (session.query(func.max(SegmentationModel.id)).filter( or_(*query_columns)).scalar()) min_id = (session.query(func.min(SegmentationModel.id)).filter( or_(*query_columns)).scalar()) if min_id and max_id: if min_id < max_id: id_range = range(min_id, max_id + 1) return create_chunks(id_range, 500) elif min_id == max_id: return [min_id] else: celery_logger.info( f"No missing root_ids found in '{SegmentationModel.__table__.name}'" ) return None
def synapse_report(datastack_name, id): aligned_volume_name, pcg_table_name = get_relevant_datastack_info( datastack_name) session = sqlalchemy_cache.get(aligned_volume_name) table = session.query(AnalysisTable).filter(AnalysisTable.id == id).first() if table.schema != "synapse": abort(504, "this table is not a synapse table") make_dataset_models(table.analysisversion.datastack, [], version=table.analysisversion.version) SynapseModel = make_annotation_model( table.analysisversion.dataset, "synapse", table.tablename, version=table.analysisversion.version, ) synapses = SynapseModel.query.count() n_autapses = (SynapseModel.query.filter( SynapseModel.pre_pt_root_id == SynapseModel.post_pt_root_id).filter( and_(SynapseModel.pre_pt_root_id != 0, SynapseModel.post_pt_root_id != 0)).count()) n_no_root = SynapseModel.query.filter( or_(SynapseModel.pre_pt_root_id == 0, SynapseModel.post_pt_root_id == 0)).count() return render_template( "synapses.html", num_synapses=synapses, num_autapses=n_autapses, num_no_root=n_no_root, dataset=table.analysisversion.dataset, analysisversion=table.analysisversion.version, version=__version__, table_name=table.tablename, schema_name="synapses", )
def cell_type_local_report(datastack_name, id): aligned_volume_name, pcg_table_name = get_relevant_datastack_info( datastack_name) session = sqlalchemy_cache.get(aligned_volume_name) table = AnalysisTable.query.filter(AnalysisTable.id == id).first_or_404() if table.schema != "cell_type_local": abort(504, "this table is not a cell_type_local table") make_dataset_models(table.analysisversion.dataset, [], version=table.analysisversion.version) CellTypeModel = make_annotation_model( table.analysisversion.dataset, table.schema, table.tablename, version=table.analysisversion.version, ) n_annotations = CellTypeModel.query.count() cell_type_merge_query = (db.session.query( CellTypeModel.pt_root_id, CellTypeModel.cell_type, func.count(CellTypeModel.pt_root_id).label("num_cells"), ).group_by(CellTypeModel.pt_root_id, CellTypeModel.cell_type).order_by("num_cells DESC")).limit(100) df = pd.read_sql(cell_type_merge_query.statement, db.get_engine(), coerce_float=False) return render_template( "cell_type_local.html", version=__version__, schema_name=table.schema, table_name=table.tablename, dataset=table.analysisversion.dataset, table=df.to_html(), )
def update_table_metadata(self, mat_info: List[dict]): """Update 'analysistables' with all the tables to be created in the frozen materialized database. Args: mat_info (List[dict]): list of dicts containing table metadata Returns: list: list of tables that were added to 'analysistables' """ aligned_volume = mat_info[0]["aligned_volume"] version = mat_info[0]["analysis_version"] session = sqlalchemy_cache.get(aligned_volume) tables = [] for mat_metadata in mat_info: version_id = (session.query(AnalysisVersion.id).filter( AnalysisVersion.version == version).first()) analysis_table = AnalysisTable( aligned_volume=aligned_volume, schema=mat_metadata["schema"], table_name=mat_metadata["annotation_table_name"], valid=False, created=mat_metadata["materialization_time_stamp"], analysisversion_id=version_id, ) tables.append(analysis_table.table_name) session.add(analysis_table) try: session.commit() except Exception as e: session.rollback() celery_logger.error(e) finally: session.close() return tables
def version_view(datastack_name: str, id: int): aligned_volume_name, pcg_table_name = get_relevant_datastack_info( datastack_name) session = sqlalchemy_cache.get(aligned_volume_name) version = session.query(AnalysisVersion).filter( AnalysisVersion.id == id).first() table_query = session.query(AnalysisTable).filter( AnalysisTable.analysisversion == version) tables = table_query.all() df = make_df_with_links_to_id( tables, AnalysisTableSchema(many=True), "views.table_view", "id", datastack_name=datastack_name, ) schema_url = "<a href='{}/schema/views/type/{}/view'>{}</a>" df["schema"] = df.schema.map( lambda x: schema_url.format(current_app.config["GLOBAL_SERVER"], x, x)) df["table_name"] = df.table_name.map( lambda x: "<a href='/annotation/views/aligned_volume/{}/table/{}'>{}</a>".format( aligned_volume_name, x, x)) with pd.option_context("display.max_colwidth", -1): output_html = df.to_html(escape=False) return render_template( "version.html", datastack=version.datastack, analysisversion=version.version, table=output_html, version=__version__, )
def shutdown_session(exception=None): for key in sqlalchemy_cache._sessions: session = sqlalchemy_cache.get(key) session.remove()
def health(): aligned_volume = current_app.config.get("TEST_DB_NAME", "annotation") session = sqlalchemy_cache.get(aligned_volume) n_versions = session.query(AnalysisVersion).count() session.close() return jsonify({aligned_volume: n_versions}), 200
def create_tables(self, bulk_upload_params: dict): table_name = bulk_upload_params["annotation_table_name"] aligned_volume = bulk_upload_params["aligned_volume"] pcg_table_name = bulk_upload_params["pcg_table_name"] last_updated = bulk_upload_params["last_updated"] seg_table_name = bulk_upload_params["seg_table_name"] upload_creation_time = bulk_upload_params["upload_creation_time"] session = sqlalchemy_cache.get(aligned_volume) engine = sqlalchemy_cache.get_engine(aligned_volume) if (not session.query(AnnoMetadata).filter( AnnoMetadata.table_name == table_name).scalar()): AnnotationModel = create_annotation_model(bulk_upload_params) AnnotationModel.__table__.create(bind=engine, checkfirst=True) anno_metadata_dict = { "table_name": table_name, "schema_type": bulk_upload_params.get("schema"), "valid": True, "created": upload_creation_time, "user_id": bulk_upload_params.get("user_id", "*****@*****.**"), "description": bulk_upload_params["description"], "reference_table": bulk_upload_params.get("reference_table"), "flat_segmentation_source": bulk_upload_params.get("flat_segmentation_source"), } anno_metadata = AnnoMetadata(**anno_metadata_dict) session.add(anno_metadata) if (not session.query(SegmentationMetadata).filter( SegmentationMetadata.table_name == table_name).scalar()): SegmentationModel = create_segmentation_model(bulk_upload_params) SegmentationModel.__table__.create(bind=engine, checkfirst=True) seg_metadata_dict = { "annotation_table": table_name, "schema_type": bulk_upload_params.get("schema"), "table_name": seg_table_name, "valid": True, "created": upload_creation_time, "pcg_table_name": pcg_table_name, "last_updated": last_updated, } seg_metadata = SegmentationMetadata(**seg_metadata_dict) try: session.flush() session.add(seg_metadata) session.commit() except Exception as e: celery_logger.error(f"SQL ERROR: {e}") session.rollback() raise e finally: drop_seg_indexes = index_cache.drop_table_indices( SegmentationModel.__table__.name, engine) # wait for indexes to drop time.sleep(10) drop_anno_indexes = index_cache.drop_table_indices( AnnotationModel.__table__.name, engine) celery_logger.info( f"Table {AnnotationModel.__table__.name} indices have been dropped {drop_anno_indexes}." ) celery_logger.info( f"Table {SegmentationModel.__table__.name} indices have been dropped {drop_seg_indexes}." ) session.close() return f"Tables {table_name}, {seg_table_name} created."
def check_tables(self, mat_info: list, analysis_version: int): """Check if each materialized table has the same number of rows as the aligned volumes tables in the live database that are set as valid. If row numbers match, set the validity of both the analysis tables as well as the analysis version (materialized database) as True. Args: mat_info (list): list of dicts containing metadata for each materialized table analysis_version (int): the materialized version number Returns: str: returns statement if all tables are valid """ aligned_volume = mat_info[0][ "aligned_volume"] # get aligned_volume name from datastack table_count = len(mat_info) analysis_database = mat_info[0]["analysis_database"] session = sqlalchemy_cache.get(aligned_volume) engine = sqlalchemy_cache.get_engine(aligned_volume) mat_session = sqlalchemy_cache.get(analysis_database) mat_engine = sqlalchemy_cache.get_engine(analysis_database) mat_client = dynamic_annotation_cache.get_db(analysis_database) versioned_database = (session.query(AnalysisVersion).filter( AnalysisVersion.version == analysis_version).one()) valid_table_count = 0 for mat_metadata in mat_info: annotation_table_name = mat_metadata["annotation_table_name"] live_table_row_count = (mat_session.query( MaterializedMetadata.row_count).filter( MaterializedMetadata.table_name == annotation_table_name).scalar()) mat_row_count = mat_client._get_table_row_count(annotation_table_name) celery_logger.info( f"ROW COUNTS: {live_table_row_count} {mat_row_count}") if mat_row_count == 0: celery_logger.warning( f"{annotation_table_name} has {mat_row_count} rows, skipping.") continue if live_table_row_count != mat_row_count: raise ValueError( f"""Row count doesn't match for table '{annotation_table_name}': Row count in '{aligned_volume}': {live_table_row_count} - Row count in {analysis_database}: {mat_row_count}""" ) celery_logger.info(f"{annotation_table_name} row counts match") schema = mat_metadata["schema"] table_metadata = None if mat_metadata.get("reference_table"): table_metadata = { "reference_table": mat_metadata.get("reference_table") } anno_model = make_flat_model( table_name=annotation_table_name, schema_type=schema, segmentation_source=None, table_metadata=table_metadata, ) live_mapped_indexes = index_cache.get_index_from_model( anno_model, mat_engine) mat_mapped_indexes = index_cache.get_table_indices( annotation_table_name, mat_engine) if live_mapped_indexes != mat_mapped_indexes: raise IndexMatchError( f"Indexes did not match: annotation indexes {live_mapped_indexes}; materialized indexes {mat_mapped_indexes}" ) celery_logger.info( f"Indexes matches: {live_mapped_indexes} {mat_mapped_indexes}") table_validity = (session.query(AnalysisTable).filter( AnalysisTable.analysisversion_id == versioned_database.id).filter( AnalysisTable.table_name == annotation_table_name).one()) table_validity.valid = True valid_table_count += 1 celery_logger.info( f"Valid tables {valid_table_count}, Mat tables {table_count}") if valid_table_count != table_count: raise ValueError( f"Valid table amounts don't match {valid_table_count} {table_count}" ) versioned_database.valid = True try: session.commit() return "All materialized tables match valid row number from live tables" except Exception as e: session.rollback() celery_logger.error(e) finally: session.close() mat_client.cached_session.close() mat_session.close() engine.dispose() mat_engine.dispose()