def train_recurrent(model_out, pubtator_file, directional_distant_directory,
                    symmetric_distant_directory, distant_entity_a_col,
                    distant_entity_b_col, distant_rel_col, entity_a, entity_b):

    #get distant_relations from external knowledge base file
    distant_interactions, reverse_distant_interactions = load_data.load_distant_directories(
        directional_distant_directory, symmetric_distant_directory,
        distant_entity_a_col, distant_entity_b_col, distant_rel_col)

    key_order = sorted(distant_interactions)
    #get pmids,sentences,
    training_pmids, training_forward_sentences, training_reverse_sentences, entity_a_text, entity_b_text = load_data.load_pubtator_abstract_sentences(
        pubtator_file, entity_a, entity_b)

    #training full model
    training_instances, \
    dep_path_list_dictionary, \
    dep_word_dictionary,word2vec_embeddings  = load_data.build_instances_training(training_forward_sentences,
                                                   training_reverse_sentences,
                                                   distant_interactions,
                                                   reverse_distant_interactions,
                                                   entity_a_text,
                                                   entity_b_text,
                                                   key_order,True)

    dep_path_list_features, dep_word_features, dep_type_path_length, dep_word_path_length, labels = load_data.build_recurrent_arrays(
        training_instances)
    features = [
        dep_path_list_features, dep_word_features, dep_type_path_length,
        dep_word_path_length
    ]

    if os.path.exists(model_out):
        shutil.rmtree(model_out)

    pickle.dump([dep_path_list_dictionary, dep_word_dictionary, key_order],
                open(model_out + 'a.pickle', 'wb'))

    trained_model_path = rnn.recurrent_train(features, labels,
                                             len(dep_path_list_dictionary),
                                             len(dep_word_dictionary),
                                             model_out + '/', key_order,
                                             word2vec_embeddings)

    print("trained model")

    return trained_model_path
def parallel_k_fold_cross_validation(batch_id, k, pmids, forward_sentences,
                                     reverse_sentences, distant_interactions,
                                     reverse_distant_interactions,
                                     entity_a_text, entity_b_text,
                                     hidden_array, key_order, recurrent):

    pmids = list(pmids)
    #split training sentences for cross validation
    ten_fold_length = len(pmids) / k
    all_chunks = [
        pmids[i:i + ten_fold_length]
        for i in xrange(0, len(pmids), ten_fold_length)
    ]

    total_test = []  #test_labels for instances
    total_predicted_prob = []  #test_probability returns for instances
    total_instances = []

    fold_chunks = all_chunks[:]
    fold_test_abstracts = set(fold_chunks.pop(batch_id))
    fold_training_abstracts = set(
        list(itertools.chain.from_iterable(fold_chunks)))

    fold_training_forward_sentences = {}
    fold_training_reverse_sentences = {}
    fold_test_forward_sentences = {}
    fold_test_reverse_sentences = {}

    for key in forward_sentences:
        if key.split('|')[0] in fold_training_abstracts:
            fold_training_forward_sentences[key] = forward_sentences[key]
        elif key.split('|')[0] in fold_test_abstracts:
            fold_test_forward_sentences[key] = forward_sentences[key]

    for key in reverse_sentences:
        if key.split('|')[0] in fold_training_abstracts:
            fold_training_reverse_sentences[key] = reverse_sentences[key]
        elif key.split('|')[0] in fold_test_abstracts:
            fold_test_reverse_sentences[key] = reverse_sentences[key]

    if recurrent is False:
        fold_training_instances, \
        fold_dep_dictionary, \
        fold_dep_word_dictionary,\
        fold_dep_element_dictionary,\
        fold_between_word_dictionary = load_data.build_instances_training(fold_training_forward_sentences,
                                                                      fold_training_reverse_sentences,
                                                                      distant_interactions,
                                                                      reverse_distant_interactions,
                                                                      entity_a_text,
                                                                      entity_b_text,key_order)

        #train model
        X = []
        y = []
        for t in fold_training_instances:
            X.append(t.features)
            y.append(t.label)

        fold_train_X = np.array(X)
        fold_train_y = np.array(y)

        model_dir = os.path.dirname(
            os.path.realpath(__file__)
        ) + '/model_building_meta_data/test' + str(batch_id) + str(
            time.time()).replace('.', '')
        if os.path.exists(model_dir):
            shutil.rmtree(model_dir)

        fold_test_instances = load_data.build_instances_testing(
            fold_test_forward_sentences, fold_test_reverse_sentences,
            fold_dep_dictionary, fold_dep_word_dictionary,
            fold_dep_element_dictionary, fold_between_word_dictionary,
            distant_interactions, reverse_distant_interactions, entity_a_text,
            entity_b_text, key_order)

        # group instances by pmid and build feature array
        fold_test_features = []
        fold_test_labels = []
        pmid_test_instances = {}
        for test_index in range(len(fold_test_instances)):
            fti = fold_test_instances[test_index]
            if fti.sentence.pmid not in pmid_test_instances:
                pmid_test_instances[fti.sentence.pmid] = []
            pmid_test_instances[fti.sentence.pmid].append(test_index)
            fold_test_features.append(fti.features)
            fold_test_labels.append(fti.label)

        fold_test_X = np.array(fold_test_features)
        fold_test_y = np.array(fold_test_labels)

        test_model = snn.feed_forward_train(fold_train_X, fold_train_y,
                                            fold_test_X, fold_test_y,
                                            hidden_array, model_dir + '/',
                                            key_order)

        fold_test_predicted_prob = snn.feed_forward_test(
            fold_test_X, fold_test_y, test_model)

        total_predicted_prob = fold_test_predicted_prob.tolist()
        total_test = fold_test_y.tolist()
        total_instances = fold_test_instances

        total_test = np.array(total_test)
        total_predicted_prob = np.array(total_predicted_prob)

        return total_predicted_prob, total_instances

    else:
        fold_training_instances, \
        fold_dep_path_list_dictionary, \
        fold_dep_word_dictionary,word2vec_embeddings = load_data.build_instances_training(fold_training_forward_sentences,
                                                                          fold_training_reverse_sentences,
                                                                          distant_interactions,
                                                                          reverse_distant_interactions,
                                                                          entity_a_text,
                                                                          entity_b_text, key_order,True)

        dep_path_list_features, dep_word_features, dep_type_path_length, dep_word_path_length, labels = load_data.build_recurrent_arrays(
            fold_training_instances)

        features = [
            dep_path_list_features, dep_word_features, dep_type_path_length,
            dep_word_path_length
        ]

        model_dir = os.path.dirname(
            os.path.realpath(__file__)
        ) + '/model_building_meta_data/test' + str(batch_id) + str(
            time.time()).replace('.', '')
        if os.path.exists(model_dir):
            shutil.rmtree(model_dir)

        trained_model_path = rnn.recurrent_train(
            features, labels, len(fold_dep_path_list_dictionary),
            len(fold_dep_word_dictionary), model_dir + '/', key_order,
            word2vec_embeddings)

        fold_test_instances = load_data.build_instances_testing(
            fold_test_forward_sentences, fold_test_reverse_sentences, None,
            fold_dep_word_dictionary, None, None, distant_interactions,
            reverse_distant_interactions, entity_a_text, entity_b_text,
            key_order, fold_dep_path_list_dictionary)

        # group instances by pmid and build feature array
        test_dep_path_list_features, test_dep_word_features, test_dep_type_path_length, test_dep_word_path_length, test_labels = load_data.build_recurrent_arrays(
            fold_test_instances)

        test_features = [
            test_dep_path_list_features, test_dep_word_features,
            test_dep_type_path_length, test_dep_word_path_length
        ]
        print(trained_model_path)
        fold_test_predicted_prob, fold_test_labels = rnn.recurrent_test(
            test_features, test_labels, trained_model_path)

        assert (np.array_equal(fold_test_labels, test_labels))

        total_predicted_prob = fold_test_predicted_prob.tolist()
        total_test = fold_test_labels.tolist()
        total_instances = fold_test_instances

        total_test = np.array(total_test)
        total_predicted_prob = np.array(total_predicted_prob)

        return total_predicted_prob, total_instances
def one_fold_cross_validation(model_out,
                              pmids,
                              forward_sentences,
                              reverse_sentences,
                              distant_interactions,
                              reverse_distant_interactions,
                              entity_a_text,
                              entity_b_text,
                              hidden_array,
                              key_order,
                              recurrent,
                              pubtator_labels=None):
    pmids = list(pmids)
    #split training sentences for cross validation
    testlength = int(len(pmids) * 0.50)
    random.shuffle(pmids)
    fold_test_abstracts = pmids[:testlength]
    fold_training_abstracts = pmids[testlength:]

    fold_training_forward_sentences = {}
    fold_training_reverse_sentences = {}
    fold_test_forward_sentences = {}
    fold_test_reverse_sentences = {}

    for key in forward_sentences:
        if key.split('|')[0] in fold_training_abstracts:
            fold_training_forward_sentences[key] = forward_sentences[key]
        elif key.split('|')[0] in fold_test_abstracts:
            fold_test_forward_sentences[key] = forward_sentences[key]

    for key in reverse_sentences:
        if key.split('|')[0] in fold_training_abstracts:
            fold_training_reverse_sentences[key] = reverse_sentences[key]
        elif key.split('|')[0] in fold_test_abstracts:
            fold_test_reverse_sentences[key] = reverse_sentences[key]

    if recurrent is False:
        fold_training_instances, \
        fold_dep_dictionary, \
        fold_dep_word_dictionary, \
        fold_dep_element_dictionary, \
        fold_between_word_dictionary = load_data.build_instances_training(fold_training_forward_sentences,
                                                                    fold_training_reverse_sentences,
                                                                    distant_interactions,
                                                                    reverse_distant_interactions,
                                                                    entity_a_text, entity_b_text, key_order)

        if pubtator_labels:
            fold_training_instances, \
            fold_dep_dictionary, \
            fold_dep_word_dictionary, \
            fold_dep_element_dictionary, \
            fold_between_word_dictionary = load_data.build_instances_labelled(fold_training_forward_sentences,
                                                                         fold_training_reverse_sentences,
                                                                         pubtator_labels, key_order[0],
                                                                         entity_a_text,
                                                                         entity_b_text,
                                                                         key_order)

        pickle.dump([
            fold_dep_dictionary, fold_dep_word_dictionary,
            fold_dep_element_dictionary, fold_between_word_dictionary,
            key_order
        ], open(model_out + 'a.pickle', 'wb'))

        # train model
        X = []
        y = []
        for t in fold_training_instances:
            X.append(t.features)
            y.append(t.label)

        fold_train_X = np.array(X)
        fold_train_y = np.array(y)

        model_dir = model_out
        if os.path.exists(model_dir):
            shutil.rmtree(model_dir)

        fold_test_instances = load_data.build_instances_testing(
            fold_test_forward_sentences, fold_test_reverse_sentences,
            fold_dep_dictionary, fold_dep_word_dictionary,
            fold_dep_element_dictionary, fold_between_word_dictionary,
            distant_interactions, reverse_distant_interactions, entity_a_text,
            entity_b_text, key_order)
        if pubtator_labels:
            fold_test_instances = load_data.build_labelled_testing(
                fold_test_forward_sentences, fold_test_reverse_sentences,
                fold_dep_dictionary, fold_dep_word_dictionary,
                fold_dep_element_dictionary, fold_between_word_dictionary,
                pubtator_labels, key_order[0], entity_a_text, entity_b_text,
                key_order)

        # group instances by pmid and build feature array
        fold_test_features = []
        fold_test_labels = []
        pmid_test_instances = {}
        for test_index in range(len(fold_test_instances)):
            fti = fold_test_instances[test_index]
            if fti.sentence.pmid not in pmid_test_instances:
                pmid_test_instances[fti.sentence.pmid] = []
            pmid_test_instances[fti.sentence.pmid].append(test_index)
            fold_test_features.append(fti.features)
            fold_test_labels.append(fti.label)

        fold_test_X = np.array(fold_test_features)
        fold_test_y = np.array(fold_test_labels)
        print(fold_test_y.shape)
        print(fold_test_X.shape)

        test_model = snn.feed_forward_train(fold_train_X, fold_train_y,
                                            fold_test_X, fold_test_y,
                                            hidden_array, model_dir + '/',
                                            key_order)

        group_instances = load_data.batch_instances(fold_test_instances)

        probability_dict = {}
        label_dict = {}
        cs_grad_dict = {}
        cs_hidden_act_dict = {}

        for g in group_instances:
            fold_test_features = []
            fold_test_labels = []
            for ti in group_instances[g]:
                fold_test_features.append(fold_test_instances[ti].features)
                fold_test_labels.append(fold_test_instances[ti].label)
            fold_test_X = np.array(fold_test_features)
            fold_test_y = np.array(fold_test_labels)

            fold_test_predicted_prob, fold_test_labels, fold_test_cs_grads, fold_test_cs_hidden_activations = snn.feed_forward_test(
                fold_test_X, fold_test_y, test_model)
            probability_dict[g] = fold_test_predicted_prob
            label_dict[g] = fold_test_labels
            for i in range(len(fold_test_cs_grads)):
                print(fold_test_cs_grads[i])
                print(fold_test_predicted_prob[i])
                cs_grad_dict[group_instances[g][i]] = [
                    fold_test_predicted_prob[i], fold_test_labels[i],
                    fold_test_cs_grads[i], group_instances[g]
                ]
                cs_hidden_act_dict[group_instances[g][i]] = [
                    fold_test_predicted_prob[i], fold_test_labels[i],
                    fold_test_cs_hidden_activations[i], group_instances[g]
                ]

    else:
        fold_training_instances, \
        fold_dep_path_list_dictionary, \
        fold_dep_word_dictionary, word2vec_embeddings = load_data.build_instances_training(fold_training_forward_sentences,
                                                                                      fold_training_reverse_sentences,
                                                                                      distant_interactions,
                                                                                      reverse_distant_interactions,
                                                                                      entity_a_text,
                                                                                      entity_b_text,
                                                                                      key_order, True)

        if pubtator_labels:
            fold_training_instances, \
            fold_dep_path_list_dictionary, \
            fold_dep_word_dictionary, word2vec_embeddings = load_data.build_instances_labelled(fold_training_forward_sentences,
                                                                                          fold_training_reverse_sentences,
                                                                                          pubtator_labels, key_order[0],
                                                                                          entity_a_text,
                                                                                          entity_b_text,
                                                                                          key_order, True)

        pickle.dump([
            fold_dep_path_list_dictionary, fold_dep_word_dictionary, key_order
        ], open(model_out + 'a.pickle', 'wb'))

        dep_path_list_features, dep_word_features, dep_type_path_length, dep_word_path_length, labels = load_data.build_recurrent_arrays(
            fold_training_instances)

        features = [
            dep_path_list_features, dep_word_features, dep_type_path_length,
            dep_word_path_length
        ]

        model_dir = model_out
        if os.path.exists(model_dir):
            shutil.rmtree(model_dir)

        test_model = rnn.recurrent_train(features, labels,
                                         len(fold_dep_path_list_dictionary),
                                         len(fold_dep_word_dictionary),
                                         model_dir + '/', key_order,
                                         word2vec_embeddings)

        fold_test_instances = load_data.build_instances_testing(
            fold_test_forward_sentences, fold_test_reverse_sentences, None,
            fold_dep_word_dictionary, None, None, distant_interactions,
            reverse_distant_interactions, entity_a_text, entity_b_text,
            key_order, fold_dep_path_list_dictionary)

        if pubtator_labels:
            fold_test_instances = load_data.build_labelled_testing(
                fold_test_forward_sentences, fold_test_reverse_sentences, None,
                fold_dep_word_dictionary, None, None, pubtator_labels,
                key_order[0], entity_a_text, entity_b_text, key_order,
                fold_dep_path_list_dictionary)

        group_instances = load_data.batch_instances(fold_test_instances)

        test_dep_path_list_features, test_dep_word_features, test_dep_type_path_length, test_dep_word_path_length, test_labels = load_data.build_recurrent_arrays(
            fold_test_instances)

        test_features = [
            test_dep_path_list_features, test_dep_word_features,
            test_dep_type_path_length, test_dep_word_path_length
        ]

        fold_test_predicted_prob, fold_test_labels, fold_test_cs_grads, fold_test_cs_hidden_activations = rnn.recurrent_test(
            test_features, test_labels, test_model)
        cs_grad_dict = {}
        cs_hidden_act_dict = {}
        for g in group_instances:
            for ti in group_instances[g]:
                print(fold_test_predicted_prob[ti])
                cs_grad_dict[ti] = [
                    fold_test_predicted_prob[ti], fold_test_labels[ti], [],
                    group_instances[g]
                ]

        # group instances by pmid and build feature array

    return fold_test_instances, cs_grad_dict, cs_hidden_act_dict
def train_labelled_and_distant(model_out, distant_pubtator_file,
                               labelled_pubtator_file, pubtator_labels,
                               directional_distant_directory,
                               symmetric_distant_directory,
                               distant_entity_a_col, distant_entity_b_col,
                               distant_rel_col, entity_a, entity_b, recurrent):

    # get distant_relations from external knowledge base file
    distant_interactions, reverse_distant_interactions = load_data.load_distant_directories(
        directional_distant_directory, symmetric_distant_directory,
        distant_entity_a_col, distant_entity_b_col, distant_rel_col)

    key_order = sorted(distant_interactions)

    distant_pmids, distant_forward_sentences, distant_reverse_sentences, distant_entity_a_text, distant_entity_b_text = load_data.load_pubtator_abstract_sentences(
        distant_pubtator_file, entity_a, entity_b)

    labelled_pmids, labelled_forward_sentences, labelled_reverse_sentences, labelled_entity_a_text, labelled_entity_b_text = load_data.load_pubtator_abstract_sentences(
        labelled_pubtator_file, entity_a, entity_b, True)

    pmids = distant_pmids.union(labelled_pmids)

    training_forward_sentences = distant_forward_sentences.copy()
    training_forward_sentences.update(labelled_forward_sentences)

    training_reverse_sentences = distant_reverse_sentences.copy()
    training_reverse_sentences.update(labelled_reverse_sentences)

    entity_a_text = distant_entity_a_text.copy()
    for k in labelled_entity_a_text:
        if k in entity_a_text:
            entity_a_text[k] = entity_a_text[k].union(
                labelled_entity_a_text[k])
        else:
            entity_a_text[k] = labelled_entity_a_text[k]

    entity_b_text = distant_entity_b_text.copy()
    for k in labelled_entity_b_text:
        if k in entity_b_text:
            entity_b_text[k] = entity_b_text[k].union(
                labelled_entity_b_text[k])
        else:
            entity_b_text[k] = labelled_entity_b_text[k]

    print(training_forward_sentences)

    if recurrent:
        # training full model
        training_instances, \
        dep_path_list_dictionary, \
        dep_word_dictionary, word2vec_embeddings = load_data.build_instances_labelled_and_distant(training_forward_sentences,
                                                                     training_reverse_sentences, distant_interactions, reverse_distant_interactions,
                                                                     pubtator_labels,'binds',
                                                                     entity_a_text,
                                                                     entity_b_text,
                                                                     key_order,True)

        dep_path_list_features, dep_word_features, dep_type_path_length, dep_word_path_length, labels = load_data.build_recurrent_arrays(
            training_instances)
        features = [
            dep_path_list_features, dep_word_features, dep_type_path_length,
            dep_word_path_length
        ]

        if os.path.exists(model_out):
            shutil.rmtree(model_out)

        trained_model_path = rnn.recurrent_train(features, labels,
                                                 len(dep_path_list_dictionary),
                                                 len(dep_word_dictionary),
                                                 model_out + '/', key_order,
                                                 word2vec_embeddings)

        pickle.dump([dep_path_list_dictionary, dep_word_dictionary, key_order],
                    open(model_out + 'a.pickle', 'wb'))
        print("trained model")

        return trained_model_path
    else:

        # hidden layer structure
        hidden_array = [256]

        # k-cross val
        # instance_predicts, single_instances= cv.k_fold_cross_validation(10,training_pmids,training_forward_sentences,
        #                                                                training_reverse_sentences,distant_interactions,
        #                                                                reverse_distant_interactions,entity_a_text,
        #                                                                entity_b_text,hidden_array,key_order)

        # cv.write_cv_output(model_out+'_predictions',instance_predicts,single_instances,key_order)

        # training full model
        training_instances, \
        dep_dictionary, \
        dep_word_dictionary, \
        dep_element_dictionary, \
        between_word_dictionary = load_data.build_instances_labelled_and_distant(training_forward_sentences,
                                                                     training_reverse_sentences,distant_interactions, reverse_distant_interactions,
                                                                     pubtator_labels,'binds',
                                                                     entity_a_text,
                                                                     entity_b_text,
                                                                     key_order)

        X = []
        y = []
        instance_sentences = set()
        for t in training_instances:
            instance_sentences.add(' '.join(t.sentence.sentence_words))
            X.append(t.features)
            y.append(t.label)

        X_train = np.array(X)
        y_train = np.array(y)

        print(X_train.shape)
        print(y_train.shape)

        if os.path.exists(model_out):
            shutil.rmtree(model_out)

        trained_model_path = ffnn.feed_forward_train(X_train, y_train, None,
                                                     None, hidden_array,
                                                     model_out + '/',
                                                     key_order)

        print('Number of Sentences')
        print(len(instance_sentences))
        print('Number of Instances')
        print(len(training_instances))
        print('Number of dependency paths ')
        print(len(dep_dictionary))
        print('Number of dependency words')
        print(len(dep_word_dictionary))
        print('Number of between words')
        print(len(between_word_dictionary))
        print('Number of elements')
        print(len(dep_element_dictionary))
        print('length of feature space')
        print(
            len(dep_dictionary) + len(dep_word_dictionary) +
            len(dep_element_dictionary) + len(between_word_dictionary))
        pickle.dump([
            dep_dictionary, dep_word_dictionary, dep_element_dictionary,
            between_word_dictionary, key_order
        ], open(model_out + 'a.pickle', 'wb'))
        print("trained model")

        return trained_model_path