예제 #1
0
def get_trained_network(n_epochs, oversampling_ratio):
    if uf.manage_obj('trained_networks/VNN_%d_%.2f' %
                     (n_epochs, oversampling_ratio)) is False:

        layer_info, matrix_list, data_set = set_inputs(oversampling_ratio)

        net, criterion, optimizer, device = set_network(
            layer_info, matrix_list)

        train_loader, test_loader = uf.get_train_test_valid_set(data_set, 0.3)

        train_loader = td.DataLoader(data_set,
                                     batch_size=50,
                                     shuffle=True,
                                     num_workers=10)

        train_epoch(net, optimizer, criterion, train_loader, test_loader,
                    device, n_epochs)

        torch.save(
            net, '../data/objects/trained_networks/VNN_%d_%.2f.obj' %
            (n_epochs, oversampling_ratio))
        net.eval()
        return net
    else:
        net = torch.load('../data/objects/trained_networks/VNN_%d_%.2f.obj' %
                         (n_epochs, oversampling_ratio))
        net.eval()
        return net
예제 #2
0
def get_RWVNN_path(single_data, is_comb, n_epochs, oversampling_ratio, *drug_name):
    net = vnn.get_trained_network(n_epochs, oversampling_ratio)
    if is_comb:
        dir_path = "../RW/RW_combi_DEG_prediction_results"
    else:
        dir_path = "../RW/RW_DEG_prediction_results"

    filenames = os.listdir(dir_path)
    drug_target_dict = get_target_dict(is_comb)
    if not single_data:
        # for all drug
        for i, filename in enumerate(filenames):
            if not filename.endswith("_RW_result.txt"):
                drug_name = filename.split('.')[0]
            else:
                continue

            if drug_name not in drug_target_dict.keys():
                continue
            full_name = os.path.join(dir_path, filename)
            degs = list()
            for deg in uf.read_file(full_name):
                degs += deg

            framework(is_comb, drug_name, drug_target_dict[drug_name], degs, net)

    # for single drug
    elif single_data:
        degs = list()
        for deg in uf.read_file(os.path.join(dir_path, drug_name + ".txt")):
            degs += deg
        if drug_name not in drug_target_dict.keys():
            print('test fail')
            return 0
        framework(is_comb, drug_name, drug_target_dict[drug_name], degs, net)
예제 #3
0
def get_auroc(n_epochs, cross_validation, oversampling_ratio):

    if cross_validation is False:

        layer_info, matrix_list, data_set = set_inputs(oversampling_ratio)

        net, criterion, optimizer, device = set_network(
            layer_info, matrix_list)

        train_loader, test_loader = uf.get_train_test_valid_set(data_set, 0.3)

        ########################################
        #########      TRAINING       ##########
        ########################################
        train_epoch(net, optimizer, criterion, train_loader, test_loader,
                    device, n_epochs)
        ########################################

        net.eval()
        result = uf.get_sigmoid_auroc(test_loader, device, net)
        result = [n_epochs, result]
        return result
    else:
        layer_info, matrix_list, data_set = set_inputs(oversampling_ratio)
        cross_validation_idx = uf.cross_validation_index(data_set)
        result = list()
        for train_idx, test_idx in cross_validation_idx:

            net, criterion, optimizer, device = set_network(
                layer_info, matrix_list)

            train_sampler = SubsetRandomSampler(train_idx)
            test_sampler = SubsetRandomSampler(test_idx)

            train_loader = td.DataLoader(data_set,
                                         batch_size=50,
                                         sampler=train_sampler,
                                         num_workers=10)
            test_loader = td.DataLoader(data_set,
                                        batch_size=50,
                                        sampler=test_sampler,
                                        num_workers=10)

            ########################################
            #########      TRAINING       ##########
            ########################################
            train_epoch(net, optimizer, criterion, train_loader, test_loader,
                        device, n_epochs)
            ########################################

            net.eval()
            scores = uf.get_sigmoid_auroc(test_loader, device, net)
            net.train()

            scores = [n_epochs, scores]
            result.append(scores)
        return result
예제 #4
0
def get_disease_comb_drug():
    path = "../data/drug/Combination_Drug_Targets_From_DCDB.txt"
    combi_file = uf.read_file(path)
    result = list()
    for combi in combi_file:
        if combi[1] == "P":
            result.append(combi[0])
    return result
예제 #5
0
def build_network(file_path: str) -> nx.DiGraph:
    result = nx.DiGraph()
    relation_list = uf.read_file(file_path)

    for relation in relation_list:
        result.add_edge(relation[0], relation[1], type=relation[2])

    return result
예제 #6
0
def get_vnn_prediction(degs, net):
    # get vnn prediction
    result = list()
    relation_files = ['GE_MF', 'MF_BP', 'BP_PH']
    gene_index = uf.set_layer_info(relation_files)[0]

    params = rpf.get_sigmoid_network_params(
        net, degs, gene_index)  # [pred_score, BP, MF, net_paramters]
    result.append(params[0])  # prediction score
    return result
예제 #7
0
def framework(is_comb, drug_name, targets, degs, net):
    # get RW result
    network_path = '../data/network/RW/GP_GP_Four_Relation_Types.txt'
    network = build_network(network_path)
    rw_paths = get_rw_path(network, targets, degs)
    gene_RW_dict = get_rw_score_dcit(is_comb, drug_name)

    #get vnn result
    relation_files = ['GE_MF', 'MF_BP', 'BP_PH']
    layer_info = uf.set_layer_info(relation_files)
    matrix_list = uf.set_mask_matrics(relation_files, layer_info)
    gene_index = layer_info[0]
    params = get_sigmoid_network_params(net, degs, gene_index)  # [pred_score, BP, MF, net_paramters]
    vnn_paths = get_vnn_path(degs, params, layer_info, matrix_list)

    #write result
    result = get_total_path(rw_paths, gene_RW_dict, vnn_paths)
    save_path = 'path_result/%s_path.tsv' % drug_name
    uf.write_tsv(result, save_path)
예제 #8
0
def get_vnn_score(dirname, output_path, is_comb, net):
    if is_comb:
        disease_drug = get_disease_comb_drug()
    else:
        disease_drug = uf.get_diseas_drug()

    filenames = os.listdir(dirname)

    result = list()
    result.append(['drug_name', 'label', 'prediction'])

    for i, filename in enumerate(filenames):

        if not filename.endswith("_RW_result.txt"):
            drug_name = filename.split('.')[0]
        else:
            continue

        full_name = os.path.join(dirname, filename)

        degs = list()
        for deg in uf.read_file(full_name):
            degs += deg

        line = list()
        line.append(drug_name)

        if not is_comb:
            drug_name = drug_name.lower()

        if drug_name in disease_drug:
            line.append(1)
        else:
            line.append(0)

        line += get_vnn_prediction(degs, net)

        result.append(line)

    uf.write_tsv(result, output_path)
예제 #9
0
def get_rw_score_dcit(is_comb, drug_name):
    result = dict()

    if is_comb:
        dir_path = "../RW/RW_combi_DEG_prediction_results/"
    else:
        dir_path = "../RW/RW_DEG_prediction_results/"

    raw_file = uf.read_file(''.join([dir_path, drug_name, "_RW_result.txt"]))

    for line in raw_file:
        result[line[0]] = float(line[1])

    return result
예제 #10
0
def get_target_dict(is_comb) -> dict:
    result = dict()

    if not is_comb:

        path = "../data/drug/Single_Drug_Target.txt"
        target_file = uf.read_file(path)

        for line in target_file:
            result[line[0]] = line[1].split('|')

        return result

    elif is_comb:

        path = "../data/drug/Combination_Drug_Targets_From_DCDB.txt"
        target_file = uf.read_file(path)
        target_file.remove(target_file[0])

        for line in target_file:
            result[line[0]] = line[3].split('|')

        return result
예제 #11
0
def set_inputs(oversampling_ratio):
    relation_files = ['GE_MF', 'MF_BP', 'BP_PH']

    layer_info = uf.set_layer_info(relation_files)
    gene_index = layer_info[0]

    matrix_list = uf.set_mask_matrics(relation_files, layer_info)

    disease_drug = uf.get_diseas_drug()
    deg_drug_train = uf.get_DEG_drug(disease_drug)

    data_set = uf.set_input_data(deg_drug_train, gene_index, 'train_data')

    # oversampling
    data_set = uf.oversampling(data_set, oversampling_ratio)

    return [layer_info, matrix_list, data_set]