def populate_books_table(db: scoped_session = db, data_path='books.csv'):
    # Good progress bars
    try:
        __IPYTHON__
        from tqdm import tqdm_notebook as tqdm
    except NameError:
        from tqdm import tqdm as tqdm

    import pandas as pd
    if not table_exists('books', db=db):
        sql(connect_database.config['queries']['create books'],
            rollback=True,
            db=db)
    sql(connect_database.config['queries']['clear books'], rollback=True)

    data_path = Path(data_path).absolute()
    data = pd.read_csv(data_path)
    db.rollback()
    for i in tqdm(range(len(data))):
        book = data.iloc[i]
        isbn, title, author, year = escape_list(list(book))
        #     set_trace()
        query = f"INSERT INTO books(isbn, title, author, year)" + \
                f""" VALUES ('{isbn}', '{title}', '{author}', '{year}')"""
        db.execute(query)

    db.commit()

    print("Successful")
示例#2
0
def insert_properties(properties: Set[str], session: scoped_session) -> Optional[Any]:
    """Insert all the properties as defined in the APIDocumentation into DB."""
    prop_list = [BaseProperty(name=prop) for prop in properties
                 if not session.query(exists().where(BaseProperty.name == prop)).scalar()]
    session.add_all(prop_list)
    session.commit()
    return None


# if __name__ == "__main__":
#     Session = sessionmaker(bind=engine)
#     session = Session()
#
#     doc = doc_gen("test", "test")
#     # Extract all classes with supportedProperty from both
#     classes = get_classes(doc.generate())
#
#     # Extract all properties from both
#     # import pdb; pdb.set_trace()
#     properties = get_all_properties(classes)
#     # Add all the classes
#     insert_classes(classes, session)
#     print("Classes inserted successfully")
#     # Add all the properties
#     insert_properties(properties, session)
#     print("Properties inserted successfully")
示例#3
0
def insert_classes(classes: List[Dict[str, Any]],
                   session: scoped_session) -> Optional[Any]:
    """Insert all the classes as defined in the APIDocumentation into DB.

    Raises:
        TypeError: If `session` is not an instance of `scoped_session` or `Session`.

    """
    # print(session.query(exists().where(RDFClass.name == "Datastream")).scalar())
    if not isinstance(session, scoped_session) and not isinstance(
            session, Session):
        raise TypeError(
            "session is not of type <sqlalchemy.orm.scoping.scoped_session>"
            "or <sqlalchemy.orm.session.Session>")
    class_list = [
        RDFClass(name=class_["label"].strip('.')) for class_ in classes
        if "label" in class_ and not session.query(exists().where(
            RDFClass.name == class_["label"].strip('.'))).scalar()
    ]

    class_list.extend([
        RDFClass(name=class_["title"].strip('.')) for class_ in classes
        if "title" in class_ and not session.query(exists().where(
            RDFClass.name == class_["title"].strip('.'))).scalar()
    ])
    # print(class_list)
    session.add_all(class_list)
    session.commit()
    return None
示例#4
0
def __refresh_arrivals(session: scoped_session,
                       naptan_id: StopId) -> List[Arrival]:
    ud = __UpdateDescription(CachedDataType.arrival, naptan_id,
                             timedelta(minutes=1), __cache_arrivals,
                             __delete_arrivals)
    __update_cache(session, ud)
    session.commit()
def sql(query, db: scoped_session = db, fetch=False, rollback=False):
    """ runs sql query on db """
    if rollback: db.rollback()
    if fetch:
        return db.execute(query).fetchall()
    db.execute(query)
    db.commit()
示例#6
0
def remove_stale_modification_records(session: scoped_session,
                                      stale_records_removal_interval: int = 900
                                      ):
    """
    Remove modification records which are older than last 1000 records.
    :param session: sqlalchemy session.
    :param stale_records_removal_interval: Interval time to run the removal job.
    """
    timer = Timer(stale_records_removal_interval,
                  remove_stale_modification_records, [session])
    timer.daemon = True
    timer.start()
    # Get all valid records.
    valid_records = session.query(Modification).order_by(
        Modification.job_id.desc()).limit(1000).all()
    # If number of returned valid records is less than set limit then
    # there is nothing to clean up.
    if len(valid_records) < 1000:
        return
    else:
        # Get the job_id of last (oldest) valid record.
        job_id_of_last_valid_record = valid_records[-1].job_id
        # Get all records which are older than the oldest valid record.
        stale_records = session.query(Modification).filter(
            Modification.job_id < job_id_of_last_valid_record).all()
        for record in stale_records:
            session.delete(record)
        session.commit()
    session.remove()
示例#7
0
def insert_properties(properties: Set[str],
                      session: scoped_session) -> Optional[Any]:
    """Insert all the properties as defined in the APIDocumentation into DB."""
    prop_list = [BaseProperty(name=prop) for prop in properties
                 if not session.query(exists().where(BaseProperty.name == prop)).scalar()]
    session.add_all(prop_list)
    session.commit()
    return None
 def generate(_db_session: scoped_session,
              school_name: str,
              is_commit: bool = False):
     _school = School(school_name)
     db_session.add(_school)
     Board.generate(_db_session, _school)
     if is_commit:
         _db_session.commit()
     return _school
示例#9
0
def __cache_arrivals(session: scoped_session, naptan_id: str) -> None:
    """Fetches the arrivals for a single stop point from TFL and stores in the database"""
    logger = logging.getLogger(__name__)
    arrivals = fetch_arrivals(naptan_id)
    logger.info(f"Adding arrivals for '{naptan_id}' to database")
    for arrival in arrivals:
        db_arrival = session.query(Arrival).filter(
            Arrival.arrival_id == arrival.arrival_id).one_or_none()
        if db_arrival is not None:
            db_arrival.update_with(arrival)
        else:
            session.add(arrival)
    session.commit()
示例#10
0
def __cache_stop_point(session: scoped_session, naptan_id: str) -> None:
    """Fetches the data for a single stop point from TFL and stores in the database"""
    logger = logging.getLogger(__name__)
    stops = fetch_stops(naptan_id)
    for stop in stops:
        #logger.debug(f"Adding stop '{stop.name}', '{stop.indicator}' to database")
        # TODO This should do a proper update instead of remove-instert
        session.query(StopPoint).filter(
            StopPoint.naptan_id == stop.naptan_id).delete()
        __save_update_timestamp(session, CachedDataType.stop_point,
                                stop.naptan_id)
        session.add(stop)
    session.commit()
示例#11
0
def __save_update_timestamp(session: scoped_session,
                            type: CachedDataType,
                            id: str = "") -> None:
    """Stores the current time as the timestamp of the last update
    for a given CachedDataType and id pair"""
    ts = session.query(CacheTimestamp).filter(
        CacheTimestamp.data_type == type).filter(
            CacheTimestamp.data_id == id).one_or_none()
    if ts == None:
        session.add(CacheTimestamp(data_type=type, data_id=id))
    else:
        ts.update_time = datetime.utcnow()
    session.commit()
def test_it_ignores_other_models_being_committed(orcid_token: OrcidToken,
                                                 orcid_config: Dict[str, str],
                                                 mock_orcid_client: MagicMock,
                                                 session: scoped_session,
                                                 url_safe_serializer: URLSafeSerializer):
    webhook_maintainer = maintain_orcid_webhook(orcid_config, mock_orcid_client,
                                                url_safe_serializer)
    models_committed.connect(receiver=webhook_maintainer)

    session.add(orcid_token)
    session.commit()

    assert mock_orcid_client.set_webhook.call_count == 0
    assert mock_orcid_client.remove_webhook.call_count == 0
示例#13
0
def update_item(db_session: scoped_session, query_data: dict, trial_results: dict, mab: MAB) -> None or str:
    """
    Check results of MAB trial against data in db, update db with trial results if necessary,
    as well as the pickled mab.

    If the trial has ended, set the best preforming thumbnail to be `next`, set `active_trial`
    to False, and delete all options associated with the item.
    """

    # Re-pickle mab class test instance
    query_data["item"].mab = pickle.dumps(mab)

    # If there are more trials, set the next url
    if trial_results["new_trial"]:
        if trial_results["choice"] != query_data["item"].next:
            # item.id is already hashed with project, just need to add option_num
            next_id = hash_id(query_data["item"].id, trial_results["choice"])
            query_data["item"].option_num = trial_results["choice"]

            # Find the next image using the id we get from `trial_results`
            try:
                option = Option.query.get(next_id)
                query_data["item"].next = option.content
            except:
                return "500 - internal server error"

    # The trial is over, set `next` to the best preforming thumbnail and `active_trial` to False
    else:
        query_data["item"].next = trial_results["best"]
        query_data["item"].active_trial = False

        # Delete all options as they are no longer needed
        for i in range(1, query_data["item"].total_num + 1):
            # item.id is already hashed, just add option_num
            option_id = hash_id(query_data["item"].id, i)
            Option.query.filter_by(id=option_id).delete()

        query_data["item"].total_num = 0

    try:
        db_session.add(query_data["item"])
        db_session.commit()
        return "200 - trial successfully updated."
    except:
        return "500 - internal server error."
示例#14
0
    def initDBofContract(self,
                         session: scoped_session,
                         submitInteral,
                         sleepSec=1):

        if session.query(dbmodules.Processing).first():
            logger.warning("wrong")
            return

        for root, dirs, files in os.walk(self.rootpath):

            interval: int = 0
            templist = []
            templist_2 = []
            for file in files:
                if file.endswith("txt"):
                    templist.append(
                        dbmodules.Processing(
                            contractAddr=os.path.splitext(file)[0],
                            isprocessed=False))
                    templist_2.append(
                        dbmodules.SmartContract(
                            contractAddr=os.path.splitext(file)[0],
                            label='none'))
                    interval = interval + 1
                    if interval % submitInteral == 0:  #
                        session.add_all(templist)
                        session.add_all(templist_2)

                        templist = []  #
                        templist_2 = []
                        session.commit()  #
                        session.flush()
                        time.sleep(sleepSec)  #
                        print("sumbit: %d times" %
                              ((interval - 1) // submitInteral + 1))

            if templist:
                session.add_all(templist)
                session.add_all(templist_2)
                session.commit()
                session.flush()
                print("sumbit: %d times" %
                      ((interval - 1) // submitInteral + 1))
示例#15
0
def refresh_recently_requested_stop_ids(
        session: scoped_session) -> List[StopId]:
    logger = logging.getLogger(__name__)
    tracking_time_limit = timedelta(minutes=5)
    session.query(ArrivalRequest).\
        filter(ArrivalRequest.request_time < (datetime.utcnow() - tracking_time_limit)).\
        delete()
    session.commit()
    recent_queries = session.query(ArrivalRequest).order_by(
        ArrivalRequest.naptan_id).all()
    if len(recent_queries) == 0:
        logger.info(
            f"No arrivals were requested in the last {tracking_time_limit.seconds/60} minutes"
        )
    for q in recent_queries:
        try:
            __refresh_arrivals(session, StopId(q.naptan_id))
        except exc.DataError as e:
            logger(f"Cannot refresh arrival info for {q.naptan_id}:", e)
示例#16
0
def get_arrivals(session: scoped_session, naptan_id: StopId) -> List[Arrival]:
    try:
        __refresh_arrivals(session, naptan_id)
    except exc.DataError as e:
        logger(f"Cannot refresh arrival info for {naptan_id}:", e)
        return []

    req = session.query(ArrivalRequest).filter(
        ArrivalRequest.naptan_id == naptan_id).one_or_none()
    if req == None:
        session.add(ArrivalRequest(naptan_id=naptan_id))
    else:
        req.request_time = datetime.utcnow()
    session.commit()

    return session.query(Arrival).\
        filter(Arrival.naptan_id == naptan_id).\
        filter(Arrival.ttl > datetime.utcnow()).\
        order_by(Arrival.expected).\
        limit(6).all()
示例#17
0
    def pullAllTransactionBySmartContract(self, scAddress,
                                          session: scoped_session) -> bool:

        api = EtherscanAPI()

        existError: bool = False
        batchSize = 100
        try:
            txList = api.getAllTransactionForContractAddress(scAddress)
            logger.info("TxList for: {address}, length: {length} ".format(
                address=scAddress, length=len(txList)))
            batchList = []
            errorNumber: int = 0
            # testIndex = 0
            for item in txList:
                # print("testindex :%d" % testIndex)
                # testIndex =testIndex +1

                if item['isError'] == "0":
                    continue

                errDescp = ''
                # time.sleep(0.2)
                existError = True
                errDescp = api.getTransactionStatus(
                    item['hash'])['errDescription']

                if not item['to'] and item['contractAddress']:
                    item['to'] = item['contractAddress']

                batchList.append(
                    dbmodules.ErrorTransactionListForSC(
                        blockNumber=item['blockNumber'],
                        timeStamp=item['timeStamp'],
                        hash=item['hash'],
                        nonce=item['nonce'],
                        blockHash=item['blockHash'],
                        transactionIndex=item['transactionIndex'],
                        fromAddr=item['from'],
                        toAddr=item['to'],
                        value=item['value'],
                        gas=item['gas'],
                        gasPrice=item['gasPrice'],
                        isError=item['isError'],
                        errDescription=errDescp,
                        txreceipt_status=item['txreceipt_status'],
                        contractAddress=item['contractAddress'],
                        comulativeGasUsed=item['cumulativeGasUsed'],
                        gasUsed=item['gasUsed'],
                        confirmations=item['confirmations']))

                errorNumber = errorNumber + 1

                if errorNumber % batchSize == 0:  #
                    # session.add_all(batchList)
                    session.bulk_save_objects(batchList)
                    batchList = []  #
                    session.commit()  #
                    session.flush()
                    time.sleep(0.25)  #
                    logger.info(
                        " |__ TxList for: %s, %d times, at most %d txs/submit."
                        % (scAddress,
                           (errorNumber - 1) // batchSize + 1, batchSize))

            #
            if batchList:
                # session.add_all(batchList)
                session.bulk_save_objects(batchList)
                session.commit()
                session.flush()
                # time.sleep(0.5)
                logger.info(
                    " |__ TxList for: %s, %d times, at most %d txs/submit." %
                    (scAddress, (errorNumber - 1) // batchSize + 1, batchSize))
            logger.info(" |__ TxList for: %s, total error transactions: %d." %
                        (scAddress, errorNumber))

        except Exception as e:
            #
            logger.error(
                "Contract Address:{address}, Transaction Hash:{txhash}, Error Message:{message}"
                .format(address=scAddress, txhash=item['hash'], message=e))
            return False

        else:
            #
            session.query(dbmodules.Processing)\
                .filter(dbmodules.Processing.contractAddr == scAddress)\
                .update({"isprocessed":True})
            session.query(dbmodules.SmartContract)\
                    .filter(dbmodules.SmartContract.contractAddr == scAddress)\
                    .update({"label": "1" if existError else '0',
                             "txTotalCount": len(txList),
                             "txErrorTotalCount": errorNumber
                             })
            session.commit()
            session.flush()
            # time.sleep(0.)

            logger.success(
                "Request all txlist of Contract Address:{address} successfully!"
                .format(address=scAddress))
            return True
示例#18
0
    def pullAllTransactionBySmartContract_raw(self, scAddress,
                                              session: scoped_session) -> bool:

        api = EtherscanAPI()

        existError: bool = False

        try:
            txList = api.getAllTransactionForContractAddress(scAddress)
            logger.info("TxList for: {address}, length: {length} ".format(
                address=scAddress, length=len(txList)))
            errorNumber: int = 0
            # testIndex = 0
            batchSize = 100
            for item in txList:

                if item['isError'] == "0":
                    continue

                errDescp = ''
                existError = True
                errDescp = api.getTransactionStatus(
                    item['hash'])['errDescription']

                if not item['to'] and item['contractAddress']:
                    item['to'] = item['contractAddress']

                item['fromAddr'] = item['from']
                item['toAddr'] = item['to']
                item["errDescription"] = errDescp
                item['comulativeGasUsed'] = item['cumulativeGasUsed']

                sql = text(
                    "INSERT INTO errTxList (blockNumber,timeStamp,hash,nonce,blockHash,transactionIndex,fromAddr,toAddr,value,gas,gasPrice,isError,errDescription,txreceipt_status,contractAddress,comulativeGasUsed,gasUsed,confirmations)"
                    " VALUES (:blockNumber,:timeStamp,:hash,:nonce,:blockHash,:transactionIndex,:fromAddr,:toAddr,:value,:gas,:gasPrice,:isError,:errDescription,:txreceipt_status,:contractAddress,:comulativeGasUsed,:gasUsed,:confirmations)"
                )

                sess.execute(sql, params=item)

                errorNumber = errorNumber + 1

                if errorNumber % batchSize == 0:  #
                    session.flush()
                    session.commit()  #
                    time.sleep(0.25)  #
                    logger.info(
                        " |__ TxList for: %s, %d times, at most %d txs/submit."
                        % (scAddress,
                           (errorNumber - 1) // batchSize + 1, batchSize))

            logger.info(" |__ TxList for: %s, total error transactions: %d." %
                        (scAddress, errorNumber))

        except Exception as e:

            logger.error(
                "Contract Address:{address}, Transaction Hash:{txhash}, Error Message:{message}"
                .format(address=scAddress, txhash=item['hash'], message=e))
            return False

        else:
            #
            session.query(dbmodules.Processing)\
                .filter(dbmodules.Processing.contractAddr == scAddress)\
                .update({"isprocessed":True})
            session.query(dbmodules.SmartContract)\
                    .filter(dbmodules.SmartContract.contractAddr == scAddress)\
                    .update({"label": "1" if existError else '0',
                             "txTotalCount": len(txList),
                             "txErrorTotalCount": errorNumber
                             })
            session.commit()
            session.flush()
            # time.sleep(0.)

            logger.success(
                "Request all txlist of Contract Address:{address} successfully!"
                .format(address=scAddress))
            return True
示例#19
0
def __delete_arrivals(session: scoped_session, id: str) -> None:
    """Deletes a arrivals for a single stop point from the database"""
    session.query(Arrival).filter(Arrival.naptan_id == id).delete()
    session.commit()