def maybe_build_cache(self, stmt_embed_dict: Optional[Dict] = None) -> None: if ((not self.stmt_embed_cache.exists()) or self.config.inference.rebuild_stmt_cache) \ and not self.config.inference.update_perf_caches_only: torch.save(stmt_embed_dict, self.stmt_embed_cache) logger.info( f"Cached {len(stmt_embed_dict['sids'])} statement embeddings for analysis:" f" {self.stmt_embed_cache}") if ((not self.tweet_model_perf_cache.exists()) or (not self.nontweet_model_perf_cache.exists()) or self.config.inference.rebuild_perf_cache): tweet_model_perf_tups = fetchallwrapper( self.cnxp.get_connection(), self.config.inference.sql.tweet_model_perf_cache_sql) torch.save(tweet_model_perf_tups, self.tweet_model_perf_cache) logger.info( f"Cached model test accuracy into {len(tweet_model_perf_tups)} buckets for tweet model " f"reporting: {self.tweet_model_perf_cache}") nontweet_model_perf_tups = fetchallwrapper( self.cnxp.get_connection(), self.config.inference.sql.nontweet_model_perf_cache_sql) torch.save(nontweet_model_perf_tups, self.nontweet_model_perf_cache) logger.info( f"Cached model test accuracy into {len(nontweet_model_perf_tups)} buckets for nontweet model " f"reporting: {self.nontweet_model_perf_cache}") self.refresh_global_cache() # refresh global cache after updating
def gen_pred_exp_ds(self) -> Tuple[Dict, Tuple]: pred_exp_tups = fetchallwrapper(self.cnxp.get_connection(), self.config.inference.sql.pred_exp_sql) pred_exp_set = [] pred_exp_ds = OrderedDict({ 'bucket_type': [], 'bucket_acc': [], 'conf_percentile': [], 'pos_pred_acc': [], 'neg_pred_acc': [], 'pos_pred_ratio': [], 'neg_pred_ratio': [], 'statement_id': [], 'statement_text': [], 'tp': [], 'tn': [], 'fp': [], 'fn': [] }) for (bucket_type, bucket_acc, conf_percentile, pos_pred_acc, neg_pred_acc, pos_pred_ratio, neg_pred_ratio, statement_id, statement_text, ctxt_type, tp, tn, fp, fn) in pred_exp_tups: label = 'False' if tp == 1 or fn == 1 else 'True' pred_exp_set.append((statement_text, ctxt_type, label)) for k, v in zip(list(pred_exp_ds.keys()), [ bucket_type, bucket_acc, conf_percentile, pos_pred_acc, neg_pred_acc, pos_pred_ratio, neg_pred_ratio, statement_id, statement_text, tp, tn, fp, fn ]): pred_exp_ds[k].append(v) pred_exp_attr_tups, global_metric_summ = Inference( self.config, pred_exp_set=pred_exp_set).init_predict() pred_exp_ds['pred_exp_attr_tups'] = pred_exp_attr_tups return pred_exp_ds, global_metric_summ
def gen_analysis_set(self) -> List[Tuple]: # current use case involves relatively small analysis set that fits in memory and should only be used once # so wasteful to persist. if later use cases necessitate, will pickle or persist for larger datasets report_sql = f"select * from {self.report_view}" analysis_set = ModelAnalysisRpt.prep_model_analysis_ds( fetchallwrapper(self.cnxp.get_connection(), report_sql)) return analysis_set
def refresh_global_cache(self): global_model_perf_tups = fetchallwrapper( self.cnxp.get_connection(), self.config.inference.sql.global_model_perf_cache_sql) torch.save(global_model_perf_tups, self.global_model_perf_cache) logger.info( f"(Re)Built global model accuracy cache {self.global_model_perf_cache}" )
def fetch_mapping_inputs(self): mapping_inputs = [] truth_sql = self.config.data_source.sql.build_truths_embedding falsehood_sql = self.config.data_source.sql.build_falsehoods_embedding for idsql in [truth_sql, falsehood_sql]: mapping_inputs.append( db_utils.fetchallwrapper(self.cnxp.get_connection(), idsql)) return mapping_inputs
def gen_perf_exp_ds(self) -> Dict: perf_exp_dict = {} for cmatrix_rpt_type in [ *db_constants.TEST_CMATRICES, *db_constants.TEST_CONF_CMATRICES ]: perf_exp_dict[cmatrix_rpt_type] = fetchallwrapper( self.cnxp.get_connection(), f"select * from {cmatrix_rpt_type}") return perf_exp_dict
def check_ids(self, test_ids: List) -> Set: test_ids = [tup[0] for tup in test_ids] inlist = ", ".join(list(map(lambda x: '%s', test_ids))) test_ids = tuple(test_ids) sql = f"{self.config['scraping'][self.stype]['sql']['check_tid']} ({inlist})" found_ids = fetchallwrapper(self.cnxp.get_connection(), sql, test_ids) test_set = set(test_ids) found_set = set([f[0] for f in found_ids]) dup_ids = test_set.intersection(found_set) return dup_ids
def maybe_publish(self, target_type: str) -> None: # N.B. publishing all statements and tweets that meet length thresholds, driven by four tables: # a published and "notpublished" table for both statements and tweets # since metadata is substantially different and not straightforward to cleanly combine) if target_type == 'stmts': target_tups = fetchallwrapper( self.cnxp.get_connection(), self.config.experiment.tweetbot.sql.stmts_to_analyze_sql) interval = self.config.experiment.tweetbot.dcbot_poll_interval * self.non_twitter_updatefreq else: target_tups = fetchallwrapper( self.cnxp.get_connection(), self.config.experiment.tweetbot.sql.tweets_to_analyze_sql) interval = self.config.experiment.tweetbot.dcbot_poll_interval if target_tups: self.prep_new_threads(target_tups) self.publish_reports( Inference(self.config).init_predict(), target_type) else: logger.info( f"No new {target_type} found to analyze and publish. Trying again in {interval} seconds" )
def gen_report(self, rpt_type: str) -> None: analysis_set = self.gen_analysis_set() ds_meta = fetchallwrapper(self.cnxp.get_connection(), self.config.inference.sql.ds_md_sql)[0] self.config.data_source.dsid = ds_meta[0] self.config.data_source.train_start_date = datetime.datetime.combine( ds_meta[1], datetime.time()) self.config.data_source.train_end_date = datetime.datetime.combine( ds_meta[2], datetime.time()) rpt_tups, stmt_embed_dict = Inference( self.config, analysis_set=analysis_set, rpt_type=rpt_type).init_predict() self.persist_rpt_data(rpt_tups) self.maybe_build_cache(stmt_embed_dict)
def publish_flow(self) -> None: # N.B. publishing all statements and tweets that meet length thresholds, driven by separate statements/tweets # tables since metadata is substantially different and not straightforward to cleanly combine target_tups = [] for sql in [ self.config.experiment.infsvc.sql.stmts_to_analyze_sql, self.config.experiment.infsvc.sql.tweets_to_analyze_sql ]: target_tups.extend(fetchallwrapper(self.cnxp.get_connection(), sql)) if target_tups: inf_metadata = self.prep_new_threads(target_tups) self.publish_inference( Inference(self.config).init_predict(), inf_metadata) else: logger.info(f"No new claims found to analyze and publish")