Ejemplo n.º 1
0
 def fetch_auth_creds(self):
     self.svc_auth['twitter'] = self.authenticate_dcbot()
     for creds, sql in zip(['pinata', 'cloudflare'], [
             self.config.experiment.infsvc.sql.get_pinata_creds_sql,
             self.config.experiment.infsvc.sql.get_cloudflare_creds_sql
     ]):
         self.svc_auth[creds] = fetch_one(self.cnxp.get_connection(), sql)
Ejemplo n.º 2
0
 def pin_flow(self, preds: List, rm_previous: bool = False) -> bool:
     current_cid = None
     headers = {
         'pinata_api_key': self.svc_auth['pinata'][0],
         'pinata_secret_api_key': self.svc_auth['pinata'][1],
         'Content-Type': 'application/json'
     }
     if rm_previous:
         current_cid = fetch_one(
             self.cnxp.get_connection(),
             self.config.experiment.infsvc.sql.fetch_current_pinned_cid_sql)
     pinned_tup, pin_cnt, pin_error = self.pin_cid(preds, headers)
     if pin_cnt == 1 and pin_error == 0:
         logger.info(
             f'Pinned latest unlabeled model predictions {pinned_tup[1]} with size {pinned_tup[2]}'
         )
         self.patch_dns(pinned_tup[1])
         self.latest_pinned_cid = pinned_tup[1]
         if rm_previous and (current_cid[0] != pinned_tup[1]):
             return self.unpin_cid(current_cid[0], headers)
         return True
     else:
         logger.warning(
             f'Unexpected pinning results. Pinned {pin_cnt} items, with {pin_error} errors detected '
             f'while logging pinned items. You may want to inspect pinata/ipfs items.'
         )
         return False
Ejemplo n.º 3
0
 def construct_gen(self, start_dt: datetime, end_dt: datetime) -> Iterator:
     sql_stmts = []
     sql_bound_pred = f"STR_TO_DATE('{start_dt.strftime('%Y-%m-%d')}','%Y-%m-%d') " \
                      f"and STR_TO_DATE('{end_dt.strftime('%Y-%m-%d')}','%Y-%m-%d')"
     for sql in self.dataset_conf['class_sql']:
         sql_stmts.append(f"{sql} {sql_bound_pred}")
     gens = [
         db_utils.db_ds_gen(self.cnxp, sql,
                            self.config.data_source.db_fetch_size)
         for sql in sql_stmts
     ]
     self.dataset_conf['start_date'], self.dataset_conf[
         'end_date'] = start_dt, end_dt
     if self.config.data_source.class_balancing_strategy == "class_weights":
         gen = dataprep.dataprep_utils.class_weight_gen(gens)
     else:
         if not dataprep.dataprep_utils.validate_normalized(
                 self.config.data_source.sampling_weights):
             logger.error(
                 f"invalid values specified {self.config.data_source.sampling_weights} (sum must == 1.0). "
                 f"Please reconfigure and restart")
             sys.exit(0)
         cls_bound_cards = []
         for sql in self.dataset_conf['class_bound_sql']:
             sql = f"{sql} {sql_bound_pred}"
             cls_bound_cards.append(
                 db_utils.fetch_one(self.cnxp.get_connection(), sql)[0])
         gen = dataprep.dataprep_utils.ds_minority_oversample_gen(
             self.config.data_source.sampling_weights, gens,
             cls_bound_cards)
     return gen
Ejemplo n.º 4
0
 def define_dataset_structure(self) -> None:
     if self.config.experiment.debug.use_debug_dataset is True:
         bsql = self.config.data_source.sql.debug.dist_dt_bound_sql
         dt_lower_bound = datetime.datetime.strptime(
             '2017-01-19', '%Y-%m-%d').date()
     else:
         bsql = self.config.data_source.sql.converge_dist_dt_bound_sql
         dt_lower_bound = datetime.datetime.strptime(
             '1900-01-01', '%Y-%m-%d').date()
     cume_v = 0
     for k, v in self.dataset_conf[self.target_ds_structure].items():
         # setup db dataset generators per specified train/dev/test splits
         cume_v += v
         sql = f"{bsql} {cume_v}"
         dt_upper_bound = db_utils.fetch_one(self.cnxp.get_connection(),
                                             sql)[0]
         ds_gen = self.construct_gen(dt_lower_bound, dt_upper_bound)
         recs, xformer_examples, ctxt_features = dataprep.dataprep_utils.parse_sql_to_example(
             ds_gen, k)
         self.dataset_conf[f'num_{k}_recs'] = recs
         self.dataset_conf[f'{k}_start_date'] = self.dataset_conf[
             'start_date']
         self.dataset_conf[f'{k}_end_date'] = self.dataset_conf['end_date']
         xformer_features = convert_examples_to_features(
             xformer_examples,
             self.dataset_conf['albert_tokenizer'],
             label_list=self.dataset_conf['class_labels'],
             max_length=self.config.data_source.max_seq_length,
             output_mode="classification")
         logger.info(
             f"Saving features into cached file {self.dataset_conf['ds_datafiles'][k]}",
         )
         torch.save([xformer_features, ctxt_features],
                    self.dataset_conf['ds_datafiles'][k])
         dt_lower_bound = dt_upper_bound
Ejemplo n.º 5
0
 def authenticate_dcbot(self) -> tweepy.API:
     # store tokens in DB and retrieve them rather than set them in memory
     consumer_key, consumer_secret, access_token, access_token_secret = \
         fetch_one(self.cnxp.get_connection(), self.config.experiment.tweetbot.sql.get_bot_creds_sql)
     auth = tweepy.OAuthHandler(consumer_key, consumer_secret)
     auth.set_access_token(access_token, access_token_secret)
     api = tweepy.API(auth,
                      wait_on_rate_limit=True,
                      wait_on_rate_limit_notify=True)
     try:
         api.verify_credentials()
         print("Authentication OK")
     except TweepError:
         exc_type, exc_value, exc_traceback = sys.exc_info()
         logger.error(
             f"Error during authentication:"
             f" {repr(traceback.format_exception(exc_type, exc_value, exc_traceback))}"
         )
     return api
 def __init__(self, config: MutableMapping, dbcnxp: MySQLConnectionPool,
              stype: str) -> None:
     self.config = config
     self.updatedb = self.config.db.update_db
     self.cnxp = dbcnxp
     self.stype = stype
     self.init_wait = self.config['scraping'][self.stype]['init_wait']
     self.debug_mode = True if self.config.db.debug_enabled else False
     self.initial_load = False
     self.driver = None
     try:
         self.latest_db_stmt = fetch_one(
             self.cnxp.get_connection(), self.config['scraping'][self.stype]
             ['sql']['latest_statement'])[0]
     except Error as e:
         logger.error(constants.DB_ERROR_MSG + f'{e}')
         raise e
     if not self.latest_db_stmt:
         self.latest_db_stmt = self.config['scraping'][self.stype]['default_latest_stmt'] or \
                               constants.BEGINNING_OF_TIME
         self.initial_load = True