Пример #1
0
def _upload_bytes_io_to_s3(bytes_io_obj: BytesIO, s3p: S3Path):
    """Upload a BytesIO object to s3

    Parameters
    ----------
    bytes_io_obj : BytesIO
        Object to upload
    s3p : S3Path
        An S3Path instance of the full upload url
    """
    logger.info(f'Uploading BytesIO object to s3: {str(s3p)}')
    bytes_io_obj.seek(0)  # Just in case
    s3 = get_s3_client(unsigned=False)
    s3p.put(body=bytes_io_obj, s3=s3)
Пример #2
0
def _joinpath(fpath: Union[S3Path, Path], other: str) -> Union[S3Path, Path]:
    if isinstance(fpath, Path):
        return fpath.joinpath(other).absolute()
    else:
        if fpath.to_string().endswith('/') and not other.startswith('/') or \
                not fpath.to_string().endswith('/') and other.startswith('/'):
            return S3Path.from_string(fpath.to_string() + other)
        elif fpath.to_string().endswith('/') and other.startswith('/'):
            return S3Path.from_string(fpath.to_string() + other[1:])
        elif not fpath.to_string().endswith('/') and not other.startswith('/'):
            return S3Path.from_string(fpath.to_string() + '/' + other)
        else:
            raise ValueError(f'Unable to join {fpath.to_string()} and '
                             f'{other} with "/"')
Пример #3
0
def load_db_content(ns_list, pkl_filename=None, ro=None, reload=False):
    if isinstance(pkl_filename, str) and pkl_filename.startswith('s3:'):
        pkl_filename = S3Path.from_string(pkl_filename)
    # Get the raw data
    if reload or not pkl_filename:
        if not ro:
            ro = get_ro('primary')
        logger.info("Querying the database for statement metadata...")
        results = []
        for ns in ns_list:
            logger.info("Querying for {ns}".format(ns=ns))
            res = ro.select_all([
                ro.PaMeta.mk_hash, ro.PaMeta.db_name, ro.PaMeta.db_id,
                ro.PaMeta.ag_num, ro.PaMeta.ev_count, ro.PaMeta.type_num
            ], ro.PaMeta.db_name.like(ns))
            results.extend(res)
        results = {(h, dbn, dbi, ag_num, ev_cnt, ro_type_map.get_str(tn))
                   for h, dbn, dbi, ag_num, ev_cnt, tn in results}
        if pkl_filename:
            if isinstance(pkl_filename, S3Path):
                upload_pickle_to_s3(results, pkl_filename)
            else:
                with open(pkl_filename, 'wb') as f:
                    pickle.dump(results, f)
    # Get a cached pickle
    else:
        logger.info("Loading database content from %s" % pkl_filename)
        if pkl_filename.startswith('s3:'):
            results = load_pickle_from_s3(pkl_filename)
        else:
            with open(pkl_filename, 'rb') as f:
                results = pickle.load(f)
    logger.info("{len} stmts loaded".format(len=len(results)))
    return results
Пример #4
0
def get_dir_iter(path: str, file_ending: Optional[str] = None) -> List:
    """Takes a directory path and returns a list of files

    Parameters
    ----------
    path :
        The path to the directory to loop over
    file_ending :
        If provided, files in the returned list must be of this format,
        e.g. .pkl

    Returns
    -------
    :
        A list of file in the directory
    """
    if path.startswith('s3://'):
        s3 = get_s3_client(unsigned=False)
        s3_base_path = S3Path.from_string(path)
        input_iter = \
            [s3p.to_string() for s3p in s3_base_path.list_objects(s3)]
    else:
        local_base_path = Path(path)
        input_iter = [
            f.absolute().as_posix() for f in local_base_path.glob('*')
            if f.is_file()
        ]

    if file_ending:
        input_iter = [f for f in input_iter if f.endswith(file_ending)]

    return input_iter
Пример #5
0
def dump_sif(df_file=None,
             db_res_file=None,
             csv_file=None,
             src_count_file=None,
             reload=False,
             reconvert=True,
             ro=None):
    if ro is None:
        ro = get_db('primary')

    # Get the db content from a new DB dump or from file
    db_content = load_db_content(reload=reload,
                                 ns_list=NS_LIST,
                                 pkl_filename=db_res_file,
                                 ro=ro)

    # Convert the database query result into a set of pairwise relationships
    df = make_dataframe(pkl_filename=df_file,
                        reconvert=reconvert,
                        db_content=db_content)

    if csv_file:
        if isinstance(csv_file, str) and csv_file.startswith('s3:'):
            csv_file = S3Path.from_string(csv_file)
        # Aggregate rows by genes and stmt type
        logger.info("Saving to CSV...")
        filt_df = df.filter(items=[
            'agA_ns', 'agA_id', 'agA_name', 'agB_ns', 'agB_id', 'agB_name',
            'stmt_type', 'evidence_count'
        ])
        type_counts = filt_df.groupby(by=[
            'agA_ns', 'agA_id', 'agA_name', 'agB_ns', 'agB_id', 'agB_name',
            'stmt_type'
        ]).sum()
        # This requires package s3fs under the hood. See:
        # https://pandas.pydata.org/pandas-docs/stable/whatsnew/v0.20.0.html#s3-file-handling
        if isinstance(csv_file, S3Path):
            try:
                type_counts.to_csv(csv_file.to_string())
            except Exception as e:
                try:
                    logger.warning('Failed to upload csv to s3 using direct '
                                   's3 url, trying boto3: %s.' % e)
                    s3 = get_s3_client(unsigned=False)
                    csv_buf = StringIO()
                    type_counts.to_csv(csv_buf)
                    s3.put_object(Body=csv_buf.getvalue(), **csv_file.kw())
                    logger.info('Uploaded CSV file to s3')
                except Exception as second_e:
                    logger.error('Failed to upload csv file with fallback '
                                 'method')
                    logger.exception(second_e)
        # save locally
        else:
            type_counts.to_csv(csv_file)

    if src_count_file:
        _ = get_source_counts(src_count_file, ro=ro)
    return
Пример #6
0
def get_s3_dump(force_update=False, include_config=True):
    if not CONFIG or force_update:
        _load(include_config)

    if 's3_dump' not in CONFIG:
        return None

    return S3Path(CONFIG['s3_dump']['bucket'], CONFIG['s3_dump'].get('prefix'))
Пример #7
0
    def get_s3_path(self) -> S3Path:
        """Return an S3Path object of the saved s3 location

        Returns
        -------
        S3Path
        """
        if self.s3_location is None:
            raise ValueError('s3_location is not set')
        return S3Path.from_string(self.s3_location)
Пример #8
0
 def is_dir(path: str):
     if path.startswith('s3://'):
         from indra_db.util.s3_path import S3Path
         from .aws import get_s3_client
         s3 = get_s3_client(False)
         s3dp = S3Path.from_string(path)
         if not s3dp.exists(s3):
             raise ValueError(f'Path {path} does not seem to exists')
     else:
         dp = Path(path)
         if not dp.is_dir():
             raise ValueError(f'Path {path} does not exist')
     return path
Пример #9
0
def s3_file_opener(s3_url: str, unsigned: bool = False, **kwargs) -> \
        Union[object, pd.DataFrame, Dict]:
    """Open a file from s3 given a standard s3-path

    kwargs are only relevant for csv/tsv files and are used for pd.read_csv()

    Parameters
    ----------
    s3_url : str
        S3 url of the format 's3://<bucket>/<key>'. The key is assumed to
        also contain a file ending
    unsigned : bool
        If True, perform S3 calls unsigned. Default: False

    Returns
    -------
    Union[object, pd.DataFrame, Dict]
        Object stored on S3
    """
    from indra_db.util.s3_path import S3Path
    from .aws import load_pickle_from_s3, read_json_from_s3, get_s3_client
    logger.info(f'Loading {s3_url} from s3')
    s3_path = S3Path.from_string(s3_url)
    s3 = get_s3_client(unsigned=unsigned)
    bucket, key = s3_path.bucket, s3_path.key
    if key.endswith('.json'):
        return read_json_from_s3(s3=s3, key=key, bucket=bucket)
    elif key.endswith('.pkl'):
        return load_pickle_from_s3(s3=s3, key=key, bucket=bucket)
    elif key.endswith(('.csv', '.tsv')):
        fileio = S3Path.from_string(s3_url).get(s3=s3)
        csv_str = fileio['Body'].read().decode('utf-8')
        raw_file = StringIO(csv_str)
        return pd.read_csv(raw_file, **kwargs)
    else:
        logger.warning(f'File type {key.split(".")[-1]} not recognized, '
                       f'returning S3 file stream handler (access from '
                       f'`res["Body"].read()`)')
        return S3Path.from_string(s3_url).get(s3=s3)
Пример #10
0
 def check_path(fpath: str):
     if fpath.startswith('s3://'):
         if file_ending and not fpath.endswith(file_ending):
             raise ValueError(f'Unrecognized file type '
                              f'{fpath.split("/")[-1]}')
         from indra_db.util.s3_path import S3Path
         from .aws import get_s3_client
         if not S3Path.from_string(fpath).exists(s3=get_s3_client(False)):
             raise ValueError(f'File {fpath} does not exist')
         return fpath
     p = Path(fpath)
     if not p.is_file():
         raise ValueError(f'File {fpath} does not exist')
     if file_ending and not p.name.endswith(file_ending):
         raise ValueError(f'Unrecognized file type {p.name.split(".")[-1]}')
     return fpath
Пример #11
0
 def _load_file(path):
     if isinstance(path, str) and path.startswith('s3:') or \
             isinstance(path, S3Path):
         if isinstance(path, str):
             s3path = S3Path.from_string(path)
         else:
             s3path = path
         if s3path.to_string().endswith('pkl'):
             return load_pickle_from_s3(s3path)
         elif s3path.to_string().endswith('json'):
             return load_json_from_s3(s3path)
         else:
             raise ValueError(f'Unknown file format of {path}')
     else:
         if path.endswith('pkl'):
             with open(path, 'rb') as f:
                 return pickle.load(f)
         elif path.endswith('json'):
             with open(path, 'r') as f:
                 return json.load(f)
Пример #12
0
def get_source_counts(pkl_filename=None, ro=None):
    """Returns a dict of dicts with evidence count per source, per statement

    The dictionary is at the top level keyed by statement hash and each
    entry contains a dictionary keyed by the source that support the
    statement where the entries are the evidence count for that source."""
    logger.info('Getting source counts per statement')
    if isinstance(pkl_filename, str) and pkl_filename.startswith('s3:'):
        pkl_filename = S3Path.from_string(pkl_filename)
    if not ro:
        ro = get_ro('primary-ro')
    ev = {h: j for h, j in ro.select_all([ro.SourceMeta.mk_hash,
                                          ro.SourceMeta.src_json])}

    if pkl_filename:
        if isinstance(pkl_filename, S3Path):
            upload_pickle_to_s3(obj=ev, s3_path=pkl_filename)
        else:
            with open(pkl_filename, 'wb') as f:
                pickle.dump(ev, f)
    return ev
Пример #13
0
def load_db_content(ns_list, pkl_filename=None, ro=None, reload=False):
    """Get preassembled stmt metadata from the DB for export.

    Queries the NameMeta, TextMeta, and OtherMeta tables as needed to get
    agent/stmt metadata for agents from the given namespaces.

    Parameters
    ----------
    ns_list : list of str
        List of agent namespaces to include in the metadata query.
    pkl_filename : str
        Name of pickle file to save to (if reloading) or load from (if not
        reloading). If an S3 path is given (i.e., pkl_filename starts with
        `s3:`), the file is loaded to/saved from S3. If not given,
        automatically reloads the content (overriding reload).
    ro : ReadonlyDatabaseManager
        Readonly database to load the content from. If not given, calls
        `get_ro('primary')` to get the primary readonly DB.
    reload : bool
        Whether to re-query the database for content or to load the content
        from from `pkl_filename`. Note that even if `reload` is False,
        if no `pkl_filename` is given, data will be reloaded anyway.

    Returns
    -------
    set of tuples
        Set of tuples containing statement information organized
        by agent. Tuples contain (stmt_hash, agent_ns, agent_id, agent_num,
        evidence_count, stmt_type).
    """
    if isinstance(pkl_filename, str) and pkl_filename.startswith('s3:'):
        pkl_filename = S3Path.from_string(pkl_filename)
    # Get the raw data
    if reload or not pkl_filename:
        if not ro:
            ro = get_ro('primary')
        logger.info("Querying the database for statement metadata...")
        results = {}
        for ns in ns_list:
            logger.info("Querying for {ns}".format(ns=ns))
            filters = []
            if ns == 'NAME':
                tbl = ro.NameMeta
            elif ns == 'TEXT':
                tbl = ro.TextMeta
            else:
                tbl = ro.OtherMeta
                filters.append(tbl.db_name.like(ns))
            filters.append(tbl.is_complex_dup == False)
            res = ro.select_all([tbl.mk_hash, tbl.db_id, tbl.ag_num,
                                 tbl.ev_count, tbl.type_num], *filters)
            results[ns] = res
        results = {(h, dbn, dbi, ag_num, ev_cnt, ro_type_map.get_str(tn))
                   for dbn, value_list in results.items()
                   for h, dbi, ag_num, ev_cnt, tn in value_list}
        if pkl_filename:
            if isinstance(pkl_filename, S3Path):
                upload_pickle_to_s3(results, pkl_filename)
            else:
                with open(pkl_filename, 'wb') as f:
                    pickle.dump(results, f)
    # Get a cached pickle
    else:
        logger.info("Loading database content from %s" % pkl_filename)
        if pkl_filename.startswith('s3:'):
            results = load_pickle_from_s3(pkl_filename)
        else:
            with open(pkl_filename, 'rb') as f:
                results = pickle.load(f)
    logger.info("{len} stmts loaded".format(len=len(results)))
    return results
Пример #14
0
def dump_sif(src_count_file, res_pos_file, belief_file, df_file=None,
             db_res_file=None, csv_file=None, reload=True, reconvert=True,
             ro=None, normalize_names: bool = True):
    """Build and dump a sif dataframe of PA statements with grounded agents

    Parameters
    ----------
    src_count_file : Union[str, S3Path]
        A location to load the source count dict from. Can be local file
        path, an s3 url string or an S3Path instance.
    res_pos_file : Union[str, S3Path]
        A location to load the residue-postion dict from. Can be local file
        path, an s3 url string or an S3Path instance.
    belief_file : Union[str, S3Path]
        A location to load the belief dict from. Can be local file path,
        an s3 url string or an S3Path instance.
    df_file : Optional[Union[str, S3Path]]
        If provided, dump the sif to this location. Can be local file path,
        an s3 url string or an S3Path instance.
    db_res_file : Optional[Union[str, S3Path]]
        If provided, save the db content to this location. Can be local file
        path, an s3 url string or an S3Path instance.
    csv_file : Optional[str, S3Path]
        If provided, calculate dataframe statistics and save to local file
        or s3. Can be local file path, an s3 url string or an S3Path instance.
    reconvert : bool
        Whether to generate a new DataFrame from the database content or
        to load and return a DataFrame from `df_file`. If False, `df_file`
        must be given. Default: True.
    reload : bool
        If True, load new content from the database and make a new
        dataframe. If False, content can be loaded from provided files.
        Default: True.
    ro : Optional[PrincipalDatabaseManager]
        Provide a DatabaseManager to load database content from. If not
        provided, `get_ro('primary')` will be used.
    normalize_names :
        If True, detect and try to merge name duplicates (same entity with
        different names, e.g. Loratadin vs loratadin). Default: False
    """
    def _load_file(path):
        if isinstance(path, str) and path.startswith('s3:') or \
                isinstance(path, S3Path):
            if isinstance(path, str):
                s3path = S3Path.from_string(path)
            else:
                s3path = path
            if s3path.to_string().endswith('pkl'):
                return load_pickle_from_s3(s3path)
            elif s3path.to_string().endswith('json'):
                return load_json_from_s3(s3path)
            else:
                raise ValueError(f'Unknown file format of {path}')
        else:
            if path.endswith('pkl'):
                with open(path, 'rb') as f:
                    return pickle.load(f)
            elif path.endswith('json'):
                with open(path, 'r') as f:
                    return json.load(f)

    if ro is None:
        ro = get_db('primary')

    # Get the db content from a new DB dump or from file
    db_content = load_db_content(reload=reload, ns_list=NS_LIST,
                                 pkl_filename=db_res_file, ro=ro)

    # Load supporting files
    res_pos = _load_file(res_pos_file)
    src_count = _load_file(src_count_file)
    belief = _load_file(belief_file)

    # Convert the database query result into a set of pairwise relationships
    df = make_dataframe(pkl_filename=df_file, reconvert=reconvert,
                        db_content=db_content, src_count_dict=src_count,
                        res_pos_dict=res_pos, belief_dict=belief,
                        normalize_names=normalize_names)

    if csv_file:
        if isinstance(csv_file, str) and csv_file.startswith('s3:'):
            csv_file = S3Path.from_string(csv_file)
        # Aggregate rows by genes and stmt type
        logger.info("Saving to CSV...")
        filt_df = df.filter(items=['agA_ns', 'agA_id', 'agA_name',
                                   'agB_ns', 'agB_id', 'agB_name',
                                   'stmt_type', 'evidence_count'])
        type_counts = filt_df.groupby(by=['agA_ns', 'agA_id', 'agA_name',
                                          'agB_ns', 'agB_id', 'agB_name',
                                          'stmt_type']).sum()
        # This requires package s3fs under the hood. See:
        # https://pandas.pydata.org/pandas-docs/stable/whatsnew/v0.20.0.html#s3-file-handling
        if isinstance(csv_file, S3Path):
            try:
                type_counts.to_csv(csv_file.to_string())
            except Exception as e:
                try:
                    logger.warning('Failed to upload csv to s3 using direct '
                                   's3 url, trying boto3: %s.' % e)
                    s3 = get_s3_client(unsigned=False)
                    csv_buf = StringIO()
                    type_counts.to_csv(csv_buf)
                    csv_file.upload(s3, csv_buf)
                    logger.info('Uploaded CSV file to s3')
                except Exception as second_e:
                    logger.error('Failed to upload csv file with fallback '
                                 'method')
                    logger.exception(second_e)
        # save locally
        else:
            type_counts.to_csv(csv_file)
    return
Пример #15
0
def _pseudo_key(fname, ymd_date):
    return S3Path.from_key_parts(S3_SIF_BUCKET, S3_SUBDIR, ymd_date, fname)
Пример #16
0
def make_dataframe(reconvert, db_content, res_pos_dict, src_count_dict,
                   belief_dict, pkl_filename=None,
                   normalize_names: bool = False):
    """Make a pickled DataFrame of the db content, one row per stmt.

    Parameters
    ----------
    reconvert : bool
        Whether to generate a new DataFrame from the database content or
        to load and return a DataFrame from the given pickle file. If False,
        `pkl_filename` must be given.
    db_content : set of tuples
        Set of tuples of agent/stmt data as returned by `load_db_content`.
    res_pos_dict : Dict[str, Dict[str, str]]
        Dict containing residue and position keyed by hash.
    src_count_dict : Dict[str, Dict[str, int]]
        Dict of dicts containing source counts per source api keyed by hash.
    belief_dict : Dict[str, float]
        Dict of belief scores keyed by hash.
    pkl_filename : str
        Name of pickle file to save to (if reconverting) or load from (if not
        reconverting). If an S3 path is given (i.e., pkl_filename starts with
        `s3:`), the file is loaded to/saved from S3. If not given,
        reloads the content (overriding reload).
    normalize_names :
        If True, detect and try to merge name duplicates (same entity with
        different names, e.g. Loratadin vs loratadin). Default: False

    Returns
    -------
    pandas.DataFrame
        DataFrame containing the content, with columns: 'agA_ns', 'agA_id',
        'agA_name', 'agB_ns', 'agB_id', 'agB_name', 'stmt_type',
        'evidence_count', 'stmt_hash'.
    """
    if isinstance(pkl_filename, str) and pkl_filename.startswith('s3:'):
        pkl_filename = S3Path.from_string(pkl_filename)
    if reconvert:
        # Content consists of tuples organized by agent, e.g.
        # (-11421523615931377, 'UP', 'P04792', 1, 1, 'Phosphorylation')
        #
        # First we need to organize by statement, collecting all agents
        # for each statement along with evidence count and type.
        # We also separately store the NAME attribute for each statement
        # agent (indexing by hash/agent_num).
        logger.info("Organizing by statement...")
        stmt_info = {} # Store statement info (agents, ev, type) by hash
        ag_name_by_hash_num = {} # Store name for each stmt agent
        for h, db_nm, db_id, num, n, t in tqdm(db_content):
            db_nmn, db_id = fix_id(db_nm, db_id)
            # Populate the 'NAME' dictionary per agent
            if db_nm == 'NAME':
                ag_name_by_hash_num[(h, num)] = db_id
            if h not in stmt_info.keys():
                stmt_info[h] = {'agents': [], 'ev_count': n, 'type': t}
            stmt_info[h]['agents'].append((num, db_nm, db_id))
        # Turn into dataframe with geneA, geneB, type, indexed by hash;
        # expand out complexes to multiple rows

        # Organize by pairs of genes, counting evidence.
        nkey_errors = 0
        error_keys = []
        rows = []
        logger.info("Converting to pairwise entries...")
        # Iterate over each statement
        for hash, info_dict in tqdm(stmt_info.items()):
            # Get the priority grounding for the agents in each position
            agents_by_num = {}
            for num, db_nm, db_id in info_dict['agents']:
                # Agent name is handled separately so we skip it here
                if db_nm == 'NAME':
                    continue
                # For other namespaces, we get the top-priority namespace
                # given all namespaces for the agent
                else:
                    assert db_nm in NS_PRIORITY_LIST
                    db_rank = NS_PRIORITY_LIST.index(db_nm)
                    # If we don't already have an agent for this num, use the
                    # one we've found
                    if num not in agents_by_num:
                        agents_by_num[num] = (num, db_nm, db_id, db_rank)
                    # Otherwise, take the current agent if the identifier type
                    # has a higher rank
                    else:
                        cur_rank = agents_by_num[num][3]
                        if db_rank < cur_rank:
                            agents_by_num[num] = (num, db_nm, db_id, db_rank)
            # Make ordered list of agents for this statement, picking up
            # the agent name from the ag_name_by_hash_num dict that we
            # built earlier
            agents = []
            for num, db_nm, db_id, _ in sorted(agents_by_num.values()):
                # Try to get the agent name
                ag_name = ag_name_by_hash_num.get((hash, num), None)
                # If the name is not found, log it but allow the agent
                # to be included as None
                if ag_name is None:
                    nkey_errors += 1
                    error_keys.append((hash, num))
                    if nkey_errors < 11:
                        logger.warning('Missing key in agent name dict: '
                                       '(%s, %s)' % (hash, num))
                    elif nkey_errors == 11:
                        logger.warning('Got more than 10 key warnings: '
                                       'muting further warnings.')
                agents.append((db_nm, db_id, ag_name))

            # Need at least two agents.
            if len(agents) < 2:
                continue

            # If this is a complex, or there are more than two agents, permute!
            if info_dict['type'] == 'Complex':
                # Skip complexes with 4 or more members
                if len(agents) > 3:
                    continue
                pairs = permutations(agents, 2)
            else:
                pairs = [agents]

            # Add all the pairs, and count up total evidence.
            for pair in pairs:
                row = OrderedDict([
                    ('agA_ns', pair[0][0]),
                    ('agA_id', pair[0][1]),
                    ('agA_name', pair[0][2]),
                    ('agB_ns', pair[1][0]),
                    ('agB_id', pair[1][1]),
                    ('agB_name', pair[1][2]),
                    ('stmt_type', info_dict['type']),
                    ('evidence_count', info_dict['ev_count']),
                    ('stmt_hash', hash),
                    ('residue', res_pos_dict['residue'].get(hash)),
                    ('position', res_pos_dict['position'].get(hash)),
                    ('source_counts', src_count_dict.get(hash)),
                    ('belief', belief_dict.get(str(hash)))
                ])
                rows.append(row)
        if nkey_errors:
            ef = 'key_errors.csv'
            logger.warning('%d KeyErrors. Offending keys found in %s' %
                           (nkey_errors, ef))
            with open(ef, 'w') as f:
                f.write('hash,PaMeta.ag_num\n')
                for kn in error_keys:
                    f.write('%s,%s\n' % kn)
        df = pd.DataFrame.from_dict(rows)

        if pkl_filename:
            if isinstance(pkl_filename, S3Path):
                upload_pickle_to_s3(obj=df, s3_path=pkl_filename)
            else:
                with open(pkl_filename, 'wb') as f:
                    pickle.dump(df, f)
    else:
        if not pkl_filename:
            logger.error('Have to provide pickle file if not reconverting')
            raise FileExistsError
        else:
            if isinstance(pkl_filename, S3Path):
                df = load_pickle_from_s3(pkl_filename)
            else:
                with open(pkl_filename, 'rb') as f:
                    df = pickle.load(f)
    if normalize_names:
        normalize_sif_names(sif_df=df)
    return df
        '--single-proc', action='store_true',
        help='Run all scripts on a single process. This option is good when '
             'debugging or if the environment for some reason does not '
             'support multiprocessing. Default: False.'
    )

    args = parser.parse_args()
    base_path: str = args.base_path
    outdir: str = args.outdir
    single_proc: bool = args.single_proc

    s3 = boto3.client('s3')

    # Set input dir
    if base_path.startswith('s3://'):
        s3_base_path = S3Path.from_string(base_path)
        input_iter = \
            [s3p.to_string() for s3p in s3_base_path.list_objects(s3)
             if s3p.to_string().endswith('.pkl')]
    else:
        local_base_path = Path(base_path)
        input_iter = [f.absolute().as_posix()
                      for f in local_base_path.glob('*.pkl')]

    # Set output dir
    if outdir.startswith('s3://'):
        output_dir = S3Path.from_string(outdir)
    else:
        output_dir = Path(outdir)

    dry = args.dry
def _exists(fpath: str) -> bool:
    if fpath.startswith('s3://'):
        return S3Path.from_string(fpath).exists(s3)
    else:
        return Path(fpath).is_file()
Пример #19
0
    def get_corr_stats_axb(self,
                           z_corr: Optional[Union[str, pd.DataFrame]] = None,
                           max_proc: Optional[int] = None,
                           max_so_pairs_size: int = 10000,
                           mp_pairs: bool = True,
                           run_linear: bool = False) -> Results:
        """Get statistics of the correlations from different explanation types

        Note: the provided options have no effect if the data is loaded
        from cache.

        Parameters
        ----------
        z_corr : Optional[Union[pd.DataFrame, str]]
            A pd.DataFrame containing the correlation z scores used to
            create the statistics in this object. Pro
        max_proc : int > 0
            The maximum number of processes to run in the multiprocessing
            in get_corr_stats_mp. Default: multiprocessing.cpu_count()
        max_so_pairs_size : int
            The maximum number of correlation pairs to process. If the
            number of eligible pairs is larger than this number, a random
            sample of max_so_pairs_size is used. Default: 10 000. If the
            number of pairs to check is smaller than 10 000, no sampling is
            done.
        mp_pairs : bool
            If True, get the pairs to process using multiprocessing if larger
            than 10 000. Default: True.
        run_linear : bool
            If True, gather the data without multiprocessing. This option is
            good when debugging or if the environment for some reason does
            not support multiprocessing. Default: False.

        Returns
        -------
        Results
            A BaseModel containing correlation data for different explanations
        """
        if not self.corr_stats_axb:
            s3 = get_s3_client(unsigned=False)
            try:
                corr_stats_loc = self.get_s3_corr_stats_path()
                if S3Path.from_string(corr_stats_loc).exists(s3):
                    logger.info(f'Found corr stats data at {corr_stats_loc}')
                    corr_stats_json = file_opener(corr_stats_loc)
                    self.corr_stats_axb = Results(**corr_stats_json)
                else:
                    logger.info(f'No corr stats data at found at '
                                f'{corr_stats_loc}')
            except ValueError as ve:
                # Raised when s3 location is not set
                logger.warning(ve)

            # If not found on s3 or ValueError was raised
            if not self.corr_stats_axb:
                logger.info('Generating corr stats data')
                # Load correlation matrix
                if z_corr is None:
                    z_corr = self.load_z_corr()
                if isinstance(z_corr, str):
                    z_corr = self.load_z_corr(local_file_path=z_corr)
                # Load reactome if present
                try:
                    reactome = self.load_reactome()
                except FileNotFoundError:
                    logger.info('No reactome file used in script')
                    reactome = None
                self.corr_stats_axb: Results = axb_stats(
                    self.expl_df,
                    self.stats_df,
                    z_corr=z_corr,
                    reactome=reactome,
                    eval_str=False,
                    max_proc=max_proc,
                    max_corr_pairs=max_so_pairs_size,
                    do_mp_pairs=mp_pairs,
                    run_linear=run_linear)
                try:
                    corr_stats_loc = self.get_s3_corr_stats_path()
                    logger.info(f'Uploading corr stats to S3 at '
                                f'{corr_stats_loc}')
                    s3p_loc = S3Path.from_string(corr_stats_loc)
                    s3p_loc.put(s3=s3, body=self.corr_stats_axb.json())
                    logger.info('Finished uploading corr stats to S3')
                except ValueError:
                    logger.warning('Unable to upload corr stats to S3')
        else:
            logger.info('Data already present in corr_stats_axb')
        return self.corr_stats_axb
Пример #20
0
def main(indra_net: str,
         z_score: str,
         outname: str,
         graph_type: str,
         sd_range: Tuple[float, Union[None, float]],
         random: bool = False,
         raw_data: Optional[List[str]] = None,
         raw_corr: Optional[List[str]] = None,
         expl_funcs: Optional[List[str]] = None,
         pb_node_mapping: Optional[Dict[str, Set]] = None,
         n_chunks: Optional[int] = 256,
         is_a_part_of: Optional[List[str]] = None,
         immediate_only: Optional[bool] = False,
         return_unexplained: Optional[bool] = False,
         reactome_path: Optional[str] = None,
         subset_list: Optional[List[Union[str, int]]] = None,
         apriori_explained: Optional[Union[bool, str]] = False,
         allowed_ns: Optional[List[str]] = None,
         allowed_sources: Optional[List[str]] = None,
         info: Optional[Dict[Hashable, Any]] = None,
         indra_date: Optional[str] = None,
         depmap_date: Optional[str] = None,
         sample_size: Optional[int] = None,
         shuffle: Optional[bool] = False,
         overwrite: Optional[bool] = False,
         normalize_names: Optional[bool] = False,
         argparse_dict: Optional[Dict[str, Union[str, float, int,
                                                 List[str]]]] = None):
    """Set up correlation matching of depmap data with an indranet graph

    Parameters
    ----------
    indra_net : Union[nx.DiGraph, nx.MultiDiGraph]
        The graph representation of the indra network. Each edge should
        have an attribute named 'statements' containing a list of sources
        supporting that edge. If signed search, indranet is expected to be an
        nx.MultiDiGraph with edges keyed by (gene, gene, sign) tuples.
    outname : str
        A file path (can be an S3 url) to where to store the final pickle
        file containing the DepmapExplainer
    graph_type : str
        The graph type of the graph used for the explanations. Can be one of
        'unsigned', 'signed', 'pybel'.
    sd_range : Tuple[float, Union[float, None]]
        A tuple of the lower and optionally the upper bound of the z-score
        range to use when getting correlations
    random : bool
        Whether to do a random sampling or not. If True do a random sample
        instead of cutting the correlations of to the given SD range.
    z_score : Union[pd.DataFrame, str]
        The path to the correlation DataFrame. If either raw data or raw
        corr are used, this filepath will be used to save the resulting
        DataFrame instead.
    raw_data : Optional[List[str]]
        File paths to CRISPR raw data and RNAi raw data from the DepMap Portal
    raw_corr : Optional[List[str]]
        File paths to raw correlation data (before z-score conversion)
        containing hdf compressed correlation data. These files contain the
        result of running `raw_df.corr()`.
    expl_funcs : Optional[List[str]]
        Provide a list of explanation functions to apply. Default: All
        functions are applied. Currently available functions:
        - 'expl_ab': Explain pair by checking for an edge between a and b
        - 'expl_ba': Explain pair by checking for an edge between b and a
        - 'expl_axb': Explain pair by looking for intermediate nodes
          connecting a to b
        - 'expl_bxa': Explain pair by looking for intermediate nodes
          connecting b to a
        - 'get_sr': Explain pair by finding common upstream nodes
        - 'get_st': Explain pair by finding common downstream nodes
        - 'get_sd': Explain pair by finding common downstream nodes two
          edges from s and o
        - 'find_cp': Explain pair by looking for ontological parents
        - 'apriori_explained': Map entities to a-priori explanations
        - 'common_reactome_paths': Explain pair by matching common reactome
          pathways
    pb_node_mapping : Optional[Union[Dict, Set[Any]]]
        If graph type is "pybel", use this argument to provide a mapping
        from HGNC symbols to pybel nodes in the pybel model
    n_chunks : Optional[int]
        How many chunks to split the data into in the multiprocessing part
        of the script
    is_a_part_of : Optional[Iterable]
        A set of identifiers to look for when applying the common parent
        explanation between a pair of correlating nodes.
    immediate_only : Optional[bool]
        Only look for immediate parents. This option might limit the number
        of results that are returned. Default: False.
    return_unexplained : Optional[bool]
        If True: return explanation data even if there is no set
        intersection of nodes up- or downstream of A, B for shared
        regulators and shared targets. Default: False.
    reactome_path : Optional[str]
        File path to reactome data.
    subset_list :  Optional[List[Union[str, int]]]
        Provide a list if entities that defines a subset of the entities in
        the correlation data frame that will be picked as 'a' when the pairs
        (a, b) are generated
    apriori_explained : Optional[str]
        A mapping from entity names to a string containing a short
        explanation of why the entity is explained. To use the default
        MitoCarta 3.0 file, run the following code:
        >>> from depmap_analysis.scripts.depmap_script2 import mito_file
        >>> from depmap_analysis.preprocessing import get_mitocarta_info
        >>> apriori_mapping = get_mitocarta_info(mito_file)
        then pass `apriori_mapping` as `apriori_explained` when calling this
        function:
        >>> main(apriori_explained=apriori_mapping, ...)
    allowed_ns : Optional[List[str]]
        A list of allowed name spaces for explanations involving
        intermediary nodes. Default: Any namespace.
    allowed_sources : Optional[List[str]]
        The allowed sources for edges. This will not affect subsequent edges
        in explanations involving 2 or more edges. Default: all sources are
        allowed.
    info : Optional[Dict[Hashable, Any]]
        An optional dict in which to save meta data about this run
    indra_date : Optional[str]
        The date of the sif dump used to create the graph
    depmap_date : Optional[str]
        The date (usually a quarter e.g. 19Q4) the depmap data was published
        on depmap.org
    sample_size : Optional[int]
        Number of correlation pairs to approximately get out of the
        correlation matrix after down sampling it
    shuffle : Optional[bool]
        If True, shuffle the correlation matrix. This is good to do in case
        the input data have some sort of structure that could lead to large
        discrepancies between compute times for the different processes.
        Default: False.
    overwrite : Optional[bool]
        If True, overwrite any output files. Default: False.
    normalize_names : Optional[bool]
        If True, try to normalize the names in the correlation matrix that
        are not found in the provided graph. Default: False.
    argparse_dict : Optional[Dict[str, Union[str, float, int, List[str]]]]
        Provide the argparse options from running this file as a script
    """
    global indranet, hgnc_node_mapping, output_list
    indranet = file_opener(indra_net)
    assert isinstance(indranet, nx.DiGraph)

    assert expl_funcs is None or isinstance(expl_funcs, (list, tuple, set))

    # 1 Check options
    sd_l, sd_u = sd_range if sd_range and len(sd_range) == 2 else \
        ((sd_range[0], None) if sd_range and len(sd_range) == 1 else
         (None, None))

    if not random and not sd_l and not sd_u:
        raise ValueError('Must specify at least a lower bound for the SD '
                         'range or flag run for random explanation')

    if graph_type == 'pybel' and not pb_node_mapping:
        raise ValueError('Must provide PyBEL node mapping with option '
                         'pb_node_mapping if graph type is "pybel"')

    if apriori_explained:
        if apriori_explained is True or mito_file_name in apriori_explained:
            # Run default
            apriori_explained = get_mitocarta_info(mito_file)
        else:
            # Hope it's a csv/tsv
            try:
                expl_df = pd.read_csv(apriori_explained)
                apriori_explained = {
                    e: d
                    for e, d in zip(expl_df.name, expl_df.description)
                }
            except Exception as err:
                raise ValueError('A-priori explained entities must be in a '
                                 'file that can be parsed as CSV/TSV with '
                                 'column names "name" for entity name and '
                                 '"description" for explanation why the '
                                 'entity is explained.') \
                    from err

        logger.info(f'Using explained set with '
                    f'{len(apriori_explained)} entities')

    outname = outname if outname.endswith('.pkl') else \
        outname + '.pkl'
    if not overwrite:
        if outname.startswith('s3://'):
            s3 = get_s3_client(unsigned=False)
            if S3Path.from_string(outname).exists(s3):
                raise FileExistsError(f'File {str(outname)} already exists!')
        elif Path(outname).is_file():
            raise FileExistsError(f'File {str(outname)} already exists!')

    if z_score is not None and Path(z_score).is_file():
        z_corr = pd.read_hdf(z_score)
    else:
        z_sc_options = {
            'crispr_raw': raw_data[0],
            'rnai_raw': raw_data[1],
            'crispr_corr': raw_corr[0],
            'rnai_corr': raw_corr[1],
            'z_corr_path': z_score
        }
        z_corr = run_corr_merge(**z_sc_options)

    if reactome_path:
        up2path, _, pathid2pathname = file_opener(reactome_path)
        reactome_dict = {
            'uniprot_mapping': up2path,
            'pathid_name_mapping': pathid2pathname
        }
    else:
        reactome_dict = None

    # Get mapping of correlation names to pybel nodes
    if graph_type == 'pybel':
        if isinstance(pb_node_mapping, dict):
            hgnc_node_mapping = pb_node_mapping
        elif isinstance(pb_node_mapping, str) and \
                Path(pb_node_mapping).is_file():
            hgnc_node_mapping = file_opener(pb_node_mapping)
        else:
            raise ValueError('Could not load pybel node mapping')

    # 2. Filter to SD range OR run random sampling
    if random:
        logger.info('Doing random sampling through df.sample')
        z_corr = z_corr.sample(142, axis=0)
        z_corr = z_corr.filter(list(z_corr.index), axis=1)
        # Remove correlation values to not confuse with real data
        z_corr.loc[:, :] = 0
    else:
        if sd_l and sd_u:
            logger.info(f'Filtering correlations to {sd_l} - {sd_u} SD')
            z_corr = z_corr[((z_corr > sd_l) & (z_corr < sd_u)) |
                            ((z_corr < -sd_l) & (z_corr > -sd_u))]
        elif isinstance(sd_l, (int, float)) and sd_l and not sd_u:
            logger.info(f'Filtering correlations to {sd_l}+ SD')
            z_corr = z_corr[(z_corr > sd_l) | (z_corr < -sd_l)]

    sd_range = (sd_l, sd_u) if sd_u else (sd_l, None)

    # Pick a sample
    if sample_size is not None and not random:
        logger.info(f'Reducing correlation matrix to a random approximately '
                    f'{sample_size} correlation pairs.')
        z_corr = down_sample_df(z_corr, sample_size)

    # Shuffle corr matrix without removing items
    elif shuffle and not random:
        logger.info('Shuffling correlation matrix...')
        z_corr = z_corr.sample(frac=1, axis=0)
        z_corr = z_corr.filter(list(z_corr.index), axis=1)

    if normalize_names:
        logger.info('Normalizing correlation matrix column names')
        z_corr = normalize_corr_names(z_corr, indranet)
    else:
        logger.info('Leaving correlation matrix column names as is')

    # 4. Add meta data
    info_dict = {}
    if info:
        info_dict['info'] = info

    # Set the script_settings
    script_settings = {
        'raw_data':
        raw_data,
        'raw_corr':
        raw_corr,
        'z_score':
        z_score,
        'random':
        random,
        'indranet':
        indra_net,
        'shuffle':
        shuffle,
        'sample_size':
        sample_size,
        'n_chunks':
        n_chunks,
        'outname':
        outname,
        'apriori_explained':
        apriori_explained if isinstance(apriori_explained, str) else 'no info',
        'graph_type':
        graph_type,
        'pybel_node_mapping':
        pb_node_mapping if isinstance(pb_node_mapping, str) else 'no info',
        'argparse_info':
        argparse_dict
    }

    # Create output list in global scope
    output_list = []
    explanations = match_correlations(corr_z=z_corr,
                                      sd_range=sd_range,
                                      script_settings=script_settings,
                                      graph_filepath=indra_net,
                                      z_corr_filepath=z_score,
                                      apriori_explained=apriori_explained,
                                      graph_type=graph_type,
                                      allowed_ns=allowed_ns,
                                      allowed_sources=allowed_sources,
                                      is_a_part_of=is_a_part_of,
                                      expl_funcs=expl_funcs,
                                      reactome_filepath=reactome_path,
                                      indra_date=indra_date,
                                      info=info_dict,
                                      depmap_date=depmap_date,
                                      n_chunks=n_chunks,
                                      immediate_only=immediate_only,
                                      return_unexplained=return_unexplained,
                                      reactome_dict=reactome_dict,
                                      subset_list=subset_list)
    if outname.startswith('s3://'):
        try:
            logger.info(f'Uploading results to s3: {outname}')
            s3 = get_s3_client(unsigned=False)
            s3outpath = S3Path.from_string(outname)
            explanations.s3_location = s3outpath.to_string()
            s3outpath.upload(s3=s3, body=pickle.dumps(explanations))
            logger.info('Finished uploading results to s3')
        except Exception:
            new_path = Path(outname.replace('s3://', ''))
            logger.warning(f'Something went wrong in s3 upload, trying to '
                           f'save locally instead to {new_path}')
            new_path.parent.mkdir(parents=True, exist_ok=True)
            dump_it_to_pickle(fname=new_path.absolute().resolve().as_posix(),
                              pyobj=explanations,
                              overwrite=overwrite)

    else:
        # mkdir in case it doesn't exist
        outpath = Path(outname)
        logger.info(f'Dumping results to {outpath}')
        outpath.parent.mkdir(parents=True, exist_ok=True)
        dump_it_to_pickle(fname=outpath.absolute().resolve().as_posix(),
                          pyobj=explanations,
                          overwrite=overwrite)
    logger.info('Script finished')
    explanations.summarize()
Пример #21
0
    def plot_corr_stats(self,
                        outdir: str,
                        z_corr: Optional[Union[str, pd.DataFrame]] = None,
                        show_plot: bool = False,
                        max_proc: bool = None,
                        index_counter: Optional[Union[Iterator,
                                                      Generator]] = None,
                        max_so_pairs_size: int = 10000,
                        mp_pairs: bool = True,
                        run_linear: bool = False,
                        log_scale_y: bool = False):
        """Plot the results of running explainer.get_corr_stats_axb()

        Parameters
        ----------
        outdir : str
            The output directory to save the plots in. If string starts with
            's3://' upload to s3. outdir must then have the form
            's3://<bucket>/<sub_dir>' where <bucket> must be specified and
            <sub_dir> is optional and may contain subdirectories.
        z_corr : Union[str, pd.DataFrame]
            A pd.DataFrame containing the correlation z scores used to
            create the statistics in this object. If not provided,
            an attempt will be made to load it from the file path present in
            script_settings.
        show_plot : bool
            If True, also show plots after saving them. Default False.
        max_proc : int > 0
            The maximum number of processes to run in the multiprocessing in
            get_corr_stats_mp. Default: multiprocessing.cpu_count()
        index_counter : Union[Iterator, Generator]
            An object which produces a new int by using 'next()' on it. The
            integers are used to separate the figures so as to not append
            new plots in the same figure.
        max_so_pairs_size : int
            The maximum number of correlation pairs to process. If the
            number of eligible pairs is larger than this number, a random
            sample of max_so_pairs_size is used. Default: 10 000.
        mp_pairs : bool
            If True, get the pairs to process using multiprocessing if larger
            than 10 000. Default: True.
        run_linear : bool
            If True, gather the data without multiprocessing. This option is
            good when debugging or if the environment for some reason does
            not support multiprocessing. Default: False.
        log_scale_y : bool
            If True, plot the plots in this method with log10 scale on y-axis.
            Default: False.
        """
        # Local file or s3
        if outdir.startswith('s3://'):
            s3_path = S3Path.from_string(outdir)
            logger.info(f'Outdir path is on S3: {str(s3_path)}')
            od = None
        else:
            s3_path = None
            od = Path(outdir)
            if not od.is_dir():
                logger.info(f'Creating directory/ies for {od}')
                od.mkdir(parents=True, exist_ok=True)

        # Get corr stats
        corr_stats: Results = self.get_corr_stats_axb(
            z_corr=z_corr,
            max_proc=max_proc,
            max_so_pairs_size=max_so_pairs_size,
            mp_pairs=mp_pairs,
            run_linear=run_linear)
        sd_str = self.get_sd_str()
        for m, (plot_type, data) in enumerate(corr_stats.dict().items()):
            if len(data) > 0:
                name = f'{plot_type}_{self.script_settings["graph_type"]}.pdf'
                logger.info(f'Using file name {name}')
                if od is None:
                    fname = BytesIO()
                else:
                    fname = od.joinpath(name).as_posix()
                if isinstance(data[0], tuple):
                    data = [t[-1] for t in data]

                fig_index = next(index_counter) if index_counter else m
                plt.figure(fig_index)
                plt.hist(x=data, bins='auto', log=log_scale_y)
                title = f'{plot_type.replace("_", " ").capitalize()}; '\
                        f'{sd_str} {self.script_settings["graph_type"]}'

                plt.title(title)
                plt.xlabel('combined z-score')
                plt.ylabel('count')

                # Save to file or ByteIO and S3
                plt.savefig(fname, format='pdf')
                if od is None:
                    # Reset pointer
                    fname.seek(0)
                    # Upload to s3
                    full_s3_path = _joinpath(s3_path, name)
                    _upload_bytes_io_to_s3(bytes_io_obj=fname,
                                           s3p=full_s3_path)

                # Show plot
                if show_plot:
                    plt.show()

                # Close figure
                plt.close(fig_index)
            else:
                logger.warning(f'Empty result for {plot_type} in '
                               f'range {sd_str} for graph type '
                               f'{self.script_settings["graph_type"]}')
Пример #22
0
    def plot_interesting(self,
                         outdir: str,
                         z_corr: Optional[Union[str, pd.DataFrame]] = None,
                         show_plot: Optional[bool] = False,
                         max_proc: Optional[int] = None,
                         index_counter: Optional[Union[Iterator,
                                                       Generator]] = None,
                         max_so_pairs_size: int = 10000,
                         mp_pairs: bool = True,
                         run_linear: bool = False,
                         log_scale_y: bool = False):
        """Plots the same type of plot as plot_dists, but filters A, B

        A, B are filtered to those that fulfill the following:
            - No a-b or b-a explanations
            - Not explained by apriori explanations
            - Without common reactome pathways
            - With a-x-b, b-x-a or shared target explanation

        Parameters
        ----------
        outdir : str
            The output directory to save the plots in. If string starts with
            's3://' upload to s3. outdir must then have the form
            's3://<bucket>/<sub_dir>' where <bucket> must be specified and
            <sub_dir> is optional and may contain subdirectories.
        z_corr : Union[str, pd.DataFrame]
            A pd.DataFrame containing the correlation z scores used to
            create the statistics in this object. If not provided,
            an attempt will be made to load it from the file path present in
            script_settings.
        show_plot : bool
            If True also show plots
        max_proc : int > 0
            The maximum number of processes to run in the multiprocessing in
            get_corr_stats_mp. Default: multiprocessing.cpu_count()
        index_counter : Union[Iterator, Generator]
            An object which produces a new int by using 'next()' on it. The
            integers are used to separate the figures so as to not append
            new plots in the same figure.
        max_so_pairs_size : int
            The maximum number of correlation pairs to process. If the
            number of eligible pairs is larger than this number, a random
            sample of max_so_pairs_size is used. Default: 10000.
        mp_pairs : bool
            If True, get the pairs to process using multiprocessing if larger
            than 10 000. Default: True.
        run_linear : bool
            If True, gather the data without multiprocessing. This option is
            good when debugging or if the environment for some reason does
            not support multiprocessing. Default: False.
        log_scale_y : bool
            If True, plot the plots in this method with log10 scale on y-axis.
            Default: False.
        """
        # Local file or s3
        if outdir.startswith('s3://'):
            s3_path = S3Path.from_string(outdir)
            od = None
        else:
            s3_path = None
            od = Path(outdir)
            if not od.is_dir():
                od.mkdir(parents=True, exist_ok=True)

        # Get corr stats
        corr_stats: Results = self.get_corr_stats_axb(
            z_corr=z_corr,
            max_proc=max_proc,
            max_so_pairs_size=max_so_pairs_size,
            mp_pairs=mp_pairs,
            run_linear=run_linear)
        fig_index = next(index_counter) if index_counter \
            else floor(datetime.timestamp(datetime.utcnow()))
        plt.figure(fig_index)
        plt.hist(corr_stats.azfb_avg_corrs,
                 bins='auto',
                 density=True,
                 color='b',
                 alpha=0.3,
                 log=log_scale_y)
        plt.hist(corr_stats.avg_x_filtered_corrs,
                 bins='auto',
                 density=True,
                 color='r',
                 alpha=0.3,
                 log=log_scale_y)
        legend = [
            'Filtered A-X-B for any X', 'Filtered A-X-B for X in network'
        ]

        sd_str = self.get_sd_str()
        title = f'avg X corrs, filtered {sd_str} ' \
                f'({self.script_settings["graph_type"]})'
        plt.title(title)
        plt.ylabel('Norm. Density')
        plt.xlabel('mean(abs(corr(a,x)), abs(corr(x,b))) (SD)')
        plt.legend(legend)
        name = '%s_%s_axb_filtered_hist_comparison.pdf' % \
               (sd_str, self.script_settings['graph_type'])

        # Save to file or ByteIO and S3
        if od is None:
            fname = BytesIO()
        else:
            fname = od.joinpath(name).as_posix()
        plt.savefig(fname, format='pdf')
        if od is None:
            # Reset pointer
            fname.seek(0)
            # Upload to s3
            full_s3_path = _joinpath(s3_path, name)
            _upload_bytes_io_to_s3(bytes_io_obj=fname, s3p=full_s3_path)

        # Show plot
        if show_plot:
            plt.show()

        # Close figure
        plt.close(fig_index)
Пример #23
0
def make_dataframe(reconvert, db_content, pkl_filename=None):
    if isinstance(pkl_filename, str) and pkl_filename.startswith('s3:'):
        pkl_filename = S3Path.from_string(pkl_filename)
    if reconvert:
        # Organize by statement
        logger.info("Organizing by statement...")
        stmt_info = {}
        ag_name_by_hash_num = {}
        for h, db_nm, db_id, num, n, t in db_content:
            # Populate the 'NAME' dictionary per agent
            if db_nm == 'NAME':
                ag_name_by_hash_num[(h, num)] = db_id
            if h not in stmt_info.keys():
                stmt_info[h] = {'agents': [], 'ev_count': n, 'type': t}
            stmt_info[h]['agents'].append((num, db_nm, db_id))
        # Turn into dataframe with geneA, geneB, type, indexed by hash;
        # expand out complexes to multiple rows

        # Organize by pairs of genes, counting evidence.
        nkey_errors = 0
        error_keys = []
        rows = []
        logger.info("Converting to pairwise entries...")
        for hash, info_dict in stmt_info.items():
            # Find roles with more than one agent
            agents_by_num = {}
            for num, db_nm, db_id in info_dict['agents']:
                if db_nm == 'NAME':
                    continue
                else:
                    assert db_nm in NS_PRIORITY_LIST
                    db_rank = NS_PRIORITY_LIST.index(db_nm)
                    # If we don't already have an agent for this num, use the
                    # one we've found
                    if num not in agents_by_num:
                        agents_by_num[num] = (num, db_nm, db_id, db_rank)
                    # Otherwise, take the current agent if the identifier type
                    # has a higher rank
                    else:
                        cur_rank = agents_by_num[num][3]
                        if db_rank < cur_rank:
                            agents_by_num[num] = (num, db_nm, db_id, db_rank)

            agents = []
            for num, db_nm, db_id, _ in sorted(agents_by_num.values()):
                try:
                    agents.append(
                        (db_nm, db_id, ag_name_by_hash_num[(hash, num)]))
                except KeyError:
                    nkey_errors += 1
                    error_keys.append((hash, num))
                    if nkey_errors < 11:
                        logger.warning('Missing key in agent name dict: '
                                       '(%s, %s)' % (hash, num))
                    elif nkey_errors == 11:
                        logger.warning('Got more than 10 key warnings: '
                                       'muting further warnings.')
                    continue

            # Need at least two agents.
            if len(agents) < 2:
                continue

            # If this is a complex, or there are more than two agents, permute!
            if info_dict['type'] == 'Complex':
                # Skip complexes with 4 or more members
                if len(agents) > 3:
                    continue
                pairs = permutations(agents, 2)
            else:
                pairs = [agents]

            # Add all the pairs, and count up total evidence.
            for pair in pairs:
                row = OrderedDict([('agA_ns', pair[0][0]),
                                   ('agA_id', pair[0][1]),
                                   ('agA_name', pair[0][2]),
                                   ('agB_ns', pair[1][0]),
                                   ('agB_id', pair[1][1]),
                                   ('agB_name', pair[1][2]),
                                   ('stmt_type', info_dict['type']),
                                   ('evidence_count', info_dict['ev_count']),
                                   ('stmt_hash', hash)])
                rows.append(row)
        if nkey_errors:
            ef = 'key_errors.csv'
            logger.warning('%d KeyErrors. Offending keys found in %s' %
                           (nkey_errors, ef))
            with open(ef, 'w') as f:
                f.write('hash,PaMeta.ag_num\n')
                for kn in error_keys:
                    f.write('%s,%s\n' % kn)
        df = pd.DataFrame.from_dict(rows)

        if pkl_filename:
            if isinstance(pkl_filename, S3Path):
                upload_pickle_to_s3(obj=df, s3_path=pkl_filename)
            else:
                with open(pkl_filename, 'wb') as f:
                    pickle.dump(df, f)
    else:
        if not pkl_filename:
            logger.error('Have to provide pickle file if not reconverting')
            raise FileExistsError
        else:
            if isinstance(pkl_filename, S3Path):
                df = load_pickle_from_s3(pkl_filename)
            else:
                with open(pkl_filename, 'rb') as f:
                    df = pickle.load(f)
    return df