Esempio n. 1
0
 def maybe_build_cache(self,
                       stmt_embed_dict: Optional[Dict] = None) -> None:
     if ((not self.stmt_embed_cache.exists()) or self.config.inference.rebuild_stmt_cache) \
             and not self.config.inference.update_perf_caches_only:
         torch.save(stmt_embed_dict, self.stmt_embed_cache)
         logger.info(
             f"Cached {len(stmt_embed_dict['sids'])} statement embeddings for analysis:"
             f" {self.stmt_embed_cache}")
     if ((not self.tweet_model_perf_cache.exists())
             or (not self.nontweet_model_perf_cache.exists())
             or self.config.inference.rebuild_perf_cache):
         tweet_model_perf_tups = fetchallwrapper(
             self.cnxp.get_connection(),
             self.config.inference.sql.tweet_model_perf_cache_sql)
         torch.save(tweet_model_perf_tups, self.tweet_model_perf_cache)
         logger.info(
             f"Cached model test accuracy into {len(tweet_model_perf_tups)} buckets for tweet model "
             f"reporting: {self.tweet_model_perf_cache}")
         nontweet_model_perf_tups = fetchallwrapper(
             self.cnxp.get_connection(),
             self.config.inference.sql.nontweet_model_perf_cache_sql)
         torch.save(nontweet_model_perf_tups,
                    self.nontweet_model_perf_cache)
         logger.info(
             f"Cached model test accuracy into {len(nontweet_model_perf_tups)} buckets for nontweet model "
             f"reporting: {self.nontweet_model_perf_cache}")
         self.refresh_global_cache()  # refresh global cache after updating
Esempio n. 2
0
 def gen_pred_exp_ds(self) -> Tuple[Dict, Tuple]:
     pred_exp_tups = fetchallwrapper(self.cnxp.get_connection(),
                                     self.config.inference.sql.pred_exp_sql)
     pred_exp_set = []
     pred_exp_ds = OrderedDict({
         'bucket_type': [],
         'bucket_acc': [],
         'conf_percentile': [],
         'pos_pred_acc': [],
         'neg_pred_acc': [],
         'pos_pred_ratio': [],
         'neg_pred_ratio': [],
         'statement_id': [],
         'statement_text': [],
         'tp': [],
         'tn': [],
         'fp': [],
         'fn': []
     })
     for (bucket_type, bucket_acc, conf_percentile, pos_pred_acc,
          neg_pred_acc, pos_pred_ratio, neg_pred_ratio, statement_id,
          statement_text, ctxt_type, tp, tn, fp, fn) in pred_exp_tups:
         label = 'False' if tp == 1 or fn == 1 else 'True'
         pred_exp_set.append((statement_text, ctxt_type, label))
         for k, v in zip(list(pred_exp_ds.keys()), [
                 bucket_type, bucket_acc, conf_percentile, pos_pred_acc,
                 neg_pred_acc, pos_pred_ratio, neg_pred_ratio, statement_id,
                 statement_text, tp, tn, fp, fn
         ]):
             pred_exp_ds[k].append(v)
     pred_exp_attr_tups, global_metric_summ = Inference(
         self.config, pred_exp_set=pred_exp_set).init_predict()
     pred_exp_ds['pred_exp_attr_tups'] = pred_exp_attr_tups
     return pred_exp_ds, global_metric_summ
Esempio n. 3
0
 def gen_analysis_set(self) -> List[Tuple]:
     # current use case involves relatively small analysis set that fits in memory and should only be used once
     # so wasteful to persist. if later use cases necessitate, will pickle or persist for larger datasets
     report_sql = f"select * from {self.report_view}"
     analysis_set = ModelAnalysisRpt.prep_model_analysis_ds(
         fetchallwrapper(self.cnxp.get_connection(), report_sql))
     return analysis_set
Esempio n. 4
0
 def refresh_global_cache(self):
     global_model_perf_tups = fetchallwrapper(
         self.cnxp.get_connection(),
         self.config.inference.sql.global_model_perf_cache_sql)
     torch.save(global_model_perf_tups, self.global_model_perf_cache)
     logger.info(
         f"(Re)Built global model accuracy cache {self.global_model_perf_cache}"
     )
Esempio n. 5
0
 def fetch_mapping_inputs(self):
     mapping_inputs = []
     truth_sql = self.config.data_source.sql.build_truths_embedding
     falsehood_sql = self.config.data_source.sql.build_falsehoods_embedding
     for idsql in [truth_sql, falsehood_sql]:
         mapping_inputs.append(
             db_utils.fetchallwrapper(self.cnxp.get_connection(), idsql))
     return mapping_inputs
Esempio n. 6
0
 def gen_perf_exp_ds(self) -> Dict:
     perf_exp_dict = {}
     for cmatrix_rpt_type in [
             *db_constants.TEST_CMATRICES, *db_constants.TEST_CONF_CMATRICES
     ]:
         perf_exp_dict[cmatrix_rpt_type] = fetchallwrapper(
             self.cnxp.get_connection(),
             f"select * from {cmatrix_rpt_type}")
     return perf_exp_dict
Esempio n. 7
0
 def check_ids(self, test_ids: List) -> Set:
     test_ids = [tup[0] for tup in test_ids]
     inlist = ", ".join(list(map(lambda x: '%s', test_ids)))
     test_ids = tuple(test_ids)
     sql = f"{self.config['scraping'][self.stype]['sql']['check_tid']} ({inlist})"
     found_ids = fetchallwrapper(self.cnxp.get_connection(), sql, test_ids)
     test_set = set(test_ids)
     found_set = set([f[0] for f in found_ids])
     dup_ids = test_set.intersection(found_set)
     return dup_ids
Esempio n. 8
0
 def maybe_publish(self, target_type: str) -> None:
     # N.B. publishing all statements and tweets that meet length thresholds, driven by four tables:
     # a published and "notpublished" table for both statements and tweets
     # since metadata is substantially different and not straightforward to cleanly combine)
     if target_type == 'stmts':
         target_tups = fetchallwrapper(
             self.cnxp.get_connection(),
             self.config.experiment.tweetbot.sql.stmts_to_analyze_sql)
         interval = self.config.experiment.tweetbot.dcbot_poll_interval * self.non_twitter_updatefreq
     else:
         target_tups = fetchallwrapper(
             self.cnxp.get_connection(),
             self.config.experiment.tweetbot.sql.tweets_to_analyze_sql)
         interval = self.config.experiment.tweetbot.dcbot_poll_interval
     if target_tups:
         self.prep_new_threads(target_tups)
         self.publish_reports(
             Inference(self.config).init_predict(), target_type)
     else:
         logger.info(
             f"No new {target_type} found to analyze and publish. Trying again in {interval} seconds"
         )
Esempio n. 9
0
 def gen_report(self, rpt_type: str) -> None:
     analysis_set = self.gen_analysis_set()
     ds_meta = fetchallwrapper(self.cnxp.get_connection(),
                               self.config.inference.sql.ds_md_sql)[0]
     self.config.data_source.dsid = ds_meta[0]
     self.config.data_source.train_start_date = datetime.datetime.combine(
         ds_meta[1], datetime.time())
     self.config.data_source.train_end_date = datetime.datetime.combine(
         ds_meta[2], datetime.time())
     rpt_tups, stmt_embed_dict = Inference(
         self.config, analysis_set=analysis_set,
         rpt_type=rpt_type).init_predict()
     self.persist_rpt_data(rpt_tups)
     self.maybe_build_cache(stmt_embed_dict)
Esempio n. 10
0
 def publish_flow(self) -> None:
     # N.B. publishing all statements and tweets that meet length thresholds, driven by separate statements/tweets
     # tables since metadata is substantially different and not straightforward to cleanly combine
     target_tups = []
     for sql in [
             self.config.experiment.infsvc.sql.stmts_to_analyze_sql,
             self.config.experiment.infsvc.sql.tweets_to_analyze_sql
     ]:
         target_tups.extend(fetchallwrapper(self.cnxp.get_connection(),
                                            sql))
     if target_tups:
         inf_metadata = self.prep_new_threads(target_tups)
         self.publish_inference(
             Inference(self.config).init_predict(), inf_metadata)
     else:
         logger.info(f"No new claims found to analyze and publish")