def _joinpath(fpath: Union[S3Path, Path], other: str) -> Union[S3Path, Path]: if isinstance(fpath, Path): return fpath.joinpath(other).absolute() else: if fpath.to_string().endswith('/') and not other.startswith('/') or \ not fpath.to_string().endswith('/') and other.startswith('/'): return S3Path.from_string(fpath.to_string() + other) elif fpath.to_string().endswith('/') and other.startswith('/'): return S3Path.from_string(fpath.to_string() + other[1:]) elif not fpath.to_string().endswith('/') and not other.startswith('/'): return S3Path.from_string(fpath.to_string() + '/' + other) else: raise ValueError(f'Unable to join {fpath.to_string()} and ' f'{other} with "/"')
def get_dir_iter(path: str, file_ending: Optional[str] = None) -> List: """Takes a directory path and returns a list of files Parameters ---------- path : The path to the directory to loop over file_ending : If provided, files in the returned list must be of this format, e.g. .pkl Returns ------- : A list of file in the directory """ if path.startswith('s3://'): s3 = get_s3_client(unsigned=False) s3_base_path = S3Path.from_string(path) input_iter = \ [s3p.to_string() for s3p in s3_base_path.list_objects(s3)] else: local_base_path = Path(path) input_iter = [ f.absolute().as_posix() for f in local_base_path.glob('*') if f.is_file() ] if file_ending: input_iter = [f for f in input_iter if f.endswith(file_ending)] return input_iter
def load_db_content(ns_list, pkl_filename=None, ro=None, reload=False): if isinstance(pkl_filename, str) and pkl_filename.startswith('s3:'): pkl_filename = S3Path.from_string(pkl_filename) # Get the raw data if reload or not pkl_filename: if not ro: ro = get_ro('primary') logger.info("Querying the database for statement metadata...") results = [] for ns in ns_list: logger.info("Querying for {ns}".format(ns=ns)) res = ro.select_all([ ro.PaMeta.mk_hash, ro.PaMeta.db_name, ro.PaMeta.db_id, ro.PaMeta.ag_num, ro.PaMeta.ev_count, ro.PaMeta.type_num ], ro.PaMeta.db_name.like(ns)) results.extend(res) results = {(h, dbn, dbi, ag_num, ev_cnt, ro_type_map.get_str(tn)) for h, dbn, dbi, ag_num, ev_cnt, tn in results} if pkl_filename: if isinstance(pkl_filename, S3Path): upload_pickle_to_s3(results, pkl_filename) else: with open(pkl_filename, 'wb') as f: pickle.dump(results, f) # Get a cached pickle else: logger.info("Loading database content from %s" % pkl_filename) if pkl_filename.startswith('s3:'): results = load_pickle_from_s3(pkl_filename) else: with open(pkl_filename, 'rb') as f: results = pickle.load(f) logger.info("{len} stmts loaded".format(len=len(results))) return results
def dump_sif(df_file=None, db_res_file=None, csv_file=None, src_count_file=None, reload=False, reconvert=True, ro=None): if ro is None: ro = get_db('primary') # Get the db content from a new DB dump or from file db_content = load_db_content(reload=reload, ns_list=NS_LIST, pkl_filename=db_res_file, ro=ro) # Convert the database query result into a set of pairwise relationships df = make_dataframe(pkl_filename=df_file, reconvert=reconvert, db_content=db_content) if csv_file: if isinstance(csv_file, str) and csv_file.startswith('s3:'): csv_file = S3Path.from_string(csv_file) # Aggregate rows by genes and stmt type logger.info("Saving to CSV...") filt_df = df.filter(items=[ 'agA_ns', 'agA_id', 'agA_name', 'agB_ns', 'agB_id', 'agB_name', 'stmt_type', 'evidence_count' ]) type_counts = filt_df.groupby(by=[ 'agA_ns', 'agA_id', 'agA_name', 'agB_ns', 'agB_id', 'agB_name', 'stmt_type' ]).sum() # This requires package s3fs under the hood. See: # https://pandas.pydata.org/pandas-docs/stable/whatsnew/v0.20.0.html#s3-file-handling if isinstance(csv_file, S3Path): try: type_counts.to_csv(csv_file.to_string()) except Exception as e: try: logger.warning('Failed to upload csv to s3 using direct ' 's3 url, trying boto3: %s.' % e) s3 = get_s3_client(unsigned=False) csv_buf = StringIO() type_counts.to_csv(csv_buf) s3.put_object(Body=csv_buf.getvalue(), **csv_file.kw()) logger.info('Uploaded CSV file to s3') except Exception as second_e: logger.error('Failed to upload csv file with fallback ' 'method') logger.exception(second_e) # save locally else: type_counts.to_csv(csv_file) if src_count_file: _ = get_source_counts(src_count_file, ro=ro) return
def get_s3_path(self) -> S3Path: """Return an S3Path object of the saved s3 location Returns ------- S3Path """ if self.s3_location is None: raise ValueError('s3_location is not set') return S3Path.from_string(self.s3_location)
def is_dir(path: str): if path.startswith('s3://'): from indra_db.util.s3_path import S3Path from .aws import get_s3_client s3 = get_s3_client(False) s3dp = S3Path.from_string(path) if not s3dp.exists(s3): raise ValueError(f'Path {path} does not seem to exists') else: dp = Path(path) if not dp.is_dir(): raise ValueError(f'Path {path} does not exist') return path
def s3_file_opener(s3_url: str, unsigned: bool = False, **kwargs) -> \ Union[object, pd.DataFrame, Dict]: """Open a file from s3 given a standard s3-path kwargs are only relevant for csv/tsv files and are used for pd.read_csv() Parameters ---------- s3_url : str S3 url of the format 's3://<bucket>/<key>'. The key is assumed to also contain a file ending unsigned : bool If True, perform S3 calls unsigned. Default: False Returns ------- Union[object, pd.DataFrame, Dict] Object stored on S3 """ from indra_db.util.s3_path import S3Path from .aws import load_pickle_from_s3, read_json_from_s3, get_s3_client logger.info(f'Loading {s3_url} from s3') s3_path = S3Path.from_string(s3_url) s3 = get_s3_client(unsigned=unsigned) bucket, key = s3_path.bucket, s3_path.key if key.endswith('.json'): return read_json_from_s3(s3=s3, key=key, bucket=bucket) elif key.endswith('.pkl'): return load_pickle_from_s3(s3=s3, key=key, bucket=bucket) elif key.endswith(('.csv', '.tsv')): fileio = S3Path.from_string(s3_url).get(s3=s3) csv_str = fileio['Body'].read().decode('utf-8') raw_file = StringIO(csv_str) return pd.read_csv(raw_file, **kwargs) else: logger.warning(f'File type {key.split(".")[-1]} not recognized, ' f'returning S3 file stream handler (access from ' f'`res["Body"].read()`)') return S3Path.from_string(s3_url).get(s3=s3)
def check_path(fpath: str): if fpath.startswith('s3://'): if file_ending and not fpath.endswith(file_ending): raise ValueError(f'Unrecognized file type ' f'{fpath.split("/")[-1]}') from indra_db.util.s3_path import S3Path from .aws import get_s3_client if not S3Path.from_string(fpath).exists(s3=get_s3_client(False)): raise ValueError(f'File {fpath} does not exist') return fpath p = Path(fpath) if not p.is_file(): raise ValueError(f'File {fpath} does not exist') if file_ending and not p.name.endswith(file_ending): raise ValueError(f'Unrecognized file type {p.name.split(".")[-1]}') return fpath
def _load_file(path): if isinstance(path, str) and path.startswith('s3:') or \ isinstance(path, S3Path): if isinstance(path, str): s3path = S3Path.from_string(path) else: s3path = path if s3path.to_string().endswith('pkl'): return load_pickle_from_s3(s3path) elif s3path.to_string().endswith('json'): return load_json_from_s3(s3path) else: raise ValueError(f'Unknown file format of {path}') else: if path.endswith('pkl'): with open(path, 'rb') as f: return pickle.load(f) elif path.endswith('json'): with open(path, 'r') as f: return json.load(f)
def get_source_counts(pkl_filename=None, ro=None): """Returns a dict of dicts with evidence count per source, per statement The dictionary is at the top level keyed by statement hash and each entry contains a dictionary keyed by the source that support the statement where the entries are the evidence count for that source.""" logger.info('Getting source counts per statement') if isinstance(pkl_filename, str) and pkl_filename.startswith('s3:'): pkl_filename = S3Path.from_string(pkl_filename) if not ro: ro = get_ro('primary-ro') ev = {h: j for h, j in ro.select_all([ro.SourceMeta.mk_hash, ro.SourceMeta.src_json])} if pkl_filename: if isinstance(pkl_filename, S3Path): upload_pickle_to_s3(obj=ev, s3_path=pkl_filename) else: with open(pkl_filename, 'wb') as f: pickle.dump(ev, f) return ev
def load_db_content(ns_list, pkl_filename=None, ro=None, reload=False): """Get preassembled stmt metadata from the DB for export. Queries the NameMeta, TextMeta, and OtherMeta tables as needed to get agent/stmt metadata for agents from the given namespaces. Parameters ---------- ns_list : list of str List of agent namespaces to include in the metadata query. pkl_filename : str Name of pickle file to save to (if reloading) or load from (if not reloading). If an S3 path is given (i.e., pkl_filename starts with `s3:`), the file is loaded to/saved from S3. If not given, automatically reloads the content (overriding reload). ro : ReadonlyDatabaseManager Readonly database to load the content from. If not given, calls `get_ro('primary')` to get the primary readonly DB. reload : bool Whether to re-query the database for content or to load the content from from `pkl_filename`. Note that even if `reload` is False, if no `pkl_filename` is given, data will be reloaded anyway. Returns ------- set of tuples Set of tuples containing statement information organized by agent. Tuples contain (stmt_hash, agent_ns, agent_id, agent_num, evidence_count, stmt_type). """ if isinstance(pkl_filename, str) and pkl_filename.startswith('s3:'): pkl_filename = S3Path.from_string(pkl_filename) # Get the raw data if reload or not pkl_filename: if not ro: ro = get_ro('primary') logger.info("Querying the database for statement metadata...") results = {} for ns in ns_list: logger.info("Querying for {ns}".format(ns=ns)) filters = [] if ns == 'NAME': tbl = ro.NameMeta elif ns == 'TEXT': tbl = ro.TextMeta else: tbl = ro.OtherMeta filters.append(tbl.db_name.like(ns)) filters.append(tbl.is_complex_dup == False) res = ro.select_all([tbl.mk_hash, tbl.db_id, tbl.ag_num, tbl.ev_count, tbl.type_num], *filters) results[ns] = res results = {(h, dbn, dbi, ag_num, ev_cnt, ro_type_map.get_str(tn)) for dbn, value_list in results.items() for h, dbi, ag_num, ev_cnt, tn in value_list} if pkl_filename: if isinstance(pkl_filename, S3Path): upload_pickle_to_s3(results, pkl_filename) else: with open(pkl_filename, 'wb') as f: pickle.dump(results, f) # Get a cached pickle else: logger.info("Loading database content from %s" % pkl_filename) if pkl_filename.startswith('s3:'): results = load_pickle_from_s3(pkl_filename) else: with open(pkl_filename, 'rb') as f: results = pickle.load(f) logger.info("{len} stmts loaded".format(len=len(results))) return results
def dump_sif(src_count_file, res_pos_file, belief_file, df_file=None, db_res_file=None, csv_file=None, reload=True, reconvert=True, ro=None, normalize_names: bool = True): """Build and dump a sif dataframe of PA statements with grounded agents Parameters ---------- src_count_file : Union[str, S3Path] A location to load the source count dict from. Can be local file path, an s3 url string or an S3Path instance. res_pos_file : Union[str, S3Path] A location to load the residue-postion dict from. Can be local file path, an s3 url string or an S3Path instance. belief_file : Union[str, S3Path] A location to load the belief dict from. Can be local file path, an s3 url string or an S3Path instance. df_file : Optional[Union[str, S3Path]] If provided, dump the sif to this location. Can be local file path, an s3 url string or an S3Path instance. db_res_file : Optional[Union[str, S3Path]] If provided, save the db content to this location. Can be local file path, an s3 url string or an S3Path instance. csv_file : Optional[str, S3Path] If provided, calculate dataframe statistics and save to local file or s3. Can be local file path, an s3 url string or an S3Path instance. reconvert : bool Whether to generate a new DataFrame from the database content or to load and return a DataFrame from `df_file`. If False, `df_file` must be given. Default: True. reload : bool If True, load new content from the database and make a new dataframe. If False, content can be loaded from provided files. Default: True. ro : Optional[PrincipalDatabaseManager] Provide a DatabaseManager to load database content from. If not provided, `get_ro('primary')` will be used. normalize_names : If True, detect and try to merge name duplicates (same entity with different names, e.g. Loratadin vs loratadin). Default: False """ def _load_file(path): if isinstance(path, str) and path.startswith('s3:') or \ isinstance(path, S3Path): if isinstance(path, str): s3path = S3Path.from_string(path) else: s3path = path if s3path.to_string().endswith('pkl'): return load_pickle_from_s3(s3path) elif s3path.to_string().endswith('json'): return load_json_from_s3(s3path) else: raise ValueError(f'Unknown file format of {path}') else: if path.endswith('pkl'): with open(path, 'rb') as f: return pickle.load(f) elif path.endswith('json'): with open(path, 'r') as f: return json.load(f) if ro is None: ro = get_db('primary') # Get the db content from a new DB dump or from file db_content = load_db_content(reload=reload, ns_list=NS_LIST, pkl_filename=db_res_file, ro=ro) # Load supporting files res_pos = _load_file(res_pos_file) src_count = _load_file(src_count_file) belief = _load_file(belief_file) # Convert the database query result into a set of pairwise relationships df = make_dataframe(pkl_filename=df_file, reconvert=reconvert, db_content=db_content, src_count_dict=src_count, res_pos_dict=res_pos, belief_dict=belief, normalize_names=normalize_names) if csv_file: if isinstance(csv_file, str) and csv_file.startswith('s3:'): csv_file = S3Path.from_string(csv_file) # Aggregate rows by genes and stmt type logger.info("Saving to CSV...") filt_df = df.filter(items=['agA_ns', 'agA_id', 'agA_name', 'agB_ns', 'agB_id', 'agB_name', 'stmt_type', 'evidence_count']) type_counts = filt_df.groupby(by=['agA_ns', 'agA_id', 'agA_name', 'agB_ns', 'agB_id', 'agB_name', 'stmt_type']).sum() # This requires package s3fs under the hood. See: # https://pandas.pydata.org/pandas-docs/stable/whatsnew/v0.20.0.html#s3-file-handling if isinstance(csv_file, S3Path): try: type_counts.to_csv(csv_file.to_string()) except Exception as e: try: logger.warning('Failed to upload csv to s3 using direct ' 's3 url, trying boto3: %s.' % e) s3 = get_s3_client(unsigned=False) csv_buf = StringIO() type_counts.to_csv(csv_buf) csv_file.upload(s3, csv_buf) logger.info('Uploaded CSV file to s3') except Exception as second_e: logger.error('Failed to upload csv file with fallback ' 'method') logger.exception(second_e) # save locally else: type_counts.to_csv(csv_file) return
def make_dataframe(reconvert, db_content, res_pos_dict, src_count_dict, belief_dict, pkl_filename=None, normalize_names: bool = False): """Make a pickled DataFrame of the db content, one row per stmt. Parameters ---------- reconvert : bool Whether to generate a new DataFrame from the database content or to load and return a DataFrame from the given pickle file. If False, `pkl_filename` must be given. db_content : set of tuples Set of tuples of agent/stmt data as returned by `load_db_content`. res_pos_dict : Dict[str, Dict[str, str]] Dict containing residue and position keyed by hash. src_count_dict : Dict[str, Dict[str, int]] Dict of dicts containing source counts per source api keyed by hash. belief_dict : Dict[str, float] Dict of belief scores keyed by hash. pkl_filename : str Name of pickle file to save to (if reconverting) or load from (if not reconverting). If an S3 path is given (i.e., pkl_filename starts with `s3:`), the file is loaded to/saved from S3. If not given, reloads the content (overriding reload). normalize_names : If True, detect and try to merge name duplicates (same entity with different names, e.g. Loratadin vs loratadin). Default: False Returns ------- pandas.DataFrame DataFrame containing the content, with columns: 'agA_ns', 'agA_id', 'agA_name', 'agB_ns', 'agB_id', 'agB_name', 'stmt_type', 'evidence_count', 'stmt_hash'. """ if isinstance(pkl_filename, str) and pkl_filename.startswith('s3:'): pkl_filename = S3Path.from_string(pkl_filename) if reconvert: # Content consists of tuples organized by agent, e.g. # (-11421523615931377, 'UP', 'P04792', 1, 1, 'Phosphorylation') # # First we need to organize by statement, collecting all agents # for each statement along with evidence count and type. # We also separately store the NAME attribute for each statement # agent (indexing by hash/agent_num). logger.info("Organizing by statement...") stmt_info = {} # Store statement info (agents, ev, type) by hash ag_name_by_hash_num = {} # Store name for each stmt agent for h, db_nm, db_id, num, n, t in tqdm(db_content): db_nmn, db_id = fix_id(db_nm, db_id) # Populate the 'NAME' dictionary per agent if db_nm == 'NAME': ag_name_by_hash_num[(h, num)] = db_id if h not in stmt_info.keys(): stmt_info[h] = {'agents': [], 'ev_count': n, 'type': t} stmt_info[h]['agents'].append((num, db_nm, db_id)) # Turn into dataframe with geneA, geneB, type, indexed by hash; # expand out complexes to multiple rows # Organize by pairs of genes, counting evidence. nkey_errors = 0 error_keys = [] rows = [] logger.info("Converting to pairwise entries...") # Iterate over each statement for hash, info_dict in tqdm(stmt_info.items()): # Get the priority grounding for the agents in each position agents_by_num = {} for num, db_nm, db_id in info_dict['agents']: # Agent name is handled separately so we skip it here if db_nm == 'NAME': continue # For other namespaces, we get the top-priority namespace # given all namespaces for the agent else: assert db_nm in NS_PRIORITY_LIST db_rank = NS_PRIORITY_LIST.index(db_nm) # If we don't already have an agent for this num, use the # one we've found if num not in agents_by_num: agents_by_num[num] = (num, db_nm, db_id, db_rank) # Otherwise, take the current agent if the identifier type # has a higher rank else: cur_rank = agents_by_num[num][3] if db_rank < cur_rank: agents_by_num[num] = (num, db_nm, db_id, db_rank) # Make ordered list of agents for this statement, picking up # the agent name from the ag_name_by_hash_num dict that we # built earlier agents = [] for num, db_nm, db_id, _ in sorted(agents_by_num.values()): # Try to get the agent name ag_name = ag_name_by_hash_num.get((hash, num), None) # If the name is not found, log it but allow the agent # to be included as None if ag_name is None: nkey_errors += 1 error_keys.append((hash, num)) if nkey_errors < 11: logger.warning('Missing key in agent name dict: ' '(%s, %s)' % (hash, num)) elif nkey_errors == 11: logger.warning('Got more than 10 key warnings: ' 'muting further warnings.') agents.append((db_nm, db_id, ag_name)) # Need at least two agents. if len(agents) < 2: continue # If this is a complex, or there are more than two agents, permute! if info_dict['type'] == 'Complex': # Skip complexes with 4 or more members if len(agents) > 3: continue pairs = permutations(agents, 2) else: pairs = [agents] # Add all the pairs, and count up total evidence. for pair in pairs: row = OrderedDict([ ('agA_ns', pair[0][0]), ('agA_id', pair[0][1]), ('agA_name', pair[0][2]), ('agB_ns', pair[1][0]), ('agB_id', pair[1][1]), ('agB_name', pair[1][2]), ('stmt_type', info_dict['type']), ('evidence_count', info_dict['ev_count']), ('stmt_hash', hash), ('residue', res_pos_dict['residue'].get(hash)), ('position', res_pos_dict['position'].get(hash)), ('source_counts', src_count_dict.get(hash)), ('belief', belief_dict.get(str(hash))) ]) rows.append(row) if nkey_errors: ef = 'key_errors.csv' logger.warning('%d KeyErrors. Offending keys found in %s' % (nkey_errors, ef)) with open(ef, 'w') as f: f.write('hash,PaMeta.ag_num\n') for kn in error_keys: f.write('%s,%s\n' % kn) df = pd.DataFrame.from_dict(rows) if pkl_filename: if isinstance(pkl_filename, S3Path): upload_pickle_to_s3(obj=df, s3_path=pkl_filename) else: with open(pkl_filename, 'wb') as f: pickle.dump(df, f) else: if not pkl_filename: logger.error('Have to provide pickle file if not reconverting') raise FileExistsError else: if isinstance(pkl_filename, S3Path): df = load_pickle_from_s3(pkl_filename) else: with open(pkl_filename, 'rb') as f: df = pickle.load(f) if normalize_names: normalize_sif_names(sif_df=df) return df
'--single-proc', action='store_true', help='Run all scripts on a single process. This option is good when ' 'debugging or if the environment for some reason does not ' 'support multiprocessing. Default: False.' ) args = parser.parse_args() base_path: str = args.base_path outdir: str = args.outdir single_proc: bool = args.single_proc s3 = boto3.client('s3') # Set input dir if base_path.startswith('s3://'): s3_base_path = S3Path.from_string(base_path) input_iter = \ [s3p.to_string() for s3p in s3_base_path.list_objects(s3) if s3p.to_string().endswith('.pkl')] else: local_base_path = Path(base_path) input_iter = [f.absolute().as_posix() for f in local_base_path.glob('*.pkl')] # Set output dir if outdir.startswith('s3://'): output_dir = S3Path.from_string(outdir) else: output_dir = Path(outdir) dry = args.dry
def _exists(fpath: str) -> bool: if fpath.startswith('s3://'): return S3Path.from_string(fpath).exists(s3) else: return Path(fpath).is_file()
def get_corr_stats_axb(self, z_corr: Optional[Union[str, pd.DataFrame]] = None, max_proc: Optional[int] = None, max_so_pairs_size: int = 10000, mp_pairs: bool = True, run_linear: bool = False) -> Results: """Get statistics of the correlations from different explanation types Note: the provided options have no effect if the data is loaded from cache. Parameters ---------- z_corr : Optional[Union[pd.DataFrame, str]] A pd.DataFrame containing the correlation z scores used to create the statistics in this object. Pro max_proc : int > 0 The maximum number of processes to run in the multiprocessing in get_corr_stats_mp. Default: multiprocessing.cpu_count() max_so_pairs_size : int The maximum number of correlation pairs to process. If the number of eligible pairs is larger than this number, a random sample of max_so_pairs_size is used. Default: 10 000. If the number of pairs to check is smaller than 10 000, no sampling is done. mp_pairs : bool If True, get the pairs to process using multiprocessing if larger than 10 000. Default: True. run_linear : bool If True, gather the data without multiprocessing. This option is good when debugging or if the environment for some reason does not support multiprocessing. Default: False. Returns ------- Results A BaseModel containing correlation data for different explanations """ if not self.corr_stats_axb: s3 = get_s3_client(unsigned=False) try: corr_stats_loc = self.get_s3_corr_stats_path() if S3Path.from_string(corr_stats_loc).exists(s3): logger.info(f'Found corr stats data at {corr_stats_loc}') corr_stats_json = file_opener(corr_stats_loc) self.corr_stats_axb = Results(**corr_stats_json) else: logger.info(f'No corr stats data at found at ' f'{corr_stats_loc}') except ValueError as ve: # Raised when s3 location is not set logger.warning(ve) # If not found on s3 or ValueError was raised if not self.corr_stats_axb: logger.info('Generating corr stats data') # Load correlation matrix if z_corr is None: z_corr = self.load_z_corr() if isinstance(z_corr, str): z_corr = self.load_z_corr(local_file_path=z_corr) # Load reactome if present try: reactome = self.load_reactome() except FileNotFoundError: logger.info('No reactome file used in script') reactome = None self.corr_stats_axb: Results = axb_stats( self.expl_df, self.stats_df, z_corr=z_corr, reactome=reactome, eval_str=False, max_proc=max_proc, max_corr_pairs=max_so_pairs_size, do_mp_pairs=mp_pairs, run_linear=run_linear) try: corr_stats_loc = self.get_s3_corr_stats_path() logger.info(f'Uploading corr stats to S3 at ' f'{corr_stats_loc}') s3p_loc = S3Path.from_string(corr_stats_loc) s3p_loc.put(s3=s3, body=self.corr_stats_axb.json()) logger.info('Finished uploading corr stats to S3') except ValueError: logger.warning('Unable to upload corr stats to S3') else: logger.info('Data already present in corr_stats_axb') return self.corr_stats_axb
def main(indra_net: str, z_score: str, outname: str, graph_type: str, sd_range: Tuple[float, Union[None, float]], random: bool = False, raw_data: Optional[List[str]] = None, raw_corr: Optional[List[str]] = None, expl_funcs: Optional[List[str]] = None, pb_node_mapping: Optional[Dict[str, Set]] = None, n_chunks: Optional[int] = 256, is_a_part_of: Optional[List[str]] = None, immediate_only: Optional[bool] = False, return_unexplained: Optional[bool] = False, reactome_path: Optional[str] = None, subset_list: Optional[List[Union[str, int]]] = None, apriori_explained: Optional[Union[bool, str]] = False, allowed_ns: Optional[List[str]] = None, allowed_sources: Optional[List[str]] = None, info: Optional[Dict[Hashable, Any]] = None, indra_date: Optional[str] = None, depmap_date: Optional[str] = None, sample_size: Optional[int] = None, shuffle: Optional[bool] = False, overwrite: Optional[bool] = False, normalize_names: Optional[bool] = False, argparse_dict: Optional[Dict[str, Union[str, float, int, List[str]]]] = None): """Set up correlation matching of depmap data with an indranet graph Parameters ---------- indra_net : Union[nx.DiGraph, nx.MultiDiGraph] The graph representation of the indra network. Each edge should have an attribute named 'statements' containing a list of sources supporting that edge. If signed search, indranet is expected to be an nx.MultiDiGraph with edges keyed by (gene, gene, sign) tuples. outname : str A file path (can be an S3 url) to where to store the final pickle file containing the DepmapExplainer graph_type : str The graph type of the graph used for the explanations. Can be one of 'unsigned', 'signed', 'pybel'. sd_range : Tuple[float, Union[float, None]] A tuple of the lower and optionally the upper bound of the z-score range to use when getting correlations random : bool Whether to do a random sampling or not. If True do a random sample instead of cutting the correlations of to the given SD range. z_score : Union[pd.DataFrame, str] The path to the correlation DataFrame. If either raw data or raw corr are used, this filepath will be used to save the resulting DataFrame instead. raw_data : Optional[List[str]] File paths to CRISPR raw data and RNAi raw data from the DepMap Portal raw_corr : Optional[List[str]] File paths to raw correlation data (before z-score conversion) containing hdf compressed correlation data. These files contain the result of running `raw_df.corr()`. expl_funcs : Optional[List[str]] Provide a list of explanation functions to apply. Default: All functions are applied. Currently available functions: - 'expl_ab': Explain pair by checking for an edge between a and b - 'expl_ba': Explain pair by checking for an edge between b and a - 'expl_axb': Explain pair by looking for intermediate nodes connecting a to b - 'expl_bxa': Explain pair by looking for intermediate nodes connecting b to a - 'get_sr': Explain pair by finding common upstream nodes - 'get_st': Explain pair by finding common downstream nodes - 'get_sd': Explain pair by finding common downstream nodes two edges from s and o - 'find_cp': Explain pair by looking for ontological parents - 'apriori_explained': Map entities to a-priori explanations - 'common_reactome_paths': Explain pair by matching common reactome pathways pb_node_mapping : Optional[Union[Dict, Set[Any]]] If graph type is "pybel", use this argument to provide a mapping from HGNC symbols to pybel nodes in the pybel model n_chunks : Optional[int] How many chunks to split the data into in the multiprocessing part of the script is_a_part_of : Optional[Iterable] A set of identifiers to look for when applying the common parent explanation between a pair of correlating nodes. immediate_only : Optional[bool] Only look for immediate parents. This option might limit the number of results that are returned. Default: False. return_unexplained : Optional[bool] If True: return explanation data even if there is no set intersection of nodes up- or downstream of A, B for shared regulators and shared targets. Default: False. reactome_path : Optional[str] File path to reactome data. subset_list : Optional[List[Union[str, int]]] Provide a list if entities that defines a subset of the entities in the correlation data frame that will be picked as 'a' when the pairs (a, b) are generated apriori_explained : Optional[str] A mapping from entity names to a string containing a short explanation of why the entity is explained. To use the default MitoCarta 3.0 file, run the following code: >>> from depmap_analysis.scripts.depmap_script2 import mito_file >>> from depmap_analysis.preprocessing import get_mitocarta_info >>> apriori_mapping = get_mitocarta_info(mito_file) then pass `apriori_mapping` as `apriori_explained` when calling this function: >>> main(apriori_explained=apriori_mapping, ...) allowed_ns : Optional[List[str]] A list of allowed name spaces for explanations involving intermediary nodes. Default: Any namespace. allowed_sources : Optional[List[str]] The allowed sources for edges. This will not affect subsequent edges in explanations involving 2 or more edges. Default: all sources are allowed. info : Optional[Dict[Hashable, Any]] An optional dict in which to save meta data about this run indra_date : Optional[str] The date of the sif dump used to create the graph depmap_date : Optional[str] The date (usually a quarter e.g. 19Q4) the depmap data was published on depmap.org sample_size : Optional[int] Number of correlation pairs to approximately get out of the correlation matrix after down sampling it shuffle : Optional[bool] If True, shuffle the correlation matrix. This is good to do in case the input data have some sort of structure that could lead to large discrepancies between compute times for the different processes. Default: False. overwrite : Optional[bool] If True, overwrite any output files. Default: False. normalize_names : Optional[bool] If True, try to normalize the names in the correlation matrix that are not found in the provided graph. Default: False. argparse_dict : Optional[Dict[str, Union[str, float, int, List[str]]]] Provide the argparse options from running this file as a script """ global indranet, hgnc_node_mapping, output_list indranet = file_opener(indra_net) assert isinstance(indranet, nx.DiGraph) assert expl_funcs is None or isinstance(expl_funcs, (list, tuple, set)) # 1 Check options sd_l, sd_u = sd_range if sd_range and len(sd_range) == 2 else \ ((sd_range[0], None) if sd_range and len(sd_range) == 1 else (None, None)) if not random and not sd_l and not sd_u: raise ValueError('Must specify at least a lower bound for the SD ' 'range or flag run for random explanation') if graph_type == 'pybel' and not pb_node_mapping: raise ValueError('Must provide PyBEL node mapping with option ' 'pb_node_mapping if graph type is "pybel"') if apriori_explained: if apriori_explained is True or mito_file_name in apriori_explained: # Run default apriori_explained = get_mitocarta_info(mito_file) else: # Hope it's a csv/tsv try: expl_df = pd.read_csv(apriori_explained) apriori_explained = { e: d for e, d in zip(expl_df.name, expl_df.description) } except Exception as err: raise ValueError('A-priori explained entities must be in a ' 'file that can be parsed as CSV/TSV with ' 'column names "name" for entity name and ' '"description" for explanation why the ' 'entity is explained.') \ from err logger.info(f'Using explained set with ' f'{len(apriori_explained)} entities') outname = outname if outname.endswith('.pkl') else \ outname + '.pkl' if not overwrite: if outname.startswith('s3://'): s3 = get_s3_client(unsigned=False) if S3Path.from_string(outname).exists(s3): raise FileExistsError(f'File {str(outname)} already exists!') elif Path(outname).is_file(): raise FileExistsError(f'File {str(outname)} already exists!') if z_score is not None and Path(z_score).is_file(): z_corr = pd.read_hdf(z_score) else: z_sc_options = { 'crispr_raw': raw_data[0], 'rnai_raw': raw_data[1], 'crispr_corr': raw_corr[0], 'rnai_corr': raw_corr[1], 'z_corr_path': z_score } z_corr = run_corr_merge(**z_sc_options) if reactome_path: up2path, _, pathid2pathname = file_opener(reactome_path) reactome_dict = { 'uniprot_mapping': up2path, 'pathid_name_mapping': pathid2pathname } else: reactome_dict = None # Get mapping of correlation names to pybel nodes if graph_type == 'pybel': if isinstance(pb_node_mapping, dict): hgnc_node_mapping = pb_node_mapping elif isinstance(pb_node_mapping, str) and \ Path(pb_node_mapping).is_file(): hgnc_node_mapping = file_opener(pb_node_mapping) else: raise ValueError('Could not load pybel node mapping') # 2. Filter to SD range OR run random sampling if random: logger.info('Doing random sampling through df.sample') z_corr = z_corr.sample(142, axis=0) z_corr = z_corr.filter(list(z_corr.index), axis=1) # Remove correlation values to not confuse with real data z_corr.loc[:, :] = 0 else: if sd_l and sd_u: logger.info(f'Filtering correlations to {sd_l} - {sd_u} SD') z_corr = z_corr[((z_corr > sd_l) & (z_corr < sd_u)) | ((z_corr < -sd_l) & (z_corr > -sd_u))] elif isinstance(sd_l, (int, float)) and sd_l and not sd_u: logger.info(f'Filtering correlations to {sd_l}+ SD') z_corr = z_corr[(z_corr > sd_l) | (z_corr < -sd_l)] sd_range = (sd_l, sd_u) if sd_u else (sd_l, None) # Pick a sample if sample_size is not None and not random: logger.info(f'Reducing correlation matrix to a random approximately ' f'{sample_size} correlation pairs.') z_corr = down_sample_df(z_corr, sample_size) # Shuffle corr matrix without removing items elif shuffle and not random: logger.info('Shuffling correlation matrix...') z_corr = z_corr.sample(frac=1, axis=0) z_corr = z_corr.filter(list(z_corr.index), axis=1) if normalize_names: logger.info('Normalizing correlation matrix column names') z_corr = normalize_corr_names(z_corr, indranet) else: logger.info('Leaving correlation matrix column names as is') # 4. Add meta data info_dict = {} if info: info_dict['info'] = info # Set the script_settings script_settings = { 'raw_data': raw_data, 'raw_corr': raw_corr, 'z_score': z_score, 'random': random, 'indranet': indra_net, 'shuffle': shuffle, 'sample_size': sample_size, 'n_chunks': n_chunks, 'outname': outname, 'apriori_explained': apriori_explained if isinstance(apriori_explained, str) else 'no info', 'graph_type': graph_type, 'pybel_node_mapping': pb_node_mapping if isinstance(pb_node_mapping, str) else 'no info', 'argparse_info': argparse_dict } # Create output list in global scope output_list = [] explanations = match_correlations(corr_z=z_corr, sd_range=sd_range, script_settings=script_settings, graph_filepath=indra_net, z_corr_filepath=z_score, apriori_explained=apriori_explained, graph_type=graph_type, allowed_ns=allowed_ns, allowed_sources=allowed_sources, is_a_part_of=is_a_part_of, expl_funcs=expl_funcs, reactome_filepath=reactome_path, indra_date=indra_date, info=info_dict, depmap_date=depmap_date, n_chunks=n_chunks, immediate_only=immediate_only, return_unexplained=return_unexplained, reactome_dict=reactome_dict, subset_list=subset_list) if outname.startswith('s3://'): try: logger.info(f'Uploading results to s3: {outname}') s3 = get_s3_client(unsigned=False) s3outpath = S3Path.from_string(outname) explanations.s3_location = s3outpath.to_string() s3outpath.upload(s3=s3, body=pickle.dumps(explanations)) logger.info('Finished uploading results to s3') except Exception: new_path = Path(outname.replace('s3://', '')) logger.warning(f'Something went wrong in s3 upload, trying to ' f'save locally instead to {new_path}') new_path.parent.mkdir(parents=True, exist_ok=True) dump_it_to_pickle(fname=new_path.absolute().resolve().as_posix(), pyobj=explanations, overwrite=overwrite) else: # mkdir in case it doesn't exist outpath = Path(outname) logger.info(f'Dumping results to {outpath}') outpath.parent.mkdir(parents=True, exist_ok=True) dump_it_to_pickle(fname=outpath.absolute().resolve().as_posix(), pyobj=explanations, overwrite=overwrite) logger.info('Script finished') explanations.summarize()
def plot_corr_stats(self, outdir: str, z_corr: Optional[Union[str, pd.DataFrame]] = None, show_plot: bool = False, max_proc: bool = None, index_counter: Optional[Union[Iterator, Generator]] = None, max_so_pairs_size: int = 10000, mp_pairs: bool = True, run_linear: bool = False, log_scale_y: bool = False): """Plot the results of running explainer.get_corr_stats_axb() Parameters ---------- outdir : str The output directory to save the plots in. If string starts with 's3://' upload to s3. outdir must then have the form 's3://<bucket>/<sub_dir>' where <bucket> must be specified and <sub_dir> is optional and may contain subdirectories. z_corr : Union[str, pd.DataFrame] A pd.DataFrame containing the correlation z scores used to create the statistics in this object. If not provided, an attempt will be made to load it from the file path present in script_settings. show_plot : bool If True, also show plots after saving them. Default False. max_proc : int > 0 The maximum number of processes to run in the multiprocessing in get_corr_stats_mp. Default: multiprocessing.cpu_count() index_counter : Union[Iterator, Generator] An object which produces a new int by using 'next()' on it. The integers are used to separate the figures so as to not append new plots in the same figure. max_so_pairs_size : int The maximum number of correlation pairs to process. If the number of eligible pairs is larger than this number, a random sample of max_so_pairs_size is used. Default: 10 000. mp_pairs : bool If True, get the pairs to process using multiprocessing if larger than 10 000. Default: True. run_linear : bool If True, gather the data without multiprocessing. This option is good when debugging or if the environment for some reason does not support multiprocessing. Default: False. log_scale_y : bool If True, plot the plots in this method with log10 scale on y-axis. Default: False. """ # Local file or s3 if outdir.startswith('s3://'): s3_path = S3Path.from_string(outdir) logger.info(f'Outdir path is on S3: {str(s3_path)}') od = None else: s3_path = None od = Path(outdir) if not od.is_dir(): logger.info(f'Creating directory/ies for {od}') od.mkdir(parents=True, exist_ok=True) # Get corr stats corr_stats: Results = self.get_corr_stats_axb( z_corr=z_corr, max_proc=max_proc, max_so_pairs_size=max_so_pairs_size, mp_pairs=mp_pairs, run_linear=run_linear) sd_str = self.get_sd_str() for m, (plot_type, data) in enumerate(corr_stats.dict().items()): if len(data) > 0: name = f'{plot_type}_{self.script_settings["graph_type"]}.pdf' logger.info(f'Using file name {name}') if od is None: fname = BytesIO() else: fname = od.joinpath(name).as_posix() if isinstance(data[0], tuple): data = [t[-1] for t in data] fig_index = next(index_counter) if index_counter else m plt.figure(fig_index) plt.hist(x=data, bins='auto', log=log_scale_y) title = f'{plot_type.replace("_", " ").capitalize()}; '\ f'{sd_str} {self.script_settings["graph_type"]}' plt.title(title) plt.xlabel('combined z-score') plt.ylabel('count') # Save to file or ByteIO and S3 plt.savefig(fname, format='pdf') if od is None: # Reset pointer fname.seek(0) # Upload to s3 full_s3_path = _joinpath(s3_path, name) _upload_bytes_io_to_s3(bytes_io_obj=fname, s3p=full_s3_path) # Show plot if show_plot: plt.show() # Close figure plt.close(fig_index) else: logger.warning(f'Empty result for {plot_type} in ' f'range {sd_str} for graph type ' f'{self.script_settings["graph_type"]}')
def plot_interesting(self, outdir: str, z_corr: Optional[Union[str, pd.DataFrame]] = None, show_plot: Optional[bool] = False, max_proc: Optional[int] = None, index_counter: Optional[Union[Iterator, Generator]] = None, max_so_pairs_size: int = 10000, mp_pairs: bool = True, run_linear: bool = False, log_scale_y: bool = False): """Plots the same type of plot as plot_dists, but filters A, B A, B are filtered to those that fulfill the following: - No a-b or b-a explanations - Not explained by apriori explanations - Without common reactome pathways - With a-x-b, b-x-a or shared target explanation Parameters ---------- outdir : str The output directory to save the plots in. If string starts with 's3://' upload to s3. outdir must then have the form 's3://<bucket>/<sub_dir>' where <bucket> must be specified and <sub_dir> is optional and may contain subdirectories. z_corr : Union[str, pd.DataFrame] A pd.DataFrame containing the correlation z scores used to create the statistics in this object. If not provided, an attempt will be made to load it from the file path present in script_settings. show_plot : bool If True also show plots max_proc : int > 0 The maximum number of processes to run in the multiprocessing in get_corr_stats_mp. Default: multiprocessing.cpu_count() index_counter : Union[Iterator, Generator] An object which produces a new int by using 'next()' on it. The integers are used to separate the figures so as to not append new plots in the same figure. max_so_pairs_size : int The maximum number of correlation pairs to process. If the number of eligible pairs is larger than this number, a random sample of max_so_pairs_size is used. Default: 10000. mp_pairs : bool If True, get the pairs to process using multiprocessing if larger than 10 000. Default: True. run_linear : bool If True, gather the data without multiprocessing. This option is good when debugging or if the environment for some reason does not support multiprocessing. Default: False. log_scale_y : bool If True, plot the plots in this method with log10 scale on y-axis. Default: False. """ # Local file or s3 if outdir.startswith('s3://'): s3_path = S3Path.from_string(outdir) od = None else: s3_path = None od = Path(outdir) if not od.is_dir(): od.mkdir(parents=True, exist_ok=True) # Get corr stats corr_stats: Results = self.get_corr_stats_axb( z_corr=z_corr, max_proc=max_proc, max_so_pairs_size=max_so_pairs_size, mp_pairs=mp_pairs, run_linear=run_linear) fig_index = next(index_counter) if index_counter \ else floor(datetime.timestamp(datetime.utcnow())) plt.figure(fig_index) plt.hist(corr_stats.azfb_avg_corrs, bins='auto', density=True, color='b', alpha=0.3, log=log_scale_y) plt.hist(corr_stats.avg_x_filtered_corrs, bins='auto', density=True, color='r', alpha=0.3, log=log_scale_y) legend = [ 'Filtered A-X-B for any X', 'Filtered A-X-B for X in network' ] sd_str = self.get_sd_str() title = f'avg X corrs, filtered {sd_str} ' \ f'({self.script_settings["graph_type"]})' plt.title(title) plt.ylabel('Norm. Density') plt.xlabel('mean(abs(corr(a,x)), abs(corr(x,b))) (SD)') plt.legend(legend) name = '%s_%s_axb_filtered_hist_comparison.pdf' % \ (sd_str, self.script_settings['graph_type']) # Save to file or ByteIO and S3 if od is None: fname = BytesIO() else: fname = od.joinpath(name).as_posix() plt.savefig(fname, format='pdf') if od is None: # Reset pointer fname.seek(0) # Upload to s3 full_s3_path = _joinpath(s3_path, name) _upload_bytes_io_to_s3(bytes_io_obj=fname, s3p=full_s3_path) # Show plot if show_plot: plt.show() # Close figure plt.close(fig_index)
def make_dataframe(reconvert, db_content, pkl_filename=None): if isinstance(pkl_filename, str) and pkl_filename.startswith('s3:'): pkl_filename = S3Path.from_string(pkl_filename) if reconvert: # Organize by statement logger.info("Organizing by statement...") stmt_info = {} ag_name_by_hash_num = {} for h, db_nm, db_id, num, n, t in db_content: # Populate the 'NAME' dictionary per agent if db_nm == 'NAME': ag_name_by_hash_num[(h, num)] = db_id if h not in stmt_info.keys(): stmt_info[h] = {'agents': [], 'ev_count': n, 'type': t} stmt_info[h]['agents'].append((num, db_nm, db_id)) # Turn into dataframe with geneA, geneB, type, indexed by hash; # expand out complexes to multiple rows # Organize by pairs of genes, counting evidence. nkey_errors = 0 error_keys = [] rows = [] logger.info("Converting to pairwise entries...") for hash, info_dict in stmt_info.items(): # Find roles with more than one agent agents_by_num = {} for num, db_nm, db_id in info_dict['agents']: if db_nm == 'NAME': continue else: assert db_nm in NS_PRIORITY_LIST db_rank = NS_PRIORITY_LIST.index(db_nm) # If we don't already have an agent for this num, use the # one we've found if num not in agents_by_num: agents_by_num[num] = (num, db_nm, db_id, db_rank) # Otherwise, take the current agent if the identifier type # has a higher rank else: cur_rank = agents_by_num[num][3] if db_rank < cur_rank: agents_by_num[num] = (num, db_nm, db_id, db_rank) agents = [] for num, db_nm, db_id, _ in sorted(agents_by_num.values()): try: agents.append( (db_nm, db_id, ag_name_by_hash_num[(hash, num)])) except KeyError: nkey_errors += 1 error_keys.append((hash, num)) if nkey_errors < 11: logger.warning('Missing key in agent name dict: ' '(%s, %s)' % (hash, num)) elif nkey_errors == 11: logger.warning('Got more than 10 key warnings: ' 'muting further warnings.') continue # Need at least two agents. if len(agents) < 2: continue # If this is a complex, or there are more than two agents, permute! if info_dict['type'] == 'Complex': # Skip complexes with 4 or more members if len(agents) > 3: continue pairs = permutations(agents, 2) else: pairs = [agents] # Add all the pairs, and count up total evidence. for pair in pairs: row = OrderedDict([('agA_ns', pair[0][0]), ('agA_id', pair[0][1]), ('agA_name', pair[0][2]), ('agB_ns', pair[1][0]), ('agB_id', pair[1][1]), ('agB_name', pair[1][2]), ('stmt_type', info_dict['type']), ('evidence_count', info_dict['ev_count']), ('stmt_hash', hash)]) rows.append(row) if nkey_errors: ef = 'key_errors.csv' logger.warning('%d KeyErrors. Offending keys found in %s' % (nkey_errors, ef)) with open(ef, 'w') as f: f.write('hash,PaMeta.ag_num\n') for kn in error_keys: f.write('%s,%s\n' % kn) df = pd.DataFrame.from_dict(rows) if pkl_filename: if isinstance(pkl_filename, S3Path): upload_pickle_to_s3(obj=df, s3_path=pkl_filename) else: with open(pkl_filename, 'wb') as f: pickle.dump(df, f) else: if not pkl_filename: logger.error('Have to provide pickle file if not reconverting') raise FileExistsError else: if isinstance(pkl_filename, S3Path): df = load_pickle_from_s3(pkl_filename) else: with open(pkl_filename, 'rb') as f: df = pickle.load(f) return df