def _save(fpath: str, expl_inst: DepMapExplainer):
    if fpath.startswith('s3://'):
        try:
            _ = expl_inst.s3_location
        except AttributeError:
            # For backwards compatibility
            expl_inst.s3_location = fpath
        logger.info(f'Uploading to {expl_inst.s3_location}')
        s3p = expl_inst.get_s3_path()
        s3p.upload(s3=s3, body=pickle.dumps(expl_inst))
    else:
        # Just dump to local pickle
        dump_it_to_pickle(fname=fpath, pyobj=expl_inst)
Exemplo n.º 2
0
def load_indra_graph(dir_graph_path,
                     multi_digraph_path=None,
                     update=False,
                     belief_dict=None,
                     strat_ev_dict=None,
                     include_entity_hierarchies=True,
                     verbosity=0):
    """Return a nx.DiGraph and nx.MultiDiGraph representation an INDRA DB dump

    If update is True, make a fresh snapshot from the INDRA DB.
    WARNING: this typically requires a lot of RAM and might slow down your
    computer significantly.
    """
    global INDRA_DG_CACHE, INDRA_MDG_CACHE
    indra_multi_digraph = nx.MultiDiGraph()
    if update:
        df = make_dataframe(True, load_db_content(True, NS_LIST))
        options = {
            'df': df,
            'belief_dict': belief_dict,
            'strat_ev_dict': strat_ev_dict,
            'include_entity_hierarchies': include_entity_hierarchies,
            'verbosity': verbosity
        }
        indra_dir_graph = nf.sif_dump_df_to_nx_digraph(multi=False, **options)
        dump_it_to_pickle(dir_graph_path, indra_dir_graph)
        INDRA_DG_CACHE = path.join(CACHE, dir_graph_path)
        if multi_digraph_path:
            indra_multi_digraph = nf.sif_dump_df_to_nx_digraph(multi=True,
                                                               **options)
            dump_it_to_pickle(multi_digraph_path, indra_multi_digraph)
            INDRA_MDG_CACHE = path.join(CACHE, multi_digraph_path)
    else:
        logger.info('Loading indra networks %s %s' %
                    (dir_graph_path, 'and ' +
                     multi_digraph_path if multi_digraph_path else ''))
        indra_dir_graph = pickle_open(dir_graph_path)
        if multi_digraph_path:
            indra_multi_digraph = pickle_open(multi_digraph_path)
        logger.info('Finished loading indra networks.')
    return indra_dir_graph, indra_multi_digraph
Exemplo n.º 3
0
def main(indra_net: str,
         z_score: str,
         outname: str,
         graph_type: str,
         sd_range: Tuple[float, Union[None, float]],
         random: bool = False,
         raw_data: Optional[List[str]] = None,
         raw_corr: Optional[List[str]] = None,
         expl_funcs: Optional[List[str]] = None,
         pb_node_mapping: Optional[Dict[str, Set]] = None,
         n_chunks: Optional[int] = 256,
         is_a_part_of: Optional[List[str]] = None,
         immediate_only: Optional[bool] = False,
         return_unexplained: Optional[bool] = False,
         reactome_path: Optional[str] = None,
         subset_list: Optional[List[Union[str, int]]] = None,
         apriori_explained: Optional[Union[bool, str]] = False,
         allowed_ns: Optional[List[str]] = None,
         allowed_sources: Optional[List[str]] = None,
         info: Optional[Dict[Hashable, Any]] = None,
         indra_date: Optional[str] = None,
         depmap_date: Optional[str] = None,
         sample_size: Optional[int] = None,
         shuffle: Optional[bool] = False,
         overwrite: Optional[bool] = False,
         normalize_names: Optional[bool] = False,
         argparse_dict: Optional[Dict[str, Union[str, float, int,
                                                 List[str]]]] = None):
    """Set up correlation matching of depmap data with an indranet graph

    Parameters
    ----------
    indra_net : Union[nx.DiGraph, nx.MultiDiGraph]
        The graph representation of the indra network. Each edge should
        have an attribute named 'statements' containing a list of sources
        supporting that edge. If signed search, indranet is expected to be an
        nx.MultiDiGraph with edges keyed by (gene, gene, sign) tuples.
    outname : str
        A file path (can be an S3 url) to where to store the final pickle
        file containing the DepmapExplainer
    graph_type : str
        The graph type of the graph used for the explanations. Can be one of
        'unsigned', 'signed', 'pybel'.
    sd_range : Tuple[float, Union[float, None]]
        A tuple of the lower and optionally the upper bound of the z-score
        range to use when getting correlations
    random : bool
        Whether to do a random sampling or not. If True do a random sample
        instead of cutting the correlations of to the given SD range.
    z_score : Union[pd.DataFrame, str]
        The path to the correlation DataFrame. If either raw data or raw
        corr are used, this filepath will be used to save the resulting
        DataFrame instead.
    raw_data : Optional[List[str]]
        File paths to CRISPR raw data and RNAi raw data from the DepMap Portal
    raw_corr : Optional[List[str]]
        File paths to raw correlation data (before z-score conversion)
        containing hdf compressed correlation data. These files contain the
        result of running `raw_df.corr()`.
    expl_funcs : Optional[List[str]]
        Provide a list of explanation functions to apply. Default: All
        functions are applied. Currently available functions:
        - 'expl_ab': Explain pair by checking for an edge between a and b
        - 'expl_ba': Explain pair by checking for an edge between b and a
        - 'expl_axb': Explain pair by looking for intermediate nodes
          connecting a to b
        - 'expl_bxa': Explain pair by looking for intermediate nodes
          connecting b to a
        - 'get_sr': Explain pair by finding common upstream nodes
        - 'get_st': Explain pair by finding common downstream nodes
        - 'get_sd': Explain pair by finding common downstream nodes two
          edges from s and o
        - 'find_cp': Explain pair by looking for ontological parents
        - 'apriori_explained': Map entities to a-priori explanations
        - 'common_reactome_paths': Explain pair by matching common reactome
          pathways
    pb_node_mapping : Optional[Union[Dict, Set[Any]]]
        If graph type is "pybel", use this argument to provide a mapping
        from HGNC symbols to pybel nodes in the pybel model
    n_chunks : Optional[int]
        How many chunks to split the data into in the multiprocessing part
        of the script
    is_a_part_of : Optional[Iterable]
        A set of identifiers to look for when applying the common parent
        explanation between a pair of correlating nodes.
    immediate_only : Optional[bool]
        Only look for immediate parents. This option might limit the number
        of results that are returned. Default: False.
    return_unexplained : Optional[bool]
        If True: return explanation data even if there is no set
        intersection of nodes up- or downstream of A, B for shared
        regulators and shared targets. Default: False.
    reactome_path : Optional[str]
        File path to reactome data.
    subset_list :  Optional[List[Union[str, int]]]
        Provide a list if entities that defines a subset of the entities in
        the correlation data frame that will be picked as 'a' when the pairs
        (a, b) are generated
    apriori_explained : Optional[str]
        A mapping from entity names to a string containing a short
        explanation of why the entity is explained. To use the default
        MitoCarta 3.0 file, run the following code:
        >>> from depmap_analysis.scripts.depmap_script2 import mito_file
        >>> from depmap_analysis.preprocessing import get_mitocarta_info
        >>> apriori_mapping = get_mitocarta_info(mito_file)
        then pass `apriori_mapping` as `apriori_explained` when calling this
        function:
        >>> main(apriori_explained=apriori_mapping, ...)
    allowed_ns : Optional[List[str]]
        A list of allowed name spaces for explanations involving
        intermediary nodes. Default: Any namespace.
    allowed_sources : Optional[List[str]]
        The allowed sources for edges. This will not affect subsequent edges
        in explanations involving 2 or more edges. Default: all sources are
        allowed.
    info : Optional[Dict[Hashable, Any]]
        An optional dict in which to save meta data about this run
    indra_date : Optional[str]
        The date of the sif dump used to create the graph
    depmap_date : Optional[str]
        The date (usually a quarter e.g. 19Q4) the depmap data was published
        on depmap.org
    sample_size : Optional[int]
        Number of correlation pairs to approximately get out of the
        correlation matrix after down sampling it
    shuffle : Optional[bool]
        If True, shuffle the correlation matrix. This is good to do in case
        the input data have some sort of structure that could lead to large
        discrepancies between compute times for the different processes.
        Default: False.
    overwrite : Optional[bool]
        If True, overwrite any output files. Default: False.
    normalize_names : Optional[bool]
        If True, try to normalize the names in the correlation matrix that
        are not found in the provided graph. Default: False.
    argparse_dict : Optional[Dict[str, Union[str, float, int, List[str]]]]
        Provide the argparse options from running this file as a script
    """
    global indranet, hgnc_node_mapping, output_list
    indranet = file_opener(indra_net)
    assert isinstance(indranet, nx.DiGraph)

    assert expl_funcs is None or isinstance(expl_funcs, (list, tuple, set))

    # 1 Check options
    sd_l, sd_u = sd_range if sd_range and len(sd_range) == 2 else \
        ((sd_range[0], None) if sd_range and len(sd_range) == 1 else
         (None, None))

    if not random and not sd_l and not sd_u:
        raise ValueError('Must specify at least a lower bound for the SD '
                         'range or flag run for random explanation')

    if graph_type == 'pybel' and not pb_node_mapping:
        raise ValueError('Must provide PyBEL node mapping with option '
                         'pb_node_mapping if graph type is "pybel"')

    if apriori_explained:
        if apriori_explained is True or mito_file_name in apriori_explained:
            # Run default
            apriori_explained = get_mitocarta_info(mito_file)
        else:
            # Hope it's a csv/tsv
            try:
                expl_df = pd.read_csv(apriori_explained)
                apriori_explained = {
                    e: d
                    for e, d in zip(expl_df.name, expl_df.description)
                }
            except Exception as err:
                raise ValueError('A-priori explained entities must be in a '
                                 'file that can be parsed as CSV/TSV with '
                                 'column names "name" for entity name and '
                                 '"description" for explanation why the '
                                 'entity is explained.') \
                    from err

        logger.info(f'Using explained set with '
                    f'{len(apriori_explained)} entities')

    outname = outname if outname.endswith('.pkl') else \
        outname + '.pkl'
    if not overwrite:
        if outname.startswith('s3://'):
            s3 = get_s3_client(unsigned=False)
            if S3Path.from_string(outname).exists(s3):
                raise FileExistsError(f'File {str(outname)} already exists!')
        elif Path(outname).is_file():
            raise FileExistsError(f'File {str(outname)} already exists!')

    if z_score is not None and Path(z_score).is_file():
        z_corr = pd.read_hdf(z_score)
    else:
        z_sc_options = {
            'crispr_raw': raw_data[0],
            'rnai_raw': raw_data[1],
            'crispr_corr': raw_corr[0],
            'rnai_corr': raw_corr[1],
            'z_corr_path': z_score
        }
        z_corr = run_corr_merge(**z_sc_options)

    if reactome_path:
        up2path, _, pathid2pathname = file_opener(reactome_path)
        reactome_dict = {
            'uniprot_mapping': up2path,
            'pathid_name_mapping': pathid2pathname
        }
    else:
        reactome_dict = None

    # Get mapping of correlation names to pybel nodes
    if graph_type == 'pybel':
        if isinstance(pb_node_mapping, dict):
            hgnc_node_mapping = pb_node_mapping
        elif isinstance(pb_node_mapping, str) and \
                Path(pb_node_mapping).is_file():
            hgnc_node_mapping = file_opener(pb_node_mapping)
        else:
            raise ValueError('Could not load pybel node mapping')

    # 2. Filter to SD range OR run random sampling
    if random:
        logger.info('Doing random sampling through df.sample')
        z_corr = z_corr.sample(142, axis=0)
        z_corr = z_corr.filter(list(z_corr.index), axis=1)
        # Remove correlation values to not confuse with real data
        z_corr.loc[:, :] = 0
    else:
        if sd_l and sd_u:
            logger.info(f'Filtering correlations to {sd_l} - {sd_u} SD')
            z_corr = z_corr[((z_corr > sd_l) & (z_corr < sd_u)) |
                            ((z_corr < -sd_l) & (z_corr > -sd_u))]
        elif isinstance(sd_l, (int, float)) and sd_l and not sd_u:
            logger.info(f'Filtering correlations to {sd_l}+ SD')
            z_corr = z_corr[(z_corr > sd_l) | (z_corr < -sd_l)]

    sd_range = (sd_l, sd_u) if sd_u else (sd_l, None)

    # Pick a sample
    if sample_size is not None and not random:
        logger.info(f'Reducing correlation matrix to a random approximately '
                    f'{sample_size} correlation pairs.')
        z_corr = down_sample_df(z_corr, sample_size)

    # Shuffle corr matrix without removing items
    elif shuffle and not random:
        logger.info('Shuffling correlation matrix...')
        z_corr = z_corr.sample(frac=1, axis=0)
        z_corr = z_corr.filter(list(z_corr.index), axis=1)

    if normalize_names:
        logger.info('Normalizing correlation matrix column names')
        z_corr = normalize_corr_names(z_corr, indranet)
    else:
        logger.info('Leaving correlation matrix column names as is')

    # 4. Add meta data
    info_dict = {}
    if info:
        info_dict['info'] = info

    # Set the script_settings
    script_settings = {
        'raw_data':
        raw_data,
        'raw_corr':
        raw_corr,
        'z_score':
        z_score,
        'random':
        random,
        'indranet':
        indra_net,
        'shuffle':
        shuffle,
        'sample_size':
        sample_size,
        'n_chunks':
        n_chunks,
        'outname':
        outname,
        'apriori_explained':
        apriori_explained if isinstance(apriori_explained, str) else 'no info',
        'graph_type':
        graph_type,
        'pybel_node_mapping':
        pb_node_mapping if isinstance(pb_node_mapping, str) else 'no info',
        'argparse_info':
        argparse_dict
    }

    # Create output list in global scope
    output_list = []
    explanations = match_correlations(corr_z=z_corr,
                                      sd_range=sd_range,
                                      script_settings=script_settings,
                                      graph_filepath=indra_net,
                                      z_corr_filepath=z_score,
                                      apriori_explained=apriori_explained,
                                      graph_type=graph_type,
                                      allowed_ns=allowed_ns,
                                      allowed_sources=allowed_sources,
                                      is_a_part_of=is_a_part_of,
                                      expl_funcs=expl_funcs,
                                      reactome_filepath=reactome_path,
                                      indra_date=indra_date,
                                      info=info_dict,
                                      depmap_date=depmap_date,
                                      n_chunks=n_chunks,
                                      immediate_only=immediate_only,
                                      return_unexplained=return_unexplained,
                                      reactome_dict=reactome_dict,
                                      subset_list=subset_list)
    if outname.startswith('s3://'):
        try:
            logger.info(f'Uploading results to s3: {outname}')
            s3 = get_s3_client(unsigned=False)
            s3outpath = S3Path.from_string(outname)
            explanations.s3_location = s3outpath.to_string()
            s3outpath.upload(s3=s3, body=pickle.dumps(explanations))
            logger.info('Finished uploading results to s3')
        except Exception:
            new_path = Path(outname.replace('s3://', ''))
            logger.warning(f'Something went wrong in s3 upload, trying to '
                           f'save locally instead to {new_path}')
            new_path.parent.mkdir(parents=True, exist_ok=True)
            dump_it_to_pickle(fname=new_path.absolute().resolve().as_posix(),
                              pyobj=explanations,
                              overwrite=overwrite)

    else:
        # mkdir in case it doesn't exist
        outpath = Path(outname)
        logger.info(f'Dumping results to {outpath}')
        outpath.parent.mkdir(parents=True, exist_ok=True)
        dump_it_to_pickle(fname=outpath.absolute().resolve().as_posix(),
                          pyobj=explanations,
                          overwrite=overwrite)
    logger.info('Script finished')
    explanations.summarize()
def main(args):

    global any_expl, any_expl_not_sr, common_parent, ab_expl_count, \
        directed_im_expl_count, both_im_dir_expl_count, \
        any_axb_non_sr_expl_count, sr_expl_count, \
        shared_regulator_only_expl_count, explanations_of_pairs, unexplained, \
        explained_nested_dict, id1, id2, nested_dict_statements, dataset_dict, \
        avg_corr, dir_node_set, nx_dir_graph, explained_set, part_of_explained,\
        sr_explanations, any_expl_ign_sr

    if args.cell_line_filter and not len(args.cell_line_filter) > 2:
        logger.info('Filtering to provided cell lines in correlation '
                    'calculations.')
        cell_lines = _parse_cell_filter(*args.cell_line_filter)
        assert len(cell_lines) > 0
    elif args.cell_line_filter and len(args.cell_line_filter) > 2:
        sys.exit('Argument --cell-line-filter only takes one or two arguments')
    # No cell line dictionary and rnai data and filtering is requested
    elif args.cell_line_filter and len(args.cell_line_filter) == 1 and \
            args.rnai_data_file:
        sys.exit('Need a translation dictionary if RNAi data is provided and '
                 'filter is requested')
    else:
        # Should be empty only when --cell-line-filter is not provided
        logger.info('No cell line filter provided. Using all cell lines in '
                    'correlation calculations.')
        cell_lines = []

    # Parse "explained genes"
    if args.explained_set and len(args.explained_set) == 2:
        explained_set = _parse_explained_genes(
            gene_set_file=args.explained_set[0],
            check_column=args.explained_set[1])
        logger.info('Loading "explained pairs."')
    elif args.explained_set and len(args.explained_set) != 2:
        sys.exit('Argument --explained-set takes exactly two arguments: '
                 '--explained-set <file> <column name>')

    # Check if belief dict is provided
    if not args.belief_score_dict and not args.nested_dict_in:
        logger.error('Belief dict must be provided through the `-b ('
                     '--belief-score-dict)` argument if no nested dict '
                     'of statements with belief score is provided through the '
                     '`-ndi (--nested-dict-in)` argument.')
        raise FileNotFoundError

    # Get dict of {hash: belief score}
    belief_dict = None  # ToDo use api to query belief scores if not loaded
    if args.belief_score_dict:
        if args.belief_score_dict.endswith('.json'):
            belief_dict = io.json_open(args.belief_score_dict)
        elif args.belief_score_dict.endswith('.pkl'):
            belief_dict = io.pickle_open(args.belief_score_dict)

    args_dict = _arg_dict(args)
    npairs = 0

    gene_filter_list = []
    all_hgnc_ids = set()
    stmts_all = []
    master_corr_dict = dnf.create_nested_dict()

    filter_settings = {
        'gene_set_filter':
        args.gene_set_filter,
        'strict':
        args.strict,
        'cell_line_filter':
        cell_lines,
        'cell_line_translation_dict':
        io.pickle_open(args.cell_line_filter[1])
        if args.cell_line_filter and len(args.cell_line_filter) == 2 else None,
        'margin':
        args.margin,
        'filter_type': (args.filter_type if args.filter_type else None)
    }

    output_settings = {
        'dump_unique_pairs': args.dump_unique_pairs,
        'outbasename': args.outbasename
    }

    # Parse CRISPR and/or RNAi data
    if args_dict.get('crispr') or args_dict.get('rnai'):
        if not filter_settings['filter_type'] and \
            args.crispr_data_file and \
                args.rnai_data_file:
            logger.info('No merge filter set. Output will be intersection of '
                        'the two data sets.')
        elif filter_settings.get('filter_type'):
            logger.info('Using filter type "%s"' %
                        filter_settings['filter_type'])
        master_corr_dict, all_hgnc_ids, stats_dict = \
            dnf.get_combined_correlations(dict_of_data_sets=args_dict,
                                          filter_settings=filter_settings,
                                          output_settings=output_settings)

        # Count pairs in merged correlation dict and dum it
        npairs = dnf._dump_master_corr_dict_to_pairs_in_csv(
            fname=args.outbasename + '_merged_corr_pairs.csv',
            nest_dict=master_corr_dict)

        if args.gene_set_filter:
            gene_filter_list = None
            if args_dict.get('crispr') and not args_dict.get('rnai'):
                gene_filter_list = io.read_gene_set_file(
                    gf=filter_settings['gene_set_filter'],
                    data=pd.read_csv(args_dict['crispr']['data'],
                                     index_col=0,
                                     header=0))
            elif args_dict.get('rnai') and not args_dict.get('crispr'):
                gene_filter_list = io.read_gene_set_file(
                    gf=filter_settings['gene_set_filter'],
                    data=pd.read_csv(args_dict['rnai']['data'],
                                     index_col=0,
                                     header=0))
            elif args_dict.get('crispr') and args_dict.get('rnai'):
                gene_filter_list = \
                    set(io.read_gene_set_file(
                        gf=filter_settings['gene_set_filter'],
                        data=pd.read_csv(args_dict['crispr']['data'],
                                         index_col=0, header=0))) & \
                    set(io.read_gene_set_file(
                        gf=filter_settings['gene_set_filter'],
                        data=pd.read_csv(args_dict['rnai']['data'],
                                         index_col=0, header=0)))
            assert gene_filter_list is not None

        else:
            gene_filter_list = None
    else:
        stats_dict = None

    # LOADING INDRA STATEMENTS
    # Get statements from file or from database that contain any gene from
    # provided list as set unless you're already loading a pre-calculated
    # nested dict and/or precalculated directed graph.

    if not (args.light_weight_stmts or args.nested_dict_in):
        if args.statements_in:  # Get statments from file
            stmts_all = set(ac.load_statements(args.statements_in))
        # Use api to get statements. _NOT_ the same as querying for each ID
        else:
            if args.gene_set_filter and gene_filter_list:
                stmts_all = dnf.dbc_load_statements(gene_filter_list)
            else:
                # if there is no gene set file, restrict to gene ids in
                # input data
                stmts_all = dnf.dbc_load_statements(list(all_hgnc_ids))

        # Dump statements to pickle file if output name has been given
        if args.statements_out:
            logger.info('Dumping read raw statements')
            ac.dump_statements(stmts=stmts_all, fname=args.statements_out)

    # Get nested dicts from statements
    if args.light_weight_stmts:
        hash_df = pd.read_csv(args.light_weight_stmts, delimiter='\t')
        nested_dict_statements = dnf.nested_hash_dict_from_pd_dataframe(
            hash_df)
    elif args.nested_dict_in:
        nested_dict_statements = io.pickle_open(args.nested_dict_in)
    else:
        nested_dict_statements = dnf.dedupl_nested_dict_gen(
            stmts_all, belief_dict)
        if args.nested_dict_out:
            io.dump_it_to_pickle(fname=args.nested_dict_out,
                                 pyobj=nested_dict_statements)

    # Get directed simple graph
    if args.directed_graph_in:
        with open(args.directed_graph_in, 'rb') as rpkl:
            nx_dir_graph = pkl.load(rpkl)
    else:
        # Create directed graph from statement dict
        nx_dir_graph = dnf.nested_stmt_dict_to_nx_digraph(
            nest_d=nested_dict_statements, belief_dict=belief_dict)
        # Save as pickle file
        if args.directed_graph_out:
            io.dump_it_to_pickle(fname=args.directed_graph_out,
                                 pyobj=nx_dir_graph)
    dir_node_set = set(nx_dir_graph.nodes)

    # LOOP THROUGH THE UNIQUE CORRELATION PAIRS, MATCH WITH INDRA NETWORK
    any_expl = 0  # Count if any explanation per (A,B) correlation found
    any_expl_not_sr = 0  # Count any explanation, exlcuding when shared
    # regulator is the only explanation
    any_expl_ign_sr = 0  # Count any explanation, ingoring shared regulator
    # explanations
    common_parent = 0  # Count if common parent found per set(A,B)
    part_of_explained = 0  # Count pairs part the "explained set"
    ab_expl_count = 0  # Count A-B/B-A as one per set(A,B)
    directed_im_expl_count = 0  # Count any A->X->B,B->X->A as one per set(A,B)
    any_axb_non_sr_expl_count = 0  # Count if shared target found per set(A,B)
    sr_expl_count = 0  # Count if shared regulator found per set(A,B)
    shared_regulator_only_expl_count = 0  # Count if only shared regulator found
    explanations_of_pairs = []  # Saves all non shared regulator explanations
    sr_explanations = []  # Saves all shared regulator explanations
    unexplained = []  # Unexplained correlations
    skipped = 0

    # The explained nested dict: (1st key = subj, 2nd key = obj, 3rd key =
    # connection type or correlation).
    #
    # directed: any A->B or B->A
    # undirected: any of complex, selfmodification, parent
    # x_is_intermediary: A->X->B or B->X->A
    # x_is_downstream: A->X<-B
    # x_is_upstream: A<-X->B
    #
    # d[subj][obj] = {correlation: {gene_set1: corr, gene_set2: corr, ...},
    #                 directed: [(stmt/stmt hash, belief score)],
    #                 undirected: [(stmt/stmt hash, belief score)],
    #                 common_parents: [list of parents]
    #                 x_is_intermediary: [(X, belief rank)],
    #                 x_is_downstream: [(X, belief rank)],
    #                 x_is_upstream: [(X, belief rank)]}
    #
    # Then in javascript you can for example do:
    # if SUBJ_is_subj_dict.obj.direct.length <-- should return zero if []
    #
    # Used to get: directed graph
    # 1. all nodes of directed graph -> 1st dropdown
    # 2. dir -> undir graph -> jsons to check all corr neighbors -> 2nd dropdown
    # 3. jsons to check if connection is direct or intermediary

    # Using the following loop structure for counter variables:
    # a = 2
    # def for_loop_body():
    #     global a
    #     a += 1
    # # Then loop like:
    # if dict:
    #     for pairs in dict:
    #         for_loop_body(args)
    # elif random:
    #     for random pair:
    #         for_loop_body(args)

    explained_nested_dict = dnf.create_nested_dict()

    # Loop rnai and/or crispr only
    if args_dict.get('rnai') or args_dict.get('crispr') and \
            not args.brca_dependencies:
        logger.info('Gene pairs generated from DepMap knockout screening data '
                    'sets')
        logger.info('Looking for connections between %i pairs' %
                    (npairs if npairs > 0 else args.max_pairs))
        for outer_id, do in master_corr_dict.items():
            for inner_id, dataset_dict in do.items():
                if len(dataset_dict.keys()) == 0:
                    skipped += 1
                    if args.verbosity:
                        logger.info('Skipped outer_id=%s and inner_id=%s' %
                                    (outer_id, inner_id))
                    continue

                id1, id2 = outer_id, inner_id
                loop_body(args)

    # Loop rnai and/or crispr AND BRCA cell line dependencies
    elif args_dict.get('rnai') or args_dict.get('crispr') and \
            args.brca_dependencies:
        logger.info('Gene pairs generated from combined knockout screens. '
                    'Output data will incluide BRCA cell line dependency\n'
                    'data as well as correlation data from knockout screens.')
        logger.info('Looking for connections between %i pairs' %
                    (npairs if npairs > 0 else args.max_pairs))

        # Load BRCA dependency data
        brca_data_set = pd.read_csv(args.brca_dependencies, header=0)
        depend_in_breast_genes = brca_data_set.drop(
            axis=1, labels=['Url Label',
                            'Type'])[brca_data_set['Type'] == 'gene']
        genes = set(depend_in_breast_genes['Gene/Compound'].values)

        for outer_id, do in master_corr_dict.items():
            for inner_id, knockout_dict in do.items():
                if len(knockout_dict.keys()) == 0:
                    skipped += 1
                    if args.verbosity:
                        logger.info('Skipped outer_id=%s and inner_id=%s' %
                                    (outer_id, inner_id))
                    continue

                id1, id2 = outer_id, inner_id
                dataset_dict = {}
                gene1_data = []
                gene2_data = []

                # Get BRCA dep data
                if id1 in genes:
                    for row in depend_in_breast_genes[
                            depend_in_breast_genes['Gene/Compound'] ==
                            id1].iterrows():
                        gene1_data.append(
                            (row[1]['Dataset'], row[1]['T-Statistic'],
                             row[1]['P-Value']))
                if id2 in genes:
                    for row in depend_in_breast_genes[
                            depend_in_breast_genes['Gene/Compound'] ==
                            id2].iterrows():
                        gene2_data.append(
                            (row[1]['Dataset'], row[1]['T-Statistic'],
                             row[1]['P-Value']))

                dataset_dict[id1] = gene1_data
                dataset_dict[id2] = gene2_data

                dataset_dict['crispr'] = (knockout_dict['crispr']
                                          if knockout_dict.get('crispr') else
                                          None),
                dataset_dict['rnai'] = (knockout_dict['rnai']
                                        if knockout_dict.get('rnai') else None)

                if id1 not in genes and id2 not in genes:
                    dataset_dict = knockout_dict

                # Run loop body
                loop_body(args)

    # loop brca dependency ONLY
    elif args.brca_dependencies and not \
            (args_dict.get('rnai') or args_dict.get('crispr')):
        logger.info(
            'Gene pairs generated from BRCA gene enrichment data only.')
        brca_data_set = pd.read_csv(args.brca_dependencies, header=0)
        depend_in_breast_genes = brca_data_set.drop(
            axis=1, labels=['Url Label',
                            'Type'])[brca_data_set['Type'] == 'gene']
        genes = set(depend_in_breast_genes['Gene/Compound'].values)
        npairs = len(list(itt.combinations(genes, 2)))
        logger.info('Looking for connections between %i pairs' %
                    (npairs if npairs > 0 else args.max_pairs))
        for id1, id2 in itt.combinations(genes, 2):
            gene1_data = []
            gene2_data = []
            # For each non-diagonal pair in file, insert in dataset_dict:
            # geneA, geneB,
            # dataset for A, dataset for B,
            # T-stat for A, T-stat for B,
            # P-value for A, P-value
            for row in depend_in_breast_genes[
                    depend_in_breast_genes['Gene/Compound'] == id1].iterrows():
                gene1_data.append((row[1]['Dataset'], row[1]['T-Statistic'],
                                   row[1]['P-Value']))

            for row in depend_in_breast_genes[
                    depend_in_breast_genes['Gene/Compound'] == id2].iterrows():
                gene2_data.append((row[1]['Dataset'], row[1]['T-Statistic'],
                                   row[1]['P-Value']))
            # dataset_dict = {id1:
            #                 [(dataset1, T-stat1, P-value1),
            #                  (dataset2, T-stat2, P-value2)],
            #                 id2:
            #                  [(..., ...)],
            #                  ...}
            dataset_dict = {id1: gene1_data, id2: gene2_data}
            loop_body(args)

    # loop random pairs from data set
    elif args_dict.get('sampling_gene_file'):
        logger.info('Gene pairs generated at random from %s' %
                    args_dict['sampling_gene_file'])
        with open(args_dict['sampling_gene_file'], 'r') as fi:
            rnd_gene_set = [l.strip() for l in fi.readlines()]

        npairs = args.max_pairs
        dataset_dict = None
        logger.info('Looking for connections between %i pairs' %
                    (npairs if npairs > 0 else args.max_pairs))
        for _ in range(npairs):
            id1, id2 = _rnd_pair_gen(rnd_gene_set)
            assert not isinstance(id1, list)
            loop_body(args)

    long_string = ''
    long_string += '-' * 63 + '\n'
    long_string += 'Summary for matching INDRA network to correlation pairs:'\
                   + '\n\n'
    long_string += '> Total number of correlation pairs checked: %i' % npairs\
                   + '\n'
    if args.verbosity:
        long_string += '> Skipped %i empty doublets in corr dict\n' % skipped

    long_string += '> Total correlations unexplained: %i' % len(unexplained)\
                   + '\n'
    long_string += '> Total correlations explained: %i' % any_expl + '\n'
    long_string += '> Total correlations explained, ignoring shared ' \
                   'regulator: %i' % any_expl_ign_sr + '\n'
    long_string += '> Total correlations explained, excluding shared ' \
                   'regulator (total - shared only): %i' % \
                   (any_expl - shared_regulator_only_expl_count) + '\n'
    long_string += '>    %i correlations have an explanation involving a ' \
                   'common parent' % common_parent + '\n'
    if args.explained_set:
        long_string += '>    %i gene pairs were considered explained as part ' \
                       'of the "explained set"' % part_of_explained + '\n'
    long_string += '>    %i explanations involving direct connection or ' \
                   'complex' % ab_expl_count + '\n'
    long_string += '>    %i correlations have a directed explanation ' \
                   'involving an intermediate node (A->X->B/A<-X<-B)' \
                   % directed_im_expl_count + '\n'
    long_string += '>    %i correlations have an explanation involving an ' \
                   'intermediate node excluding shared regulators' % \
                   any_axb_non_sr_expl_count + '\n'
    long_string += '>    %i correlations have an explanation involving a ' \
                   'shared regulator (A<-X->B)' % sr_expl_count + '\n'
    long_string += '>    %i correlations have shared regulator as only ' \
                   'explanation' % shared_regulator_only_expl_count + '\n\n'

    if stats_dict and (stats_dict.get('rnai') or stats_dict.get('crispr')):
        long_string += 'Statistics of input data:' + '\n\n'
    if stats_dict and stats_dict.get('rnai'):
        long_string += '  RNAi data ' + '\n'
        long_string += ' -----------' + '\n'
        long_string += '> mean: %f\n' % stats_dict['rnai']['mean']
        long_string += '> SD: %f\n' % stats_dict['rnai']['sigma']
        long_string += '> lower bound: %.3f*SD = %.4f\n' % (
            args_dict['rnai']['ll'],
            args_dict['rnai']['ll'] * stats_dict['rnai']['sigma'])
        if args_dict['rnai']['ul']:
            long_string += '> upper bound: %.3f*SD = %.4f\n\n' % (
                args_dict['rnai']['ul'],
                args_dict['rnai']['ul'] * stats_dict['rnai']['sigma'])
    if stats_dict and stats_dict.get('crispr'):
        long_string += '  CRISPR data ' + '\n'
        long_string += ' -------------' + '\n'
        long_string += '> mean: %f\n' % stats_dict['crispr']['mean']
        long_string += '> SD: %f\n' % stats_dict['crispr']['sigma']
        long_string += '> lower bound: %.3f*SD = %.4f\n' % (
            args_dict['crispr']['ll'],
            args_dict['crispr']['ll'] * stats_dict['crispr']['sigma'])
        if args_dict['crispr']['ul']:
            long_string += '> upper bound: %.3f*SD = %.4f\n\n' % (
                args_dict['crispr']['ul'],
                args_dict['crispr']['ul'] * stats_dict['crispr']['sigma'])
    long_string += '-' * 63 + '\n\n'

    logger.info('\n' + long_string)

    # Here create directed graph from explained nested dict
    nx_expl_dir_graph = dnf.nested_stmt_explained_dict_nx_digraph(
        nest_d=explained_nested_dict)

    if not args.no_web_files:
        # 'explained_nodes' are used to produce first drop down
        explained_nodes = list(nx_expl_dir_graph.nodes)
        logger.info('Dumping json "explainable_ids.json" for first dropdown.')
        io.dump_it_to_json(args.outbasename + '_explainable_ids.json',
                           explained_nodes)

        # Get undir graph and save each neighbor lookup as json for 2nd dropdown
        nx_expl_undir_graph = nx_expl_dir_graph.to_undirected()
        dnf.nx_undir_to_neighbor_lookup_json(
            expl_undir_graph=nx_expl_undir_graph, outbasename=args.outbasename)

    # Easiest way to check if pairs are explained or not is to loop explained
    # dict. Skip shared regulators.
    _dump_nest_dict_to_csv(fname=args.outbasename +
                           '_explained_correlations.csv',
                           nested_dict=explained_nested_dict,
                           header=['gene1', 'gene2', 'meta_data'],
                           excl_sr=True)

    io.dump_it_to_pickle(fname=args.outbasename + '_explained_nest_dict.pkl',
                         pyobj=explained_nested_dict)
    headers = ['subj', 'obj', 'type', 'X', 'meta_data']
    io.dump_it_to_csv(fname=args.outbasename + '_explanations_of_pairs.csv',
                      pyobj=explanations_of_pairs,
                      header=headers)
    io.dump_it_to_csv(fname=args.outbasename +
                      '_explanations_of_shared_regulators.csv',
                      pyobj=sr_explanations,
                      header=headers)
    io.dump_it_to_csv(fname=args.outbasename + '_unexpl_correlations.csv',
                      pyobj=unexplained,
                      header=headers[:-2])
    with open(args.outbasename + '_script_summary.txt', 'w') as fo:
        fo.write(long_string)
    return 0
def normalize_corr_names(corr_m: pd.DataFrame,
                         graph: Union[DiGraph, MultiDiGraph],
                         ns: str = None,
                         name_mapping: Dict[str, str] = None,
                         dump_mapping: bool = False,
                         dump_name: str = None) -> pd.DataFrame:
    # todo:
    #  - Move this function, together with get_ns_id,
    #    get_ns_id_pybel_node, normalize_entitites to net_functions
    #  - Provide ns and id to the correlation matrix here too (requires
    #    overhaul of depmap script)
    #  - Add support for pybel
    #  - If ns is provided loop through results and get the data for the
    #    matching name space, otherwise use the function used on Agent
    #    normalization
    """

    Parameters
    ----------
    corr_m : pd.DataFrame
        A square pandas dataframe representing a correlation matrix. It is
        assumed that columns and indices are identical.
    graph : Union[DiGraph, MultiDiGraph]
        A graph to look in to see if the names are there
    ns : str
        The assumed namespace of the names in corr_m
    name_mapping : Optional[Dict[str, str]]
        If provided try to map names from this dict
    dump_mapping : bool
        If True, save the mapping to pickle, Default: False.
    dump_name : Optional[str]
        The file path to save the mapping at

    Returns
    -------
    pd.DataFrame
    """
    def _get_ns_id(n: str, g: Union[DiGraph, MultiDiGraph]) -> Tuple[str, str]:
        return g.nodes[n]['ns'], g.nodes[n]['id']

    if name_mapping is None and dump_mapping and not dump_name:
        raise ValueError('Must provide file path with variable `dump_name` '
                         'if name mapping is dumped')

    col_names = corr_m.columns.values
    normalized_names = []
    mapping = {}
    for name in col_names:
        # If mapping is provided and name is in the mapping
        if name_mapping and name in name_mapping:
            normalized_names.append(name_mapping[name])
        # Otherwise use gilda
        else:
            if name in graph.nodes:
                normalized_names.append(name)
                # If we want to save the mapping
                if dump_mapping and name_mapping is None:
                    mapping[name] = name
            else:
                ns, _id, nn = gilda_normalization(name)
                if nn:
                    normalized_names.append(nn)
                    # If we want to save the mapping
                    if dump_mapping and name_mapping is None:
                        mapping[name] = nn
                else:
                    normalized_names.append(name)
                    # If we want to save the mapping
                    if dump_mapping and name_mapping is None:
                        mapping[name] = name

    # Reset the normalized names
    corr_m.columns = normalized_names
    corr_m.index = normalized_names

    if dump_mapping and mapping:
        dump_it_to_pickle(fname=dump_name, pyobj=mapping, overwrite=True)

    return corr_m