def create_session_factory(self, **kwargs) -> "DBManager": """ Args: kwargs => passed to SQLAlchemy sessionmaker constructor Description: Create SQLAlchemy scoped_session if self._scoped_sessions is True, otherwise sessionmaker. All kwargs are passed to sessionmaker constructor. This method should only be called _once_ by the DBManager. SQLAlchemy doesn't recommend manually closing all sessions, and the mechanics for doing so have changed across versions. See: https://docs.sqlalchemy.org/en/13/orm/session_api.html#session-and-sessionmaker and https://docs.sqlalchemy.org/en/13/orm/contextual.html#sqlalchemy.orm.scoping.scoped_session and https://docs.sqlalchemy.org/en/13/orm/session_api.html#sqlalchemy.orm.session.sessionmaker.close_all Preconditions: N/A Raises: RuntimeError: if self._session_factory is already defined, or if self._engine isn't defined """ # Ensure self._session_factory isn't already defined if self._session_factory: raise RuntimeError("Session factory already created") # Ensure self._engine is defined if not self._engine: raise RuntimeError("Cannot create session factory without an Engine") # Generate sessionmaker session factory self._session_factory = SessionMaker(bind=self._engine, **kwargs) # If scoped sessions, wrap in scoped_sessions factory if self._scoped_sessions: self._session_factory = ScopedSession(self._session_factory) return self
def __refresh_arrivals(session: scoped_session, naptan_id: StopId) -> List[Arrival]: ud = __UpdateDescription(CachedDataType.arrival, naptan_id, timedelta(minutes=1), __cache_arrivals, __delete_arrivals) __update_cache(session, ud) session.commit()
def insert_properties(properties: Set[str], session: scoped_session) -> Optional[Any]: """Insert all the properties as defined in the APIDocumentation into DB.""" prop_list = [BaseProperty(name=prop) for prop in properties if not session.query(exists().where(BaseProperty.name == prop)).scalar()] session.add_all(prop_list) session.commit() return None # if __name__ == "__main__": # Session = sessionmaker(bind=engine) # session = Session() # # doc = doc_gen("test", "test") # # Extract all classes with supportedProperty from both # classes = get_classes(doc.generate()) # # # Extract all properties from both # # import pdb; pdb.set_trace() # properties = get_all_properties(classes) # # Add all the classes # insert_classes(classes, session) # print("Classes inserted successfully") # # Add all the properties # insert_properties(properties, session) # print("Properties inserted successfully")
def insert_properties(properties: Set[str], session: scoped_session) -> Optional[Any]: """Insert all the properties as defined in the APIDocumentation into DB.""" prop_list = [BaseProperty(name=prop) for prop in properties if not session.query(exists().where(BaseProperty.name == prop)).scalar()] session.add_all(prop_list) session.commit() return None
def generate(_db_session: scoped_session, school_name: str, is_commit: bool = False): _school = School(school_name) db_session.add(_school) Board.generate(_db_session, _school) if is_commit: _db_session.commit() return _school
def __cache_arrivals(session: scoped_session, naptan_id: str) -> None: """Fetches the arrivals for a single stop point from TFL and stores in the database""" logger = logging.getLogger(__name__) arrivals = fetch_arrivals(naptan_id) logger.info(f"Adding arrivals for '{naptan_id}' to database") for arrival in arrivals: db_arrival = session.query(Arrival).filter( Arrival.arrival_id == arrival.arrival_id).one_or_none() if db_arrival is not None: db_arrival.update_with(arrival) else: session.add(arrival) session.commit()
def test_it_will_send_event_if_email_address_is_updated(mock_publisher: MagicMock, profile: Profile, session: scoped_session, commit: Callable[[], None]): event_publisher = send_update_events(publisher=mock_publisher) models_committed.connect(receiver=event_publisher) profile.add_email_address('*****@*****.**') session.add(profile) commit() assert mock_publisher.publish.call_count == 1 assert mock_publisher.publish.call_args[0][0] == {'id': '12345678', 'type': 'profile'}
def __save_update_timestamp(session: scoped_session, type: CachedDataType, id: str = "") -> None: """Stores the current time as the timestamp of the last update for a given CachedDataType and id pair""" ts = session.query(CacheTimestamp).filter( CacheTimestamp.data_type == type).filter( CacheTimestamp.data_id == id).one_or_none() if ts == None: session.add(CacheTimestamp(data_type=type, data_id=id)) else: ts.update_time = datetime.utcnow() session.commit()
def remove_stale_modification_records(session: scoped_session, stale_records_removal_interval: int = 900 ): """ Remove modification records which are older than last 1000 records. :param session: sqlalchemy session. :param stale_records_removal_interval: Interval time to run the removal job. """ timer = Timer(stale_records_removal_interval, remove_stale_modification_records, [session]) timer.daemon = True timer.start() # Get all valid records. valid_records = session.query(Modification).order_by( Modification.job_id.desc()).limit(1000).all() # If number of returned valid records is less than set limit then # there is nothing to clean up. if len(valid_records) < 1000: return else: # Get the job_id of last (oldest) valid record. job_id_of_last_valid_record = valid_records[-1].job_id # Get all records which are older than the oldest valid record. stale_records = session.query(Modification).filter( Modification.job_id < job_id_of_last_valid_record).all() for record in stale_records: session.delete(record) session.commit() session.remove()
def test_it_ignores_other_models_being_committed(orcid_token: OrcidToken, orcid_config: Dict[str, str], mock_orcid_client: MagicMock, session: scoped_session, url_safe_serializer: URLSafeSerializer): webhook_maintainer = maintain_orcid_webhook(orcid_config, mock_orcid_client, url_safe_serializer) models_committed.connect(receiver=webhook_maintainer) session.add(orcid_token) session.commit() assert mock_orcid_client.set_webhook.call_count == 0 assert mock_orcid_client.remove_webhook.call_count == 0
def test_it_will_send_event_for_affiliation_insert(mock_publisher: MagicMock, profile: Profile, session: scoped_session, commit: Callable[[], None]) -> None: event_publisher = send_update_events(publisher=mock_publisher) models_committed.connect(receiver=event_publisher) affiliation = Affiliation('1', Address(countries.get('gb'), 'City'), 'Organisation', Date(2017)) profile.add_affiliation(affiliation) session.add(profile) commit() assert mock_publisher.publish.call_count == 1 assert mock_publisher.publish.call_args[0][0] == {'id': '12345678', 'type': 'profile'}
def test_exception_is_handled_by_catch_exception_decorator(mock_publisher: MagicMock, profile: Profile, session: scoped_session, commit: Callable[[], None]) -> None: mock_publisher.publish.side_effect = Exception('Some Exception') event_publisher = send_update_events(publisher=mock_publisher) models_committed.connect(receiver=event_publisher) affiliation = Affiliation('1', Address(countries.get('gb'), 'City'), 'Organisation', Date(2017)) profile.add_affiliation(affiliation) session.add(profile) commit() assert mock_publisher.publish.call_count == 1
def sql(query, db: scoped_session = db, fetch=False, rollback=False): """ runs sql query on db """ if rollback: db.rollback() if fetch: return db.execute(query).fetchall() db.execute(query) db.commit()
def update_item(db_session: scoped_session, query_data: dict, trial_results: dict, mab: MAB) -> None or str: """ Check results of MAB trial against data in db, update db with trial results if necessary, as well as the pickled mab. If the trial has ended, set the best preforming thumbnail to be `next`, set `active_trial` to False, and delete all options associated with the item. """ # Re-pickle mab class test instance query_data["item"].mab = pickle.dumps(mab) # If there are more trials, set the next url if trial_results["new_trial"]: if trial_results["choice"] != query_data["item"].next: # item.id is already hashed with project, just need to add option_num next_id = hash_id(query_data["item"].id, trial_results["choice"]) query_data["item"].option_num = trial_results["choice"] # Find the next image using the id we get from `trial_results` try: option = Option.query.get(next_id) query_data["item"].next = option.content except: return "500 - internal server error" # The trial is over, set `next` to the best preforming thumbnail and `active_trial` to False else: query_data["item"].next = trial_results["best"] query_data["item"].active_trial = False # Delete all options as they are no longer needed for i in range(1, query_data["item"].total_num + 1): # item.id is already hashed, just add option_num option_id = hash_id(query_data["item"].id, i) Option.query.filter_by(id=option_id).delete() query_data["item"].total_num = 0 try: db_session.add(query_data["item"]) db_session.commit() return "200 - trial successfully updated." except: return "500 - internal server error."
def test_it_sets_a_webhook_when_a_profile_is_inserted(profile: Profile, orcid_config: Dict[str, str], mock_orcid_client: MagicMock, session: scoped_session, url_safe_serializer: URLSafeSerializer, commit: Callable[[], None]): webhook_maintainer = maintain_orcid_webhook(orcid_config, mock_orcid_client, url_safe_serializer) models_committed.connect(receiver=webhook_maintainer) session.add(profile) commit() assert mock_orcid_client.set_webhook.call_count == 1 assert mock_orcid_client.set_webhook.call_args[0][0] == '0000-0002-1825-0097' assert mock_orcid_client.set_webhook.call_args[0][1] == 'http://localhost/orcid-webhook/{}' \ .format(url_safe_serializer.dumps('0000-0002-1825-0097'))
def get_by_id_raw(session: scoped_session, model: AnyModel, entity_id: int): """ Get an entity. :param session: :param model: model to query :param entity_id: id of entity to get :return: entity """ return session.query(model).filter(model.id == entity_id).first()
def get(cls, s: scoped_session, id: int) -> Optional[Any]: p = s.query(cls).get(id) if p is None: return None if p.is_active is False: return None return p
def refresh_recently_requested_stop_ids( session: scoped_session) -> List[StopId]: logger = logging.getLogger(__name__) tracking_time_limit = timedelta(minutes=5) session.query(ArrivalRequest).\ filter(ArrivalRequest.request_time < (datetime.utcnow() - tracking_time_limit)).\ delete() session.commit() recent_queries = session.query(ArrivalRequest).order_by( ArrivalRequest.naptan_id).all() if len(recent_queries) == 0: logger.info( f"No arrivals were requested in the last {tracking_time_limit.seconds/60} minutes" ) for q in recent_queries: try: __refresh_arrivals(session, StopId(q.naptan_id)) except exc.DataError as e: logger(f"Cannot refresh arrival info for {q.naptan_id}:", e)
def test_exception_is_handled_by_catch_exception_decorator(profile: Profile, orcid_config: Dict[str, str], mock_orcid_client: MagicMock, session: scoped_session, url_safe_serializer: URLSafeSerializer, commit: Callable[[], None]): mock_orcid_client.remove_webhook.side_effect = Exception('Some Exception') session.add(profile) commit() webhook_maintainer = maintain_orcid_webhook(orcid_config, mock_orcid_client, url_safe_serializer) models_committed.connect(receiver=webhook_maintainer) session.delete(profile) commit() assert mock_orcid_client.remove_webhook.call_count == 1
def insert_classes(classes: List[Dict[str, Any]], session: scoped_session) -> Optional[Any]: """Insert all the classes as defined in the APIDocumentation into DB. Raises: TypeError: If `session` is not an instance of `scoped_session` or `Session`. """ # print(session.query(exists().where(RDFClass.name == "Datastream")).scalar()) if not isinstance(session, scoped_session) and not isinstance( session, Session): raise TypeError( "session is not of type <sqlalchemy.orm.scoping.scoped_session>" "or <sqlalchemy.orm.session.Session>") class_list = [ RDFClass(name=class_["label"].strip('.')) for class_ in classes if "label" in class_ and not session.query(exists().where( RDFClass.name == class_["label"].strip('.'))).scalar() ] class_list.extend([ RDFClass(name=class_["title"].strip('.')) for class_ in classes if "title" in class_ and not session.query(exists().where( RDFClass.name == class_["title"].strip('.'))).scalar() ]) # print(class_list) session.add_all(class_list) session.commit() return None
def populate_books_table(db: scoped_session = db, data_path='books.csv'): # Good progress bars try: __IPYTHON__ from tqdm import tqdm_notebook as tqdm except NameError: from tqdm import tqdm as tqdm import pandas as pd if not table_exists('books', db=db): sql(connect_database.config['queries']['create books'], rollback=True, db=db) sql(connect_database.config['queries']['clear books'], rollback=True) data_path = Path(data_path).absolute() data = pd.read_csv(data_path) db.rollback() for i in tqdm(range(len(data))): book = data.iloc[i] isbn, title, author, year = escape_list(list(book)) # set_trace() query = f"INSERT INTO books(isbn, title, author, year)" + \ f""" VALUES ('{isbn}', '{title}', '{author}', '{year}')""" db.execute(query) db.commit() print("Successful")
def __cache_stop_point(session: scoped_session, naptan_id: str) -> None: """Fetches the data for a single stop point from TFL and stores in the database""" logger = logging.getLogger(__name__) stops = fetch_stops(naptan_id) for stop in stops: #logger.debug(f"Adding stop '{stop.name}', '{stop.indicator}' to database") # TODO This should do a proper update instead of remove-instert session.query(StopPoint).filter( StopPoint.naptan_id == stop.naptan_id).delete() __save_update_timestamp(session, CachedDataType.stop_point, stop.naptan_id) session.add(stop) session.commit()
def __get_update_timestamp(session: scoped_session, type: CachedDataType, id: str = None) -> datetime: """Gets a timestamp of the last update for a given CachedDataType and id pair. In the case of CachedDataType.line_list the id is not used""" logger = logging.getLogger(__name__) update_record_query = session.query(CacheTimestamp).\ filter(CacheTimestamp.data_type == type) if id != None: update_record_query = update_record_query.filter( CacheTimestamp.data_id == id) update_record = update_record_query.order_by(CacheTimestamp.update_time.desc()).\ limit(1).\ one_or_none() if update_record == None: return None return update_record.update_time
def labelBatchSmartContract(self, session: scoped_session, min_id: int, max_id: int, batchSize: int = 500): processinglist = session.query(dbmodules.Processing)\ .filter(dbmodules.Processing.id.between(min_id,max_id))\ .filter(dbmodules.Processing.isprocessed == False)\ .limit(batchSize) if not processinglist: logger.warning("All Smart Contracts Have Been Addressed!") # SCid: int = 1 for processing in processinglist: print(": %d / %d" % (SCid, batchSize)) print(processing) logger.info("Current Contract:{contract}".format( contract=processing.contractAddr)) status = self.pullAllTransactionBySmartContract( processing.contractAddr, session=session) SCid = SCid + 1
def get_arrivals(session: scoped_session, naptan_id: StopId) -> List[Arrival]: try: __refresh_arrivals(session, naptan_id) except exc.DataError as e: logger(f"Cannot refresh arrival info for {naptan_id}:", e) return [] req = session.query(ArrivalRequest).filter( ArrivalRequest.naptan_id == naptan_id).one_or_none() if req == None: session.add(ArrivalRequest(naptan_id=naptan_id)) else: req.request_time = datetime.utcnow() session.commit() return session.query(Arrival).\ filter(Arrival.naptan_id == naptan_id).\ filter(Arrival.ttl > datetime.utcnow()).\ order_by(Arrival.expected).\ limit(6).all()
def pullAllTransactionBySmartContract_raw(self, scAddress, session: scoped_session) -> bool: api = EtherscanAPI() existError: bool = False try: txList = api.getAllTransactionForContractAddress(scAddress) logger.info("TxList for: {address}, length: {length} ".format( address=scAddress, length=len(txList))) errorNumber: int = 0 # testIndex = 0 batchSize = 100 for item in txList: if item['isError'] == "0": continue errDescp = '' existError = True errDescp = api.getTransactionStatus( item['hash'])['errDescription'] if not item['to'] and item['contractAddress']: item['to'] = item['contractAddress'] item['fromAddr'] = item['from'] item['toAddr'] = item['to'] item["errDescription"] = errDescp item['comulativeGasUsed'] = item['cumulativeGasUsed'] sql = text( "INSERT INTO errTxList (blockNumber,timeStamp,hash,nonce,blockHash,transactionIndex,fromAddr,toAddr,value,gas,gasPrice,isError,errDescription,txreceipt_status,contractAddress,comulativeGasUsed,gasUsed,confirmations)" " VALUES (:blockNumber,:timeStamp,:hash,:nonce,:blockHash,:transactionIndex,:fromAddr,:toAddr,:value,:gas,:gasPrice,:isError,:errDescription,:txreceipt_status,:contractAddress,:comulativeGasUsed,:gasUsed,:confirmations)" ) sess.execute(sql, params=item) errorNumber = errorNumber + 1 if errorNumber % batchSize == 0: # session.flush() session.commit() # time.sleep(0.25) # logger.info( " |__ TxList for: %s, %d times, at most %d txs/submit." % (scAddress, (errorNumber - 1) // batchSize + 1, batchSize)) logger.info(" |__ TxList for: %s, total error transactions: %d." % (scAddress, errorNumber)) except Exception as e: logger.error( "Contract Address:{address}, Transaction Hash:{txhash}, Error Message:{message}" .format(address=scAddress, txhash=item['hash'], message=e)) return False else: # session.query(dbmodules.Processing)\ .filter(dbmodules.Processing.contractAddr == scAddress)\ .update({"isprocessed":True}) session.query(dbmodules.SmartContract)\ .filter(dbmodules.SmartContract.contractAddr == scAddress)\ .update({"label": "1" if existError else '0', "txTotalCount": len(txList), "txErrorTotalCount": errorNumber }) session.commit() session.flush() # time.sleep(0.) logger.success( "Request all txlist of Contract Address:{address} successfully!" .format(address=scAddress)) return True
def __delete_arrivals(session: scoped_session, id: str) -> None: """Deletes a arrivals for a single stop point from the database""" session.query(Arrival).filter(Arrival.naptan_id == id).delete() session.commit()
def initDBofContract(self, session: scoped_session, submitInteral, sleepSec=1): if session.query(dbmodules.Processing).first(): logger.warning("wrong") return for root, dirs, files in os.walk(self.rootpath): interval: int = 0 templist = [] templist_2 = [] for file in files: if file.endswith("txt"): templist.append( dbmodules.Processing( contractAddr=os.path.splitext(file)[0], isprocessed=False)) templist_2.append( dbmodules.SmartContract( contractAddr=os.path.splitext(file)[0], label='none')) interval = interval + 1 if interval % submitInteral == 0: # session.add_all(templist) session.add_all(templist_2) templist = [] # templist_2 = [] session.commit() # session.flush() time.sleep(sleepSec) # print("sumbit: %d times" % ((interval - 1) // submitInteral + 1)) if templist: session.add_all(templist) session.add_all(templist_2) session.commit() session.flush() print("sumbit: %d times" % ((interval - 1) // submitInteral + 1))
def get_stop_point_by_url(session: scoped_session, url: str) -> StopPoint: return session.query(StopPoint).filter(StopPoint.url == url).one_or_none()
def pullAllTransactionBySmartContract(self, scAddress, session: scoped_session) -> bool: api = EtherscanAPI() existError: bool = False batchSize = 100 try: txList = api.getAllTransactionForContractAddress(scAddress) logger.info("TxList for: {address}, length: {length} ".format( address=scAddress, length=len(txList))) batchList = [] errorNumber: int = 0 # testIndex = 0 for item in txList: # print("testindex :%d" % testIndex) # testIndex =testIndex +1 if item['isError'] == "0": continue errDescp = '' # time.sleep(0.2) existError = True errDescp = api.getTransactionStatus( item['hash'])['errDescription'] if not item['to'] and item['contractAddress']: item['to'] = item['contractAddress'] batchList.append( dbmodules.ErrorTransactionListForSC( blockNumber=item['blockNumber'], timeStamp=item['timeStamp'], hash=item['hash'], nonce=item['nonce'], blockHash=item['blockHash'], transactionIndex=item['transactionIndex'], fromAddr=item['from'], toAddr=item['to'], value=item['value'], gas=item['gas'], gasPrice=item['gasPrice'], isError=item['isError'], errDescription=errDescp, txreceipt_status=item['txreceipt_status'], contractAddress=item['contractAddress'], comulativeGasUsed=item['cumulativeGasUsed'], gasUsed=item['gasUsed'], confirmations=item['confirmations'])) errorNumber = errorNumber + 1 if errorNumber % batchSize == 0: # # session.add_all(batchList) session.bulk_save_objects(batchList) batchList = [] # session.commit() # session.flush() time.sleep(0.25) # logger.info( " |__ TxList for: %s, %d times, at most %d txs/submit." % (scAddress, (errorNumber - 1) // batchSize + 1, batchSize)) # if batchList: # session.add_all(batchList) session.bulk_save_objects(batchList) session.commit() session.flush() # time.sleep(0.5) logger.info( " |__ TxList for: %s, %d times, at most %d txs/submit." % (scAddress, (errorNumber - 1) // batchSize + 1, batchSize)) logger.info(" |__ TxList for: %s, total error transactions: %d." % (scAddress, errorNumber)) except Exception as e: # logger.error( "Contract Address:{address}, Transaction Hash:{txhash}, Error Message:{message}" .format(address=scAddress, txhash=item['hash'], message=e)) return False else: # session.query(dbmodules.Processing)\ .filter(dbmodules.Processing.contractAddr == scAddress)\ .update({"isprocessed":True}) session.query(dbmodules.SmartContract)\ .filter(dbmodules.SmartContract.contractAddr == scAddress)\ .update({"label": "1" if existError else '0', "txTotalCount": len(txList), "txErrorTotalCount": errorNumber }) session.commit() session.flush() # time.sleep(0.) logger.success( "Request all txlist of Contract Address:{address} successfully!" .format(address=scAddress)) return True