Пример #1
0
def main_execution():
    combo_to_drugs_ids, combo_to_side_effects = load_drug_bank_combo_side_effect_file(
        fichier='polypharmacy/drugbank/drugbank-combo.csv')
    nodes = set([u for e in combo_to_drugs_ids.values() for u in e])
    n_drugs = len(nodes)
    relation_types = set([r for r in combo_to_side_effects.values()])
    n_drugdrug_rel_types = len(relation_types)
    drugs_to_positions_in_matrices_dict = {
        node: i
        for i, node in enumerate(nodes)
    }

    drug_drug_adj_list = []  # matrice d'adjacence de chaque drug_drug
    for i, el in enumerate(relation_types):  # pour chaque side effect
        mat = np.zeros((n_drugs, n_drugs))
        for d1, d2 in combinations(list(nodes), 2):
            temp_cle = '{}_{}'.format(d1, d2)
            if temp_cle in combo_to_side_effects.keys():
                if combo_to_side_effects[temp_cle] == el:
                    # chaque fois on a une réelle s.e entre les 2 drogues dans la matrice
                    mat[drugs_to_positions_in_matrices_dict[d1], drugs_to_positions_in_matrices_dict[d2]] = \
                        mat[drugs_to_positions_in_matrices_dict[d2], drugs_to_positions_in_matrices_dict[d1]] = 1.
                    # Inscrire une interaction
        drug_drug_adj_list.append(sp.csr_matrix(mat))
    drug_degrees_list = [
        np.array(drug_adj.sum(axis=0)).squeeze()
        for drug_adj in drug_drug_adj_list
    ]

    adj_mats_orig = {
        (0, 0):
        drug_drug_adj_list +
        [x.transpose(copy=True) for x in drug_drug_adj_list],
    }
    degrees = {
        0: drug_degrees_list + drug_degrees_list,
    }

    # features (drugs)
    drug_feat = sp.identity(n_drugs)
    drug_nonzero_feat, drug_num_feat = drug_feat.shape
    drug_feat = preprocessing.sparse_to_tuple(drug_feat.tocoo())

    # data representation
    num_feat = {
        0: drug_num_feat,
    }
    nonzero_feat = {
        0: drug_nonzero_feat,
    }
    feat = {
        0: drug_feat,
    }

    edge_type2dim = {
        k: [adj.shape for adj in adjs]
        for k, adjs in adj_mats_orig.items()
    }
    edge_type2decoder = {
        (0, 0): 'dedicom',
    }

    edge_types = {k: len(v) for k, v in adj_mats_orig.items()}
    num_edge_types = sum(edge_types.values())
    print("Edge types:", "%d" % num_edge_types)
    print("Defining placeholders")
    placeholders = construct_placeholders(edge_types)

    ###########################################################
    #
    # Create minibatch iterator, model and optimizer
    #
    ###########################################################

    print("Create minibatch iterator")
    minibatch = EdgeMinibatchIterator(adj_mats=adj_mats_orig,
                                      feat=feat,
                                      edge_types=edge_types,
                                      batch_size=FLAGS.batch_size,
                                      val_test_size=val_test_size)

    print("Create model")
    model = DecagonModel(
        placeholders=placeholders,
        num_feat=num_feat,
        nonzero_feat=nonzero_feat,
        edge_types=edge_types,
        decoders=edge_type2decoder,
    )

    print("Create optimizer")
    with tf.name_scope('optimizer'):
        opt = DecagonOptimizer(embeddings=model.embeddings,
                               latent_inters=model.latent_inters,
                               latent_varies=model.latent_varies,
                               degrees=degrees,
                               edge_types=edge_types,
                               edge_type2dim=edge_type2dim,
                               placeholders=placeholders,
                               batch_size=FLAGS.batch_size,
                               margin=FLAGS.max_margin)

    print("Initialize session")
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    feed_dict = {}

    ###########################################################
    #
    # Train model
    #
    ###########################################################

    print("Train model")
    for epoch in range(FLAGS.epochs):

        minibatch.shuffle()
        itr = 0
        while not minibatch.end():
            # Construct feed dictionary
            feed_dict = minibatch.next_minibatch_feed_dict(
                placeholders=placeholders)
            feed_dict = minibatch.update_feed_dict(feed_dict=feed_dict,
                                                   dropout=FLAGS.dropout,
                                                   placeholders=placeholders)

            t = time.time()

            # Training step: run single weight update
            outs = sess.run([opt.opt_op, opt.cost, opt.batch_edge_type_idx],
                            feed_dict=feed_dict)
            train_cost = outs[1]
            batch_edge_type = outs[2]

            if itr % PRINT_PROGRESS_EVERY == 0:
                val_auc, val_auprc, val_apk = get_accuracy_scores(
                    feed_dict, placeholders, sess, opt, minibatch,
                    adj_mats_orig, minibatch.val_edges,
                    minibatch.val_edges_false,
                    minibatch.idx2edge_type[minibatch.current_edge_type_idx])

                print("Epoch:", "%04d" % (epoch + 1), "Iter:",
                      "%04d" % (itr + 1), "Edge:", "%04d" % batch_edge_type,
                      "train_loss=", "{:.5f}".format(train_cost), "val_roc=",
                      "{:.5f}".format(val_auc), "val_auprc=",
                      "{:.5f}".format(val_auprc), "val_apk=",
                      "{:.5f}".format(val_apk), "time=",
                      "{:.5f}".format(time.time() - t))

            itr += 1

    print("Optimization finished!")

    for et in range(num_edge_types):
        roc_score, auprc_score, apk_score = get_accuracy_scores(
            feed_dict, placeholders, sess, opt, minibatch, adj_mats_orig,
            minibatch.test_edges, minibatch.test_edges_false,
            minibatch.idx2edge_type[et])
        print("Edge type=", "[%02d, %02d, %02d]" % minibatch.idx2edge_type[et])
        print("Edge type:", "%04d" % et, "Test AUROC score",
              "{:.5f}".format(roc_score))
        print("Edge type:", "%04d" % et, "Test AUPRC score",
              "{:.5f}".format(auprc_score))
        print("Edge type:", "%04d" % et, "Test AP@k score",
              "{:.5f}".format(apk_score))
        print()
Пример #2
0
###########################################################
#
# Train model
#
###########################################################

print 'Train model'
print_every = 1
for epoch in range(FLAGS.epochs):

    minibatch.shuffle()
    itr = 0
    while not minibatch.end():
        # Construct feed dictionary
        feed_dict = minibatch.next_minibatch_feed_dict(placeholders=placeholders)
        feed_dict = minibatch.update_feed_dict(
            feed_dict=feed_dict,
            dropout=FLAGS.dropout,
            placeholders=placeholders)

        t = time.time()

        # Training step: run single weight update
        outs = sess.run([opt.opt_op, opt.cost, opt.batch_edge_type_idx], feed_dict=feed_dict)
        train_cost = outs[1]
        batch_edge_type = outs[2]

        if itr % print_every == 0:
            val_auc, val_auprc, val_apk = get_accuracy_scores(
                minibatch.val_edges, minibatch.val_edges_false,
Пример #3
0
def main(args):
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--decagon_data_file_directory",
        type=str,
        help=
        "path to directory where bio-decagon-*.csv files are located, with trailing slash. "
        "Default is current directory",
        default='./')
    parser.add_argument(
        "--saved_files_directory",
        type=str,
        help=
        "path to directory where saved files files are located, with trailing slash. "
        "Default is current directory. If a decagon_model.ckpt* exists in this directory, it will "
        "be loaded and evaluated, and no training will be done.",
        default='./')
    parser.add_argument("--verbose",
                        help="increase output verbosity",
                        action="store_true",
                        default=False)
    args = parser.parse_args(args)

    decagon_data_file_directory = args.decagon_data_file_directory
    verbose = args.verbose
    script_start_time = datetime.now()

    # create pre-processed file that only has side effect with >=500 occurrences
    all_combos_df = pd.read_csv('%sbio-decagon-combo.csv' %
                                decagon_data_file_directory)
    side_effects_500 = all_combos_df["Polypharmacy Side Effect"].value_counts()
    side_effects_500 = side_effects_500[side_effects_500 >= 500].index.tolist()
    all_combos_df = all_combos_df[
        all_combos_df["Polypharmacy Side Effect"].isin(side_effects_500)]
    all_combos_df.to_csv('%sbio-decagon-combo-over500only.csv' %
                         decagon_data_file_directory,
                         index=False)

    # use pre=processed file that only contains the most common side effects (those with >= 500 drug pairs)
    drug_drug_net, combo2stitch, combo2se, se2name = load_combo_se(
        fname=('%sbio-decagon-combo-over500only.csv' %
               decagon_data_file_directory))
    # net is a networkx graph with genes(proteins) as nodes and protein-protein-interactions as edges
    # node2idx maps node id to node index
    gene_net, node2idx = load_ppi(fname=('%sbio-decagon-ppi.csv' %
                                         decagon_data_file_directory))
    # stitch2se maps (individual) stitch ids to a list of side effect ids
    # se2name_mono maps side effect ids that occur in the mono file to side effect names (shorter than se2name)
    stitch2se, se2name_mono = load_mono_se(fname=('%sbio-decagon-mono.csv' %
                                                  decagon_data_file_directory))
    # stitch2proteins maps stitch ids (drug) to protein (gene) ids
    drug_gene_net, stitch2proteins = load_targets(
        fname=('%sbio-decagon-targets-all.csv' % decagon_data_file_directory))
    # se2class maps side effect id to class name

    # this was 0.05 in the original code, but the paper says that 10% each are used for testing and validation
    val_test_size = 0.1
    n_genes = gene_net.number_of_nodes()
    gene_adj = nx.adjacency_matrix(gene_net)
    gene_degrees = np.array(gene_adj.sum(axis=0)).squeeze()

    ordered_list_of_drugs = list(drug_drug_net.nodes.keys())
    ordered_list_of_side_effects = list(se2name.keys())
    ordered_list_of_proteins = list(gene_net.nodes.keys())

    n_drugs = len(ordered_list_of_drugs)

    drug_gene_adj = sp.lil_matrix(np.zeros((n_drugs, n_genes)))
    for drug in stitch2proteins:
        for protein in stitch2proteins[drug]:
            # there are quite a few drugs in here that aren't in our list of 645,
            # and proteins that aren't in our list of 19081
            if drug in ordered_list_of_drugs and protein in ordered_list_of_proteins:
                drug_index = ordered_list_of_drugs.index(drug)
                gene_index = ordered_list_of_proteins.index(protein)
                drug_gene_adj[drug_index, gene_index] = 1

    drug_gene_adj = drug_gene_adj.tocsr()

    # needs to be drug vs. gene matrix (645x19081)
    gene_drug_adj = drug_gene_adj.transpose(copy=True)

    drug_drug_adj_list = []
    if not os.path.isfile("adjacency_matrices/sparse_matrix0000.npz"):
        # pre-initialize all the matrices
        print("Initializing drug-drug adjacency matrix list")
        start_time = datetime.now()
        print("Starting at %s" % str(start_time))

        n = len(ordered_list_of_side_effects)
        for i in range(n):
            drug_drug_adj_list.append(
                sp.lil_matrix(np.zeros((n_drugs, n_drugs))))
            if verbose:
                print("%s percent done" % str(100.0 * i / n))
        print("Done initializing at %s after %s" %
              (datetime.now(), datetime.now() - start_time))

        start_time = datetime.now()
        combo_finish_time = start_time
        print("Creating adjacency matrices for side effects")
        print("Starting at %s" % str(start_time))
        combo_count = len(combo2se)
        combo_counter = 0

        # for side_effect_type in ordered_list_of_side_effects:
        # for drug1, drug2 in combinations(list(range(n_drugs)), 2):

        for combo in combo2se.keys():
            side_effect_list = combo2se[combo]
            for present_side_effect in side_effect_list:
                # find the matrix we need to update
                side_effect_number = ordered_list_of_side_effects.index(
                    present_side_effect)
                # find the drugs for which we need to make the update
                drug_tuple = combo2stitch[combo]
                drug1_index = ordered_list_of_drugs.index(drug_tuple[0])
                drug2_index = ordered_list_of_drugs.index(drug_tuple[1])
                # update
                drug_drug_adj_list[side_effect_number][drug1_index,
                                                       drug2_index] = 1

            if verbose and combo_counter % 1000 == 0:
                print(
                    "Finished combo %s after %s . %d percent of combos done" %
                    (combo_counter, str(combo_finish_time - start_time),
                     (100.0 * combo_counter / combo_count)))
            combo_finish_time = datetime.now()
            combo_counter = combo_counter + 1

        print("Done creating adjacency matrices at %s after %s" %
              (datetime.now(), datetime.now() - start_time))

        start_time = datetime.now()
        print("Saving matrices to file")
        print("Starting at %s" % str(start_time))

        # save matrices to file
        if not os.path.isdir("adjacency_matrices"):
            os.mkdir("adjacency_matrices")
        for i in range(len(drug_drug_adj_list)):
            sp.save_npz('adjacency_matrices/sparse_matrix%04d.npz' % (i, ),
                        drug_drug_adj_list[i].tocoo())
        print("Done saving matrices to file at %s after %s" %
              (datetime.now(), datetime.now() - start_time))
    else:
        print("Loading adjacency matrices from file.")
        for i in range(len(ordered_list_of_side_effects)):
            drug_drug_adj_list.append(
                sp.load_npz('adjacency_matrices/sparse_matrix%04d.npz' % i))

    for i in range(len(drug_drug_adj_list)):
        drug_drug_adj_list[i] = drug_drug_adj_list[i].tocsr()

    start_time = datetime.now()
    print("Setting up for training")
    print("Starting at %s" % str(start_time))

    drug_degrees_list = [
        np.array(drug_adj.sum(axis=0)).squeeze()
        for drug_adj in drug_drug_adj_list
    ]

    # data representation
    global adj_mats_orig
    adj_mats_orig = {
        (0, 0): [gene_adj, gene_adj.transpose(copy=True)
                 ],  # protein-protein interactions (and inverses)
        (0, 1):
        [gene_drug_adj],  # protein-drug relationships (inverse of targets)
        (1, 0): [drug_gene_adj],  # drug-protein relationships (targets)
        # This creates an "inverse" relationship for every polypharmacy side effect, using the transpose of the
        # relationship's adjacency matrix, resulting in 2x the number of side effects (and adjacency matrices).
        (1, 1):
        drug_drug_adj_list +
        [x.transpose(copy=True) for x in drug_drug_adj_list],
    }
    degrees = {
        0: [gene_degrees, gene_degrees],
        1: drug_degrees_list + drug_degrees_list,
    }

    # featureless (genes)
    gene_feat = sp.identity(n_genes)
    gene_nonzero_feat, gene_num_feat = gene_feat.shape
    gene_feat = preprocessing.sparse_to_tuple(gene_feat.tocoo())

    # features (drugs)
    drug_feat = sp.identity(n_drugs)
    drug_nonzero_feat, drug_num_feat = drug_feat.shape
    drug_feat = preprocessing.sparse_to_tuple(drug_feat.tocoo())

    # data representation
    num_feat = {
        0: gene_num_feat,
        1: drug_num_feat,
    }
    nonzero_feat = {
        0: gene_nonzero_feat,
        1: drug_nonzero_feat,
    }
    feat = {
        0: gene_feat,
        1: drug_feat,
    }

    edge_type2dim = {
        k: [adj.shape for adj in adjs]
        for k, adjs in adj_mats_orig.items()
    }
    edge_type2decoder = {
        (0, 0): 'bilinear',
        (0, 1): 'bilinear',
        (1, 0): 'bilinear',
        (1, 1): 'dedicom',
    }

    edge_types = {k: len(v) for k, v in adj_mats_orig.items()}
    global num_edge_types
    num_edge_types = sum(edge_types.values())
    print("Edge types:", "%d" % num_edge_types)

    ###########################################################
    #
    # Settings and placeholders
    #
    ###########################################################

    # Important -- Do not evaluate/print validation performance every iteration as it can take
    # substantial amount of time
    PRINT_PROGRESS_EVERY = 10000

    print("Defining placeholders")
    construct_placeholders(edge_types)

    ###########################################################
    #
    # Create minibatch iterator, model and optimizer
    #
    ###########################################################

    global minibatch_iterator
    iterator_pickle_file_name = args.saved_files_directory + "minibatch_iterator.pickle"
    if os.path.isfile(iterator_pickle_file_name):
        print("Load minibatch iterator pickle")
        with open(iterator_pickle_file_name, 'rb') as pickle_file:
            minibatch_iterator = pickle.load(pickle_file)
    else:
        print("Create minibatch iterator")
        minibatch_iterator = EdgeMinibatchIterator(adj_mats=adj_mats_orig,
                                                   feat=feat,
                                                   edge_types=edge_types,
                                                   batch_size=FLAGS.batch_size,
                                                   val_test_size=val_test_size)
        print("Pickling minibatch iterator")
        with open(iterator_pickle_file_name, 'wb') as pickle_file:
            pickle.dump(minibatch_iterator, pickle_file)

    print("Create model")
    model = DecagonModel(
        placeholders=placeholders,
        num_feat=num_feat,
        nonzero_feat=nonzero_feat,
        edge_types=edge_types,
        decoders=edge_type2decoder,
    )

    print("Create optimizer")
    global optimizer
    with tf.name_scope('optimizer'):
        optimizer = DecagonOptimizer(embeddings=model.embeddings,
                                     latent_inters=model.latent_inters,
                                     latent_varies=model.latent_varies,
                                     degrees=degrees,
                                     edge_types=edge_types,
                                     edge_type2dim=edge_type2dim,
                                     placeholders=placeholders,
                                     batch_size=FLAGS.batch_size,
                                     margin=FLAGS.max_margin)

    print("Done setting up at %s after %s" %
          (datetime.now(), datetime.now() - start_time))

    print("Initialize session")
    global sess
    sess = tf.Session()

    decagon_model_file_name = args.saved_files_directory + "decagon_model.ckpt"
    saved_model_available = os.path.isfile(decagon_model_file_name + ".index")
    if saved_model_available:
        saver = tf.train.Saver()
        saver.restore(sess, decagon_model_file_name)
        print("Model restored.")
    if not saved_model_available:
        print("Training model")
        start_time = datetime.now()
        print("Starting at %s" % str(start_time))

        sess.run(tf.global_variables_initializer())
        feed_dict = {}

        ###########################################################
        #
        # Train model
        #
        ###########################################################

        saver = tf.train.Saver()

        print("Train model")
        epoch_losses = []
        for epoch in range(FLAGS.epochs):

            minibatch_iterator.shuffle()
            itr = 0
            while not minibatch_iterator.end():
                # Construct feed dictionary
                feed_dict = minibatch_iterator.next_minibatch_feed_dict(
                    placeholders=placeholders)
                feed_dict = minibatch_iterator.update_feed_dict(
                    feed_dict=feed_dict,
                    dropout=FLAGS.dropout,
                    placeholders=placeholders)

                t = time.time()

                # Training step: run single weight update
                outs = sess.run([
                    optimizer.opt_op, optimizer.cost,
                    optimizer.batch_edge_type_idx
                ],
                                feed_dict=feed_dict)
                train_cost = outs[1]
                batch_edge_type = outs[2]

                if itr % PRINT_PROGRESS_EVERY == 0:
                    val_auc, val_auprc, val_apk = get_accuracy_scores(
                        minibatch_iterator.val_edges,
                        minibatch_iterator.val_edges_false,
                        minibatch_iterator.idx2edge_type[
                            minibatch_iterator.current_edge_type_idx],
                        feed_dict)

                    print("Epoch:", "%04d" % (epoch + 1), "Iter:",
                          "%04d" % (itr + 1), "Edge:",
                          "%04d" % batch_edge_type, "train_loss=",
                          "{:.5f}".format(train_cost), "val_roc=",
                          "{:.5f}".format(val_auc), "val_auprc=",
                          "{:.5f}".format(val_auprc), "val_apk=",
                          "{:.5f}".format(val_apk), "time=",
                          "{:.5f}".format(time.time() - t))

                itr += 1
            validation_loss = get_validation_loss(
                edges_pos=minibatch_iterator.val_edges,
                edges_neg=minibatch_iterator.val_edges_false,
                feed_dict=feed_dict)
            print(
                "Epoch:", "%04d" % (epoch + 1),
                "Validation loss (average cross entropy): {}".format(
                    validation_loss))

            epoch_losses.append(validation_loss)
            if len(epoch_losses) >= 3:
                if round(epoch_losses[-1], 3) >= round(
                        epoch_losses[-2], 3) >= round(epoch_losses[-3], 3):
                    break

            print("Saving model after epoch:", epoch)
            save_path = saver.save(
                sess, args.saved_files_directory + "decagon_model" +
                str(epoch) + ".ckpt")
            print("Model saved in path: %s" % save_path)

        print("Optimization finished!")
        print("Done training model %s after %s" %
              (datetime.now(), datetime.now() - start_time))

        print("Saving model")
        save_path = saver.save(sess, decagon_model_file_name)
        print("Model saved in path: %s" % save_path)

        print("Pickling minibatch iterator")
        with open(iterator_pickle_file_name, 'wb') as pickle_file:
            pickle.dump(minibatch_iterator, pickle_file)

    start_time = datetime.now()
    print("Evaluating model")
    print("Starting at %s" % str(start_time))

    for edge_type in range(num_edge_types):
        # get all edges in test set with this type
        feed_dict = minibatch_iterator.test_feed_dict(
            edge_type, placeholders=placeholders)
        feed_dict = minibatch_iterator.update_feed_dict(
            feed_dict, FLAGS.dropout, placeholders)
        edge_tuple = minibatch_iterator.idx2edge_type[edge_type]

        _, _, all_scores, all_labels, subjects, predicates, objects = get_predictions(
            edges_pos=minibatch_iterator.test_edges,
            edges_neg=minibatch_iterator.test_edges_false,
            edge_type=edge_tuple,
            feed_dict=feed_dict)

        print("subject\tpredicate\tobject\tpredicted\tactual")
        for i in range(len(all_scores)):
            subject = subjects[i]
            if edge_tuple[0] == 1:
                subject = ordered_list_of_drugs[subject]
            else:
                subject = ordered_list_of_proteins[subject]

            object = objects[i]
            if edge_tuple[1] == 1:
                object = ordered_list_of_drugs[object]
            else:
                object = ordered_list_of_proteins[object]

            predicate = predicates[i]
            if edge_tuple[:2] == (1, 1):
                side_effect_index = edge_tuple[2]
                is_inverse = False
                if side_effect_index >= 963:
                    side_effect_index = side_effect_index - 963
                    is_inverse = True
                predicate = ordered_list_of_side_effects[side_effect_index]
                if is_inverse:
                    predicate = predicate + "_2"

            print("{}\t{}\t{}\t{}\t{}".format(subject, predicate, object,
                                              all_scores[i], all_labels[i]))

    print()

    print("Done evaluating at %s after %s" %
          (datetime.now(), datetime.now() - start_time))

    print("Script running time: %s" % (datetime.now() - script_start_time))