예제 #1
0
def collect_data_parallel(directory, sids, fetch_all=False):
    sids = [sids] if type(sids) == str else sids

    if not file.is_dir(directory):
        file.make_dir(directory)

    with requests.Session() as s:
        pseq(sids).filter(lambda sid: fetch_all or not file.is_file(_file_name_(sid)))\
                  .for_each(lambda sid: collect_id(s, sid))
예제 #2
0
def main():
    fund_tickers = ["MHN", "MYN", "NVG", "NRK", "NAD", "RGT", "RMT", "JMF", "NML",
                    "JPS", "GGZ", "GDV", "GDL", "GGO", "NID", "BIT", "BTT"]
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s %(levelname)s %(module)s %(funcName)s %(lineno)d: '
               '%(message)s'
    )
    c = pycef.Client()
    discounted_funds: List[pycef.Fund] = list((pseq(fund_tickers)
                                               .map(lambda f: c.get_fund_by_ticker(f))
                                               .map(lambda f: {f: f.is_present_discount_2sigma_plus()})
                                               .filter(lambda d: list(d.values())[0] == True))
                                              .map(lambda d: list(d.keys())[0]))

    with open(os.path.join(os.getcwd(), *["bin", "prod.yaml"]), 'r') as stream:
        try:
            email_configs = yaml.load(stream)
        except yaml.YAMLError as exc:
            print(exc)

    if len(discounted_funds) == 0:
        logging.info("no discounted funds today")
    else:
        pay_load = '<br><br>'.join(list(seq(discounted_funds).map(lambda f: str(f))))
        # noinspection PyUnboundLocalVariable
        send_email(user=email_configs['email_login'],
                   pwd=email_configs["email_pwd"],
                   recipient=email_configs['recipient'],
                   subject="CEF daily report {}".format(datetime.date.today().strftime('%Y-%m-%d')),
                   body=pay_load)
예제 #3
0
def main():
    fund_tickers = [
        "MHN", "MYN", "NVG", "NRK", "NAD", "RGT", "RMT", "JMF", "NML", "JPS",
        "GGZ", "GDV", "GDL", "GGO", "NID", "BIT", "BTT"
    ]
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s %(levelname)s %(module)s %(funcName)s %(lineno)d: '
        '%(message)s')
    c = pycef.Client()
    all_funds_json: List[dict] = list(
        (pseq(fund_tickers, processes=2,
              partition_size=10).map(lambda f: c.get_fund_by_ticker(f)).map(
                  lambda f: {f: f.is_present_discount_2sigma_plus()})
         ).map(lambda d: (list(d.keys())[0]).to_dict()))
    all_funds_df = (pandas.io.json.json_normalize(all_funds_json)[[
        'Name', 'Ticker', 'Date', 'M2M', 'Nav', 'Premium', '52 wk avg',
        '1 Sigma'
    ]])
    pay_load = '<br><br>' + all_funds_df.to_html(index=False)

    with open(os.path.join(os.getcwd(), *["bin", "prod.yaml"]), 'r') as stream:
        try:
            email_configs = yaml.load(stream)
        except yaml.YAMLError as exc:
            print(exc)
    # noinspection PyUnboundLocalVariable
    send_email(user=email_configs['email_login'],
               pwd=email_configs["email_pwd"],
               recipient=email_configs['recipient'],
               subject="CEF bi-weekly report {}".format(
                   datetime.date.today().strftime('%Y-%m-%d')),
               body=pay_load)
예제 #4
0
 def update_questions(self):
     def update_question_tuple(q: QBApiQuestion):
         position, text = self.request_text(q.id, q.position + 1)
         if q.text != '':
             text = q.text + ' ' + text
         while '.' not in text:
             position, new_text = self.request_text(q.id, position + 1)
             text += ' ' + new_text
         print("Question text: {0}".format(text))
         return QBApiQuestion(q.fold, q.id, q.word_count, position, text, q.guess, None)
     self.questions = pseq(self.questions).map(update_question_tuple).cache()
예제 #5
0
파일: guesser.py 프로젝트: Pinafore/qb
def read_guesser_reports(guesser):
    report_paths = glob.glob(f'output/guesser/{guesser}/*/guesser_report_guessdev.pickle', recursive=True)
    reports = pseq(report_paths).map(parse_report).list()
    hyper_params = set()
    fake_params = {'random_seed', 'training_time', 'config_num'}

    for r in reports:
        for p in r['guesser_params']:
            if p not in fake_params:
                hyper_params.add(p)
    hyper_params = list(hyper_params)
    return reports, hyper_params
예제 #6
0
def read_guesser_reports(guesser):
    report_paths = glob.glob(
        f"output/guesser/{guesser}/*/guesser_report_guessdev.pickle",
        recursive=True)
    reports = pseq(report_paths).map(parse_report).list()
    hyper_params = set()
    fake_params = {"random_seed", "training_time", "config_num"}

    for r in reports:
        for p in r["guesser_params"]:
            if p not in fake_params:
                hyper_params.add(p)
    hyper_params = list(hyper_params)
    return reports, hyper_params
예제 #7
0
def qanta_2012_stats():
    """
    This computes and prints dataset statistics for prior versions from EMNLP 2012.
    Published results use private NAQT data, these stats are computed using only public data.
    Use nltk for word tokenization to be consistent with prior analysis.
    Use spacy for sentence tokenization to be consistent with qanta dataset preprocessing.
    (We don't use word tokenizations in dataset preprocessing, we consider it a model detail.)
    """
    with open('data/external/emnlp_2012_questions.csv') as f:
        questions_2012 = list(csv.reader(f))

    eprint('N EMNLP 2012 Questions', len(questions_2012))
    questions_2012 = [q[4] for q in questions_2012]
    tokenized_2012 = pseq(questions_2012).map(nltk.word_tokenize).list()
    n_tokens_2012 = sum(len(q) for q in tokenized_2012)
    eprint('N EMNLP 2012 Tokens', n_tokens_2012)
    n_sentences = [len(nlp(q)) for q in tqdm(questions_2012)]
    eprint('N EMNLP 2012 Sentences', sum(n_sentences))
예제 #8
0
def compute_stats():
    qdb = QuestionDatabase(QB_QUESTION_DB)
    ir = IrExtractor()
    questions = qdb.guess_questions()
    test_guesses = pseq(questions, partition_size=100)\
        .filter(lambda q: q.fold == 'test')\
        .map(lambda q: (q.page, ir.text_guess(q.flatten_text())))
    correct = 0
    close = 0
    total = 0
    for page, guesses in test_guesses:
        top_guess = max(guesses.items(), key=lambda x: x[1], default=None)
        if top_guess is not None and page == top_guess[0]:
            correct += 1
        elif page in guesses:
            close += 1
        total += 1
    print("Total Correct: {0}, Percent Correct: {1}".format(correct, correct / total))
    print("Total Close: {0}, Percent Close: {1}".format(close, close / total))
예제 #9
0
def compute_stats():
    qdb = QuestionDatabase(QB_QUESTION_DB)
    ir = IrExtractor()
    questions = qdb.guess_questions()
    test_guesses = pseq(questions, partition_size=100)\
        .filter(lambda q: q.fold == 'test')\
        .map(lambda q: (q.page, ir.text_guess(q.flatten_text())))
    correct = 0
    close = 0
    total = 0
    for page, guesses in test_guesses:
        top_guess = max(guesses.items(), key=lambda x: x[1], default=None)
        if top_guess is not None and page == top_guess[0]:
            correct += 1
        elif page in guesses:
            close += 1
        total += 1
    print("Total Correct: {0}, Percent Correct: {1}".format(
        correct, correct / total))
    print("Total Close: {0}, Percent Close: {1}".format(close, close / total))
예제 #10
0
 def get_api_questions(self):
     url = 'http://{domain}/qb-api/v1/questions'.format(domain=self.domain)
     response = seq(requests.get(url).json()['questions'])
     return pseq(response)\
         .map(lambda r: QBApiQuestion(**r, position=-1, text='', guess=None, all_guesses=None))\
         .cache()
예제 #11
0
 def get_constituent_prices_and_free_float(self):
     self.components = (pseq(
         self.components, processes=4, partition_size=130).map(
             lambda stock: stock.get_price_and_float()).to_list())
예제 #12
0
파일: jmlr_diversity.py 프로젝트: NPSDC/qb
def collect_parses(question_sentences):
    return pseq(question_sentences).map(fetch_parse).list()
예제 #13
0
def save_all_emails(start=START_ID, stop=STOP_ID):
    base_url = 'https://wikileaks.org/dnc-emails/get/{id}'
    urls = [(i, base_url.format(id=i)) for i in range(start, stop + 1)]
    pseq(urls).map(
        lambda kv: (kv[0], kv[1], save_email(kv[0], kv[1]))
    ).to_json('data/results.json')
예제 #14
0
파일: jmlr_diversity.py 프로젝트: NPSDC/qb
node = Forward()
root = Group(part + leaf + rpar)
node << OneOrMore(root | Group(part + node + rpar))
node.enablePackrat()


def try_parse(corenlp_parse):
    if corenlp_parse is None:
        return None
    try:
        return node.parseString(corenlp_parse).asList()
    except ParseException:
        return None


qb_parse_results = pseq(qb_parses).map(try_parse).list()
sq_parse_results = pseq(sq_parses).map(try_parse).list()
squad_parse_results = pseq(squad_parses).map(try_parse).list()
tqa_parse_results = pseq(tqa_parses).map(try_parse).list()
jeopardy_parse_results = pseq(jeopardy_parses).map(try_parse).list()


def compute_pcfg(parse_results):
    non_terminals = Counter()
    transitions = defaultdict(Counter)
    for tree in parse_results:
        if tree is not None:
            root = tree[0]
            update_transitions(root, non_terminals, transitions)
    return non_terminals, transitions
예제 #15
0
    def __init__(self,
                 path,
                 qanta_id_field,
                 sent_field,
                 page_field,
                 text_field,
                 unigram_field,
                 bigram_field,
                 trigram_field,
                 example_mode='sentence',
                 use_wiki=False,
                 n_wiki_sentences=3,
                 replace_title_mentions='',
                 **kwargs):
        from unidecode import unidecode

        if use_wiki and 'train' in path:
            base_path = os.path.dirname(path)
            filename = os.path.basename(s3_wiki)
            output_file = os.path.join(base_path, filename)
            if not os.path.exists(output_file):
                download_from_url(s3_wiki, output_file)
            with open(output_file) as f:
                self.wiki_lookup = json.load(f)
        else:
            self.wiki_lookup = {}
        self.path = path
        self.example_mode = example_mode

        text_dependent_fields = []
        if text_field is not None:
            text_dependent_fields.append(('text', text_field))
        if unigram_field is not None:
            text_dependent_fields.append(('unigram', unigram_field))
        if bigram_field is not None:
            text_dependent_fields.append(('bigram', bigram_field))
        if trigram_field is not None:
            text_dependent_fields.append(('trigram', trigram_field))

        example_fields = {
            'qanta_id': [('qanta_id', qanta_id_field)],
            'sent': [('sent', sent_field)],
            'page': [('page', page_field)],
            'text': text_dependent_fields
        }

        examples = []
        answer_set = set()
        with open(path) as f:
            for ex in json.load(f)['questions']:
                if example_mode == 'sentence':
                    sentences = [
                        ex['text'][start:end]
                        for start, end in ex['tokenizations']
                    ]
                    for i, s in enumerate(sentences):
                        examples.append(
                            Example.fromdict(
                                {
                                    'qanta_id': ex['qanta_id'],
                                    'sent': i,
                                    'text': unidecode(s),
                                    'page': ex['page']
                                }, example_fields))
                        answer_set.add(ex['page'])
                elif example_mode == 'question':
                    examples.append(
                        Example.fromdict(
                            {
                                'qanta_id': ex['qanta_id'],
                                'sent': -1,
                                'text': unidecode(ex['text']),
                                'page': ex['page']
                            }, example_fields))
                    answer_set.add(ex['page'])
                else:
                    raise ValueError(
                        f"Valid modes are 'sentence' and 'question', but '{example_mode}' was given"
                    )

        if use_wiki and n_wiki_sentences > 0 and 'train' in path:
            print('Loading wikipedia')
            pages = [(p, self.wiki_lookup[p]['text']) for p in answer_set
                     if p in self.wiki_lookup]

            def extract(args):
                title, text = args
                sentences = extract_wiki_sentences(
                    title,
                    text,
                    n_wiki_sentences,
                    replace_title_mentions=replace_title_mentions)
                return title, sentences

            for page, sentences in pseq(pages).map(extract).list():
                for i, s in enumerate(sentences):
                    examples.append(
                        Example.fromdict(
                            {
                                'qanta_id': -1,
                                'sent': i,
                                'text': s,
                                'page': page
                            }, example_fields))

        dataset_fields = {
            'qanta_id': qanta_id_field,
            'sent': sent_field,
            'page': page_field,
        }
        if text_field is not None:
            dataset_fields['text'] = text_field
        if unigram_field is not None:
            dataset_fields['unigram'] = unigram_field
        if bigram_field is not None:
            dataset_fields['bigram'] = bigram_field
        if trigram_field is not None:
            dataset_fields['trigram'] = trigram_field

        super(QuizBowl, self).__init__(examples, dataset_fields, **kwargs)
예제 #16
0
파일: dataset.py 프로젝트: NPSDC/qb
    def __init__(
        self,
        path,
        qanta_id_field,
        sent_field,
        page_field,
        text_field,
        unigram_field,
        bigram_field,
        trigram_field,
        example_mode="sentence",
        use_wiki=False,
        n_wiki_sentences=3,
        replace_title_mentions="",
        **kwargs,
    ):
        from unidecode import unidecode

        if use_wiki and "train" in path:
            base_path = os.path.dirname(path)
            filename = os.path.basename(s3_wiki)
            output_file = os.path.join(base_path, filename)
            if not os.path.exists(output_file):
                download_from_url(s3_wiki, output_file)
            with open(output_file) as f:
                self.wiki_lookup = json.load(f)
        else:
            self.wiki_lookup = {}
        self.path = path
        self.example_mode = example_mode

        text_dependent_fields = []
        if text_field is not None:
            text_dependent_fields.append(("text", text_field))
        if unigram_field is not None:
            text_dependent_fields.append(("unigram", unigram_field))
        if bigram_field is not None:
            text_dependent_fields.append(("bigram", bigram_field))
        if trigram_field is not None:
            text_dependent_fields.append(("trigram", trigram_field))

        example_fields = {
            "qanta_id": [("qanta_id", qanta_id_field)],
            "sent": [("sent", sent_field)],
            "page": [("page", page_field)],
            "text": text_dependent_fields,
        }

        examples = []
        answer_set = set()
        with open(path) as f:
            for ex in json.load(f)["questions"]:
                if example_mode == "sentence":
                    sentences = [
                        ex["text"][start:end]
                        for start, end in ex["tokenizations"]
                    ]
                    for i, s in enumerate(sentences):
                        examples.append(
                            Example.fromdict(
                                {
                                    "qanta_id": ex["qanta_id"],
                                    "sent": i,
                                    "text": unidecode(s),
                                    "page": ex["page"],
                                },
                                example_fields,
                            ))
                        answer_set.add(ex["page"])
                elif example_mode == "question":
                    examples.append(
                        Example.fromdict(
                            {
                                "qanta_id": ex["qanta_id"],
                                "sent": -1,
                                "text": unidecode(ex["text"]),
                                "page": ex["page"],
                            },
                            example_fields,
                        ))
                    answer_set.add(ex["page"])
                else:
                    raise ValueError(
                        f"Valid modes are 'sentence' and 'question', but '{example_mode}' was given"
                    )

        if use_wiki and n_wiki_sentences > 0 and "train" in path:
            print("Loading wikipedia")
            pages = [(p, self.wiki_lookup[p]["text"]) for p in answer_set
                     if p in self.wiki_lookup]

            def extract(args):
                title, text = args
                sentences = extract_wiki_sentences(
                    title,
                    text,
                    n_wiki_sentences,
                    replace_title_mentions=replace_title_mentions,
                )
                return title, sentences

            for page, sentences in pseq(pages).map(extract).list():
                for i, s in enumerate(sentences):
                    examples.append(
                        Example.fromdict(
                            {
                                "qanta_id": -1,
                                "sent": i,
                                "text": s,
                                "page": page
                            },
                            example_fields,
                        ))

        dataset_fields = {
            "qanta_id": qanta_id_field,
            "sent": sent_field,
            "page": page_field,
        }
        if text_field is not None:
            dataset_fields["text"] = text_field
        if unigram_field is not None:
            dataset_fields["unigram"] = unigram_field
        if bigram_field is not None:
            dataset_fields["bigram"] = bigram_field
        if trigram_field is not None:
            dataset_fields["trigram"] = trigram_field

        super(QuizBowl, self).__init__(examples, dataset_fields, **kwargs)
예제 #17
0
def deeplucia_eval(keras_model_filename, config_json_filename):

    import os
    import functools
    import itertools

    from pathlib import Path

    from functional import seq
    from functional import pseq  # seq with parallelism

    from deeplucia_toolkit import prep_matrix
    from deeplucia_toolkit import prep_label
    from deeplucia_toolkit import make_dataset
    from deeplucia_toolkit import make_model
    from deeplucia_toolkit import misc

    import tensorflow as tf

    from tensorflow import keras
    from tensorflow.keras import layers
    from tensorflow.keras.models import Model

    import numpy

    from sklearn.metrics import precision_score
    from sklearn.metrics import recall_score
    from sklearn.metrics import f1_score

    from sklearn.metrics import average_precision_score
    from sklearn.metrics import roc_auc_score
    from sklearn.metrics import confusion_matrix
    from sklearn.metrics import matthews_corrcoef

    # open config file
    with open(config_json_filename) as config_json_file:
        config_dict = json.load(config_json_file)
        config = SimpleNamespace(**config_dict)

    # obvious parameter setting
    local_path_base = Path(Path.cwd() / "Features")
    chrominfo_filename = local_path_base / "ChromInfo_hg19.txt"
    chrom_set = set([
        "chr1", "chr2", "chr3", "chr4", "chr5", "chr6", "chr7", "chr8", "chr9",
        "chr10", "chr11", "chr12", "chr13", "chr14", "chr15", "chr16", "chr17",
        "chr18", "chr19", "chr20", "chr21", "chr22", "chrX"
    ])

    test_chrom_set = set(config.test_chrom_list)

    n2p_ratio = config.n2p_ratio
    num_pos = max(1, int(config.num_pos / config.n2p_ratio))

    # set filepath
    loop_label_txtgz_filename = Path(
        local_path_base) / "Label" / "isHC.loop_list.txt.gz"
    anchor_label_txtgz_filename = Path(
        local_path_base) / "Label" / "isHC.anchor_list.txt.gz"

    seq_numpy_filename = Path(
        local_path_base) / "GenomicFeature" / "isHC" / "isHC.seq_onehot.npy"

    sample_list = config.sample_id_list
    mark_list = [
        "DNase", "H2AFZ", "H3K27ac", "H3K27me3", "H3K36me3", "H3K4me1",
        "H3K4me2", "H3K4me3", "H3K79me2", "H3K9ac", "H3K9me3", "H4K20me1"
    ]
    sample_mark_to_epi_numpy_filename = {}
    for epi_numpy_filename in Path(local_path_base).glob(
            "EpigenomicFeature/*/*.npy"):
        sample_mark = tuple(epi_numpy_filename.name.split(".")[:2])
        sample_mark_to_epi_numpy_filename[sample_mark] = epi_numpy_filename

    sample_to_sample_index = misc.get_sample_index(sample_list)

    # load feature array
    seq_array = prep_matrix.load_seq_array(seq_numpy_filename)
    multisample_multimark_epi_array = prep_matrix.load_multisample_multimark_epi_array(
        sample_list,
        mark_list,
        sample_mark_to_epi_numpy_filename,
        cap_crit=0.95)
    print("load feature array => DONE")

    # load loop label
    chrom_to_size = misc.load_chrominfo(chrominfo_filename)

    pos_sample_locus_index_pair_set = prep_label.load_loop_label(
        sample_list, loop_label_txtgz_filename)
    anchor_locus_index_set = prep_label.get_anchor_locus(
        sample_list, pos_sample_locus_index_pair_set)
    chrom_to_locus_index_range = prep_label.get_chrom_range(
        anchor_label_txtgz_filename, chrom_to_size)

    index_max = max(
        itertools.chain(*list(chrom_to_locus_index_range.values())))

    test_locus_index_range_set = {
        chrom_to_locus_index_range[chrom]
        for chrom in test_chrom_set
    }
    print("load loop label => DONE")

    # curryize basic functions
    is_intra_chrom = functools.partial(
        misc.is_intra_chrom,
        chrom_to_locus_index_range=chrom_to_locus_index_range)
    is_anchor_bearing = functools.partial(
        misc.is_anchor_bearing, anchor_locus_index_set=anchor_locus_index_set)
    permute_same_distance = functools.partial(prep_label.permute_same_distance,
                                              index_max=index_max)
    gen_neg_sample_locus_index_pair = functools.partial(
        prep_label.gen_neg_sample_locus_index_pair,
        sample_list=sample_list,
        index_max=index_max)

    is_in_test_range_set = functools.partial(
        misc.is_in_desired_range_set,
        index_range_set=test_locus_index_range_set)

    def permute_pos_sample_locus_index_pair(pos_sample_locus_index_pair,
                                            n2p_ratio, is_in_range_set):
        neg_sample_locus_index_pair_list = (seq(
            gen_neg_sample_locus_index_pair(pos_sample_locus_index_pair)
        ).filter(is_in_range_set).filter(is_intra_chrom).filter_not(
            is_anchor_bearing).take(n2p_ratio).to_list())
        return neg_sample_locus_index_pair_list

    permute_pos_sample_locus_index_pair_test = functools.partial(
        permute_pos_sample_locus_index_pair,
        n2p_ratio=n2p_ratio,
        is_in_range_set=is_in_test_range_set)
    print("curryize basic functions => DONE")

    # split train/validate label
    test_pos_sample_locus_index_pair_list = (
        pseq(pos_sample_locus_index_pair_set).filter(
            is_in_test_range_set).to_list())
    test_neg_sample_locus_index_pair_list = (
        pseq(test_pos_sample_locus_index_pair_list).flat_map(
            permute_pos_sample_locus_index_pair_test).to_list())

    print("split test/train/validate label => DONE")

    # merge pos/neg label
    test_sample_locus_index_pair_list = test_pos_sample_locus_index_pair_list + test_neg_sample_locus_index_pair_list

    print("merge pos/neg label => DONE")

    # prepare model
    model = keras.models.load_model(keras_model_filename)
    print("load model => DONE")

    prob_pred_list = []
    label_pred_list = []
    label_true_list = []

    chunk_size = 10000

    for _, chunk in itertools.groupby(
            enumerate(test_sample_locus_index_pair_list),
            lambda x: x[0] // chunk_size):
        test_sample_locus_index_pair_sublist = [
            indexed_sample_locus_index_pair[1]
            for indexed_sample_locus_index_pair in chunk
        ]
        feature_test, label_true = make_dataset.extract_seq_epi_dataset_unshuffled(
            test_sample_locus_index_pair_sublist,
            pos_sample_locus_index_pair_set, seq_array,
            multisample_multimark_epi_array, sample_to_sample_index)
        output = model.predict(feature_test)
        prob_pred = numpy.squeeze(output, axis=1)
        label_pred = list(map(lambda prob: int(round(prob)), prob_pred))

        prob_pred_list.append(prob_pred)
        label_pred_list.append(label_pred)
        label_true_list.append(label_true)

    prob_pred = list(itertools.chain(*prob_pred_list))
    label_pred = list(itertools.chain(*label_pred_list))
    label_true = list(itertools.chain(*label_true_list))

    f1 = f1_score(label_true, label_pred)
    mcc = matthews_corrcoef(label_true, label_pred)
    au_ro_curve = roc_auc_score(label_true, prob_pred)
    au_pr_curve = average_precision_score(label_true, prob_pred)

    model_evaluation_filename = "eval_model/" + keras_model_filename.split(
        "/")[-1] + "." + config.log_id + ".xls"

    with open(model_evaluation_filename, "wt") as model_evaluation_file:

        model_evaluation_file.write(
            keras_model_filename.split("/")[-1][:-20] + "\t" + config.log_id +
            "\tAUROC\t" + str(au_ro_curve) + "\n")
        model_evaluation_file.write(
            keras_model_filename.split("/")[-1][:-20] + "\t" + config.log_id +
            "\tAUPRC\t" + str(au_pr_curve) + "\n")
        model_evaluation_file.write(
            keras_model_filename.split("/")[-1][:-20] + "\t" + config.log_id +
            "\tMCC\t" + str(mcc) + "\n")
        model_evaluation_file.write(
            keras_model_filename.split("/")[-1][:-20] + "\t" + config.log_id +
            "\tF1\t" + str(f1) + "\n")
예제 #18
0
def deeplucia_train(config_json_filename):

	import os
	import functools
	import itertools

	from pathlib import Path

	from functional import seq
	from functional import pseq # seq with parallelism 

	from deeplucia_toolkit import prep_matrix
	from deeplucia_toolkit import prep_label
	from deeplucia_toolkit import make_dataset
	from deeplucia_toolkit import make_model
	from deeplucia_toolkit import misc

	import tensorflow as tf

	from tensorflow import keras
	from tensorflow.keras import layers
	from tensorflow.keras import regularizers
	from tensorflow.keras.models import Model
	from tensorflow.keras.callbacks import ModelCheckpoint
	from tensorflow.keras.callbacks import EarlyStopping
	from tensorflow.keras.callbacks import CSVLogger
	from tensorflow.keras.callbacks import Callback


	if not os.path.isdir("modelOUT/"):
		os.makedirs("modelOUT/")
	if not os.path.isdir("history/"):
		os.makedirs("history/")

	# open config file 
	with open(config_json_filename) as config_json_file:
		config_dict = json.load(config_json_file)
		config = SimpleNamespace(**config_dict)

	# obvious parameter setting 
	local_path_base = Path (Path.cwd() / "Features")
	chrominfo_filename = local_path_base / "ChromInfo_hg19.txt"
	chrom_set = set(["chr1","chr2","chr3","chr4","chr5","chr6","chr7","chr8","chr9","chr10","chr11","chr12","chr13","chr14","chr15","chr16","chr17","chr18","chr19","chr20","chr21","chr22","chrX"])
	
	validate_chrom_set =  set(config.val_chrom_list)
	test_chrom_set = set(config.test_chrom_list)
	train_chrom_set = chrom_set - test_chrom_set - validate_chrom_set
	
	n2p_ratio = config.n2p_ratio
	#num_pos = config.num_pos 
	num_pos = max(1,int(config.num_pos/config.n2p_ratio))

	# set filepath
	loop_label_txtgz_filename = Path(local_path_base) / "Label" / "isHC.loop_list.txt.gz"
	anchor_label_txtgz_filename = Path(local_path_base) / "Label" / "isHC.anchor_list.txt.gz"

	seq_numpy_filename = Path(local_path_base) / "GenomicFeature" / "isHC" / "isHC.seq_onehot.npy"

	sample_list = config.sample_id_list
	mark_list = ["DNase","H2AFZ","H3K27ac","H3K27me3","H3K36me3","H3K4me1","H3K4me2","H3K4me3","H3K79me2","H3K9ac","H3K9me3","H4K20me1"]
	sample_mark_to_epi_numpy_filename = {}
	for epi_numpy_filename in Path(local_path_base).glob("EpigenomicFeature/*/*.npy"):
		sample_mark = tuple(epi_numpy_filename.name.split(".")[:2])
		sample_mark_to_epi_numpy_filename[sample_mark] = epi_numpy_filename

	sample_to_sample_index = misc.get_sample_index(sample_list)

	# load feature array
	seq_array = prep_matrix.load_seq_array(seq_numpy_filename)
	multisample_multimark_epi_array = prep_matrix.load_multisample_multimark_epi_array(sample_list,mark_list,sample_mark_to_epi_numpy_filename,cap_crit=0.95)
	print("load feature array => DONE")

	# load loop label
	chrom_to_size = misc.load_chrominfo(chrominfo_filename)

	pos_sample_locus_index_pair_set = prep_label.load_loop_label(sample_list,loop_label_txtgz_filename)
	anchor_locus_index_set = prep_label.get_anchor_locus(sample_list,pos_sample_locus_index_pair_set)
	chrom_to_locus_index_range= prep_label.get_chrom_range(anchor_label_txtgz_filename,chrom_to_size)

	index_max = max(itertools.chain(*list(chrom_to_locus_index_range.values())))

	test_locus_index_range_set = {chrom_to_locus_index_range[chrom] for chrom in test_chrom_set}
	train_locus_index_range_set = {chrom_to_locus_index_range[chrom] for chrom in train_chrom_set}
	validate_locus_index_range_set = {chrom_to_locus_index_range[chrom] for chrom in validate_chrom_set}
	print("load loop label => DONE")

	# curryize basic functions
	is_intra_chrom = functools.partial(misc.is_intra_chrom,chrom_to_locus_index_range=chrom_to_locus_index_range)
	is_anchor_bearing = functools.partial(misc.is_anchor_bearing,anchor_locus_index_set=anchor_locus_index_set)
	permute_same_distance = functools.partial(prep_label.permute_same_distance,index_max = index_max)
	gen_neg_sample_locus_index_pair = functools.partial(prep_label.gen_neg_sample_locus_index_pair,sample_list=sample_list,index_max=index_max)

	is_in_test_range_set = functools.partial(misc.is_in_desired_range_set,index_range_set=test_locus_index_range_set)
	is_in_train_range_set = functools.partial(misc.is_in_desired_range_set,index_range_set=train_locus_index_range_set)
	is_in_validate_range_set = functools.partial(misc.is_in_desired_range_set,index_range_set=validate_locus_index_range_set)

	def permute_pos_sample_locus_index_pair(pos_sample_locus_index_pair,n2p_ratio,is_in_range_set):
		neg_sample_locus_index_pair_list = (
			seq(gen_neg_sample_locus_index_pair(pos_sample_locus_index_pair))
			.filter(is_in_range_set)
			.filter(is_intra_chrom)
			.filter_not(is_anchor_bearing)
			.take(n2p_ratio)
			.to_list())
		return neg_sample_locus_index_pair_list

	permute_pos_sample_locus_index_pair_test = functools.partial(permute_pos_sample_locus_index_pair,n2p_ratio = n2p_ratio,is_in_range_set = is_in_test_range_set)
	permute_pos_sample_locus_index_pair_train = functools.partial(permute_pos_sample_locus_index_pair,n2p_ratio = n2p_ratio,is_in_range_set = is_in_train_range_set)
	permute_pos_sample_locus_index_pair_validate = functools.partial(permute_pos_sample_locus_index_pair,n2p_ratio = n2p_ratio,is_in_range_set = is_in_validate_range_set)
	print("curryize basic functions => DONE")

	# split train/validate label
	train_pos_sample_locus_index_pair_list = (pseq(pos_sample_locus_index_pair_set).filter(is_in_train_range_set).to_list())
	train_neg_sample_locus_index_pair_list = (pseq(train_pos_sample_locus_index_pair_list).flat_map(permute_pos_sample_locus_index_pair_train).to_list())

	validate_pos_sample_locus_index_pair_list = (pseq(pos_sample_locus_index_pair_set).filter(is_in_validate_range_set).to_list())
	validate_neg_sample_locus_index_pair_list = (pseq(validate_pos_sample_locus_index_pair_list).flat_map(permute_pos_sample_locus_index_pair_validate).to_list())
	print("split test/train/validate label => DONE")

	# merge pos/neg label
	train_sample_locus_index_pair_list_gen = prep_label.gen_sample_locus_index_pair_list(train_pos_sample_locus_index_pair_list,train_neg_sample_locus_index_pair_list,num_pos,n2p_ratio)
	validate_sample_locus_index_pair_list_gen = prep_label.gen_sample_locus_index_pair_list(validate_pos_sample_locus_index_pair_list,validate_neg_sample_locus_index_pair_list,num_pos,n2p_ratio)
	print("merge pos/neg label => DONE")

	# make dataset
	train_dataset_gen = make_dataset.gen_seq_epi_dataset(train_sample_locus_index_pair_list_gen,pos_sample_locus_index_pair_set,seq_array,multisample_multimark_epi_array,sample_to_sample_index)
	validate_dataset_gen = make_dataset.gen_seq_epi_dataset(validate_sample_locus_index_pair_list_gen,pos_sample_locus_index_pair_set,seq_array,multisample_multimark_epi_array,sample_to_sample_index)
	print("make dataset => DONE")

	# prepare model
	model =make_model.seq_epi_20210105_v1_model()
	print("load model => DONE")

	# keras parameter setting 

	adam_optimizer = keras.optimizers.Adam(lr=0.0001)
	modelCheckpoint = ModelCheckpoint(filepath='modelOUT/model_seq_epi_20210105_v1_' + config.log_id + '_{epoch:05d}.h5', verbose=1)
	csvLogger = CSVLogger("history/model_seq_epi_20210105_v1_" + config.log_id + ".csv")

	weight_for_0 = ( n2p_ratio + 1 )/(2*n2p_ratio)
	weight_for_1 = ( n2p_ratio + 1 )/2
	class_weight = {0: weight_for_0, 1: weight_for_1}

	print("keras parameter setting => DONE")


	# model compile and fitting
	model.compile(
					loss='binary_crossentropy',
					metrics=["accuracy", "Precision" , "Recall" , "TruePositives" , "FalsePositives" , "FalseNegatives" , "AUC"],
					optimizer=adam_optimizer)

	model_history = model.fit(train_dataset_gen,
					steps_per_epoch = 400,
					epochs=1000,
					validation_data = validate_dataset_gen,
					validation_steps= 100,
					class_weight=class_weight,
					callbacks=[modelCheckpoint,csvLogger]
					)