Esempio n. 1
0
def evaluate(args, model, tcrs, peps, signs):
    # Word to index dictionary
    amino_acids = [letter for letter in 'ARNDCEQGHILKMFPSTWYV']
    if args.model_type == 'lstm':
        amino_to_ix = {
            amino: index
            for index, amino in enumerate(['PAD'] + amino_acids)
        }
    if args.model_type == 'ae':
        pep_atox = {
            amino: index
            for index, amino in enumerate(['PAD'] + amino_acids)
        }
        tcr_atox = {
            amino: index
            for index, amino in enumerate(amino_acids + ['X'])
        }
    max_len = 28
    batch_size = 50

    # Predict
    if args.model_type == 'ae':
        test_batches = ae.get_full_batches(tcrs, peps, signs, tcr_atox,
                                           pep_atox, batch_size, max_len)
        auc, roc = ae.evaluate_full(model, test_batches, args.device)
    if args.model_type == 'lstm':
        lstm.convert_data(tcrs, peps, amino_to_ix)
        test_batches = lstm.get_full_batches(tcrs, peps, signs, batch_size,
                                             amino_to_ix)
        auc, roc = lstm.evaluate_full(model, test_batches, args.device)
    return auc, roc
Esempio n. 2
0
def single_protein_score(args, model, test_data, protein, protein_peps):
    # positive examples - tcr in test that bind a pep belongs to the protein
    # negative examples - tcr in test that do not bind a pep belongs to the protein

    # Get pep-relevant data
    tcrs = [p[0] for p in test_data if p[1][0] in protein_peps[protein]]
    signs_to_prob = {'n': 0.0, 'p': 1.0}
    signs = [
        signs_to_prob[p[2]] for p in test_data
        if p[1][0] in protein_peps[protein]
    ]
    peps = [p[1][0] for p in test_data if p[1][0] in protein_peps[protein]]

    # Word to index dictionary
    amino_acids = [letter for letter in 'ARNDCEQGHILKMFPSTWYV']
    if args.model_type == 'lstm':
        amino_to_ix = {
            amino: index
            for index, amino in enumerate(['PAD'] + amino_acids)
        }
    if args.model_type == 'ae':
        pep_atox = {
            amino: index
            for index, amino in enumerate(['PAD'] + amino_acids)
        }
        tcr_atox = {
            amino: index
            for index, amino in enumerate(amino_acids + ['X'])
        }
    max_len = 28
    batch_size = 50

    if args.model_type == 'ae':
        test_batches = ae.get_full_batches(tcrs, peps, signs, tcr_atox,
                                           pep_atox, batch_size, max_len)
        test_auc, roc = ae.evaluate_full(model, test_batches, args.device)
    if args.model_type == 'lstm':
        lstm.convert_data(tcrs, peps, amino_to_ix)
        test_batches = lstm.get_full_batches(tcrs, peps, signs, batch_size,
                                             amino_to_ix)
        test_auc, roc = lstm.evaluate_full(model, test_batches, args.device)
    return test_auc, roc
Esempio n. 3
0
def single_peptide_score(args, model, test_data, pep, neg_type=None):
    # positive examples - tcr in test that bind this pep
    # negative examples - tcr in test that do not bind this pep
    # negs could be from test pairs, or naive, or memory

    # Get pep-relevant data
    tcrs = [p[0] for p in test_data if p[1][0] == pep]
    signs_to_prob = {'n': 0.0, 'p': 1.0}
    signs = [signs_to_prob[p[2]] for p in test_data if p[1][0] == pep]
    peps = [pep] * len(tcrs)

    # Word to index dictionary
    amino_acids = [letter for letter in 'ARNDCEQGHILKMFPSTWYV']
    if args.model_type == 'lstm':
        amino_to_ix = {
            amino: index
            for index, amino in enumerate(['PAD'] + amino_acids)
        }
    if args.model_type == 'ae':
        pep_atox = {
            amino: index
            for index, amino in enumerate(['PAD'] + amino_acids)
        }
        tcr_atox = {
            amino: index
            for index, amino in enumerate(amino_acids + ['X'])
        }
    max_len = 28
    batch_size = 50

    if args.model_type == 'ae':
        test_batches = ae.get_full_batches(tcrs, peps, signs, tcr_atox,
                                           pep_atox, batch_size, max_len)
        test_auc, roc = ae.evaluate_full(model, test_batches, args.device)
    if args.model_type == 'lstm':
        lstm.convert_data(tcrs, peps, amino_to_ix)
        test_batches = lstm.get_full_batches(tcrs, peps, signs, batch_size,
                                             amino_to_ix)
        test_auc, roc = lstm.evaluate_full(model, test_batches, args.device)
    return test_auc, roc
Esempio n. 4
0
def predict(args, model, tcrs, peps):
    assert len(tcrs) == len(peps)
    tcrs_copy = tcrs.copy()
    peps_copy = peps.copy()
    dummy_signs = [0.0] * len(tcrs)

    # Word to index dictionary
    amino_acids = [letter for letter in 'ARNDCEQGHILKMFPSTWYV']
    if args.model_type == 'lstm':
        amino_to_ix = {
            amino: index
            for index, amino in enumerate(['PAD'] + amino_acids)
        }
    if args.model_type == 'ae':
        pep_atox = {
            amino: index
            for index, amino in enumerate(['PAD'] + amino_acids)
        }
        tcr_atox = {
            amino: index
            for index, amino in enumerate(amino_acids + ['X'])
        }
    max_len = 28
    batch_size = 50

    # Predict
    if args.model_type == 'ae':
        test_batches = ae.get_full_batches(tcrs, peps, dummy_signs, tcr_atox,
                                           pep_atox, batch_size, max_len)
        preds = ae.predict(model, test_batches, args.device)
    if args.model_type == 'lstm':
        lstm.convert_data(tcrs, peps, amino_to_ix)
        test_batches = lstm.get_full_batches(tcrs, peps, dummy_signs,
                                             batch_size, amino_to_ix)
        preds = lstm.predict(model, test_batches, args.device)
    # Print predictions
    # for tcr, pep, pred in zip(tcrs_copy, peps_copy, preds):
    #     print('\t'.join([tcr, pep, str(pred)]))
    return tcrs_copy, peps_copy, preds
Esempio n. 5
0
def predict(args):
    # Word to index dictionary
    amino_acids = [letter for letter in 'ARNDCEQGHILKMFPSTWYV']
    if args.model_type == 'lstm':
        amino_to_ix = {
            amino: index
            for index, amino in enumerate(['PAD'] + amino_acids)
        }
    if args.model_type == 'ae':
        pep_atox = {
            amino: index
            for index, amino in enumerate(['PAD'] + amino_acids)
        }
        tcr_atox = {
            amino: index
            for index, amino in enumerate(amino_acids + ['X'])
        }

    if args.ae_file == 'auto':
        args.ae_file = 'TCR_Autoencoder/tcr_autoencoder.pt'
    if args.model_file == 'auto':
        dir = 'models'
        p_key = 'protein' if args.protein else ''
        args.model_file = dir + '/' + '_'.join(
            [args.model_type, args.dataset, args.sampling, p_key, 'model.pt'])
    if args.test_data_file == 'auto':
        args.test_data_file = 'pairs_example.csv'

    # Read test data
    tcrs = []
    peps = []
    signs = []
    max_len = 28
    with open(args.test_data_file, 'r') as csv_file:
        reader = csv.reader(csv_file)
        for line in reader:
            tcr, pep = line
            if args.model_type == 'ae' and len(tcr) >= max_len:
                continue
            tcrs.append(tcr)
            peps.append(pep)
            signs.append(0.0)
    tcrs_copy = tcrs.copy()
    peps_copy = peps.copy()

    # Load model
    device = args.device
    if args.model_type == 'ae':
        model = AutoencoderLSTMClassifier(10, device, 28, 21, 30, 50,
                                          args.ae_file, False)
        checkpoint = torch.load(args.model_file, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval()
    if args.model_type == 'lstm':
        model = DoubleLSTMClassifier(10, 30, 0.1, device)
        checkpoint = torch.load(args.model_file, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval()
        pass

    # Predict
    batch_size = 50
    if args.model_type == 'ae':
        test_batches = ae.get_full_batches(tcrs, peps, signs, tcr_atox,
                                           pep_atox, batch_size, max_len)
        preds = ae.predict(model, test_batches, device)
    if args.model_type == 'lstm':
        lstm.convert_data(tcrs, peps, amino_to_ix)
        test_batches = lstm.get_full_batches(tcrs, peps, signs, batch_size,
                                             amino_to_ix)
        preds = lstm.predict(model, test_batches, device)

    # Print predictions
    for tcr, pep, pred in zip(tcrs_copy, peps_copy, preds):
        print('\t'.join([tcr, pep, str(pred)]))
Esempio n. 6
0
def protein_test(args):
    assert args.protein
    # Word to index dictionary
    amino_acids = [letter for letter in 'ARNDCEQGHILKMFPSTWYV']
    if args.model_type == 'lstm':
        amino_to_ix = {
            amino: index
            for index, amino in enumerate(['PAD'] + amino_acids)
        }
    if args.model_type == 'ae':
        pep_atox = {
            amino: index
            for index, amino in enumerate(['PAD'] + amino_acids)
        }
        tcr_atox = {
            amino: index
            for index, amino in enumerate(amino_acids + ['X'])
        }

    if args.ae_file == 'auto':
        args.ae_file = 'TCR_Autoencoder/tcr_autoencoder.pt'
    if args.test_data_file == 'auto':
        dir = 'memory_and_protein'
        p_key = 'protein' if args.protein else ''
        args.test_data_file = dir + '/' + '_'.join([
            args.model_type, args.dataset, args.sampling, p_key, 'test.pickle'
        ])
    if args.model_file == 'auto':
        dir = 'memory_and_protein'
        p_key = 'protein' if args.protein else ''
        args.model_file = dir + '/' + '_'.join(
            [args.model_type, args.dataset, args.sampling, p_key, 'model.pt'])

    # Read test data
    with open(args.test_data_file, 'rb') as handle:
        test = pickle.load(handle)

    device = args.device
    if args.model_type == 'ae':
        test_tcrs, test_peps, test_signs = ae_get_lists_from_pairs(test, 28)
        model = AutoencoderLSTMClassifier(10, device, 28, 21, 30, 50,
                                          args.ae_file, False)
        checkpoint = torch.load(args.model_file)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval()
    if args.model_type == 'lstm':
        test_tcrs, test_peps, test_signs = lstm_get_lists_from_pairs(test)
        model = DoubleLSTMClassifier(10, 30, 0.1, device)
        checkpoint = torch.load(args.model_file)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval()
        pass

    # Get frequent peps list
    if args.dataset == 'mcpas':
        datafile = 'McPAS-TCR.csv'
    p = []
    protein_peps = {}
    with open(datafile, 'r', encoding='unicode_escape') as file:
        file.readline()
        reader = csv.reader(file)
        for line in reader:
            pep, protein = line[11], line[9]
            if protein == 'NA' or pep == 'NA':
                continue
            p.append(protein)
            try:
                protein_peps[protein].append(pep)
            except KeyError:
                protein_peps[protein] = [pep]

    d = {t: p.count(t) for t in set(p)}
    sorted_d = sorted(d.items(), key=lambda k: k[1], reverse=True)
    proteins = [t[0] for t in sorted_d]
    """
    McPAS most frequent proteins
    NP177   Influenza
    Matrix protein (M1) Influenza
    pp65    Cytomegalovirus (CMV)
    BMLF-1  Epstein Barr virus (EBV)
    PB1 Influenza
    """
    rocs = []
    for protein in proteins[:50]:
        protein_shows = [
            i for i in range(len(test_peps))
            if test_peps[i] in protein_peps[protein]
        ]
        test_tcrs_protein = [test_tcrs[i] for i in protein_shows]
        test_peps_protein = [test_peps[i] for i in protein_shows]
        test_signs_protein = [test_signs[i] for i in protein_shows]
        if args.model_type == 'ae':
            test_batches_protein = ae.get_full_batches(test_tcrs_protein,
                                                       test_peps_protein,
                                                       test_signs_protein,
                                                       tcr_atox, pep_atox, 50,
                                                       28)
        if args.model_type == 'lstm':
            lstm.convert_data(test_tcrs_protein, test_peps_protein,
                              amino_to_ix)
            test_batches_protein = lstm.get_full_batches(
                test_tcrs_protein, test_peps_protein, test_signs_protein, 50,
                amino_to_ix)
        if len(protein_shows):
            try:
                if args.model_type == 'ae':
                    test_auc, roc = ae.evaluate_full(model,
                                                     test_batches_protein,
                                                     device)
                if args.model_type == 'lstm':
                    test_auc, roc = lstm.evaluate_full(model,
                                                       test_batches_protein,
                                                       device)
                rocs.append((pep, roc))
                # print(protein)
                print(str(test_auc))
                # print(protein + ', ' + str(test_auc))
            except ValueError:
                # print(protein)
                print('NA')
                # print(protein + ', ' 'NA')
                pass
    return rocs
Esempio n. 7
0
def pep_test(args):
    # Word to index dictionary
    amino_acids = [letter for letter in 'ARNDCEQGHILKMFPSTWYV']
    if args.model_type == 'lstm':
        amino_to_ix = {
            amino: index
            for index, amino in enumerate(['PAD'] + amino_acids)
        }
    if args.model_type == 'ae':
        pep_atox = {
            amino: index
            for index, amino in enumerate(['PAD'] + amino_acids)
        }
        tcr_atox = {
            amino: index
            for index, amino in enumerate(amino_acids + ['X'])
        }

    if args.ae_file == 'auto':
        args.ae_file = 'TCR_Autoencoder/tcr_ae_dim_100.pt'
    if args.test_data_file == 'auto':
        dir = 'final_results'
        p_key = 'protein' if args.protein else ''
        args.test_data_file = dir + '/' + '_'.join([
            args.model_type, args.dataset, args.sampling, p_key, 'test.pickle'
        ])
    if args.model_file == 'auto':
        dir = 'final_results'
        p_key = 'protein' if args.protein else ''
        args.model_file = dir + '/' + '_'.join(
            [args.model_type, args.dataset, args.sampling, p_key, 'model.pt'])

    # Read test data
    with open(args.test_data_file, 'rb') as handle:
        test = pickle.load(handle)

    device = args.device
    if args.model_type == 'ae':
        test_tcrs, test_peps, test_signs = ae_get_lists_from_pairs(test, 28)
        model = AutoencoderLSTMClassifier(10, device, 28, 21, 100, 50,
                                          args.ae_file, False)
        checkpoint = torch.load(args.model_file)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval()
    if args.model_type == 'lstm':
        test_tcrs, test_peps, test_signs = lstm_get_lists_from_pairs(test)
        model = DoubleLSTMClassifier(10, 500, 0.1, device)
        checkpoint = torch.load(args.model_file)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval()
        pass

    # Get frequent peps list
    if args.dataset == 'mcpas':
        datafile = 'nine_class_testdata.csv'
    p = []
    with open(datafile, 'r', encoding='unicode_escape') as file:
        file.readline()
        reader = csv.reader(file)
        for line in reader:
            pep = line[1]
            if pep == 'NA':
                continue
            p.append(pep)
    d = {t: p.count(t) for t in set(p)}
    sorted_d = sorted(d.items(), key=lambda k: k[1], reverse=True)
    peps = [t[0] for t in sorted_d]
    """
    McPAS most frequent peps
    LPRRSGAAGA  Influenza
    GILGFVFTL   Influenza
    GLCTLVAML   Epstein Barr virus (EBV)	
    NLVPMVATV   Cytomegalovirus (CMV)	
    SSYRRPVGI   Influenza
    """
    rocs = []
    auc = []
    for pep in peps[:249]:
        pep_shows = [i for i in range(len(test_peps)) if pep == test_peps[i]]
        test_tcrs_pep = [test_tcrs[i] for i in pep_shows]
        test_peps_pep = [test_peps[i] for i in pep_shows]
        test_signs_pep = [test_signs[i] for i in pep_shows]
        if args.model_type == 'ae':
            test_batches_pep = ae.get_full_batches(test_tcrs_pep,
                                                   test_peps_pep,
                                                   test_signs_pep, tcr_atox,
                                                   pep_atox, 50, 28)
        if args.model_type == 'lstm':
            lstm.convert_data(test_tcrs_pep, test_peps_pep, amino_to_ix)
            test_batches_pep = lstm.get_full_batches(test_tcrs_pep,
                                                     test_peps_pep,
                                                     test_signs_pep, 50,
                                                     amino_to_ix)
        if len(pep_shows):
            try:
                if args.model_type == 'ae':
                    test_auc, roc = ae.evaluate_full(model, test_batches_pep,
                                                     device)
                    filename = 'pep_test_result_ae2'
                if args.model_type == 'lstm':
                    test_auc, roc = lstm.evaluate_full(model, test_batches_pep,
                                                       device)
                    filename = 'pep_test_result_lstm2'
                rocs.append((pep, roc))
                auc.append([pep, roc, test_auc])
                print(str(test_auc))
                # print(pep + ', ' + str(test_auc))
            except ValueError:
                print('NA')
                # print(pep + ', ' 'NA')
                pass
    pickle.dump(auc, open(filename, 'wb'))
Esempio n. 8
0
def pep_test(args):
    # Word to index dictionary
    amino_acids = [letter for letter in 'ARNDCEQGHILKMFPSTWYV']
    if args.model_type == 'lstm':
        amino_to_ix = {
            amino: index
            for index, amino in enumerate(['PAD'] + amino_acids)
        }
    if args.model_type == 'ae':
        pep_atox = {
            amino: index
            for index, amino in enumerate(['PAD'] + amino_acids)
        }
        tcr_atox = {
            amino: index
            for index, amino in enumerate(amino_acids + ['X'])
        }

    if args.ae_file == 'auto':
        args.ae_file = 'TCR_Autoencoder/tcr_ae_dim_100.pt'
    if args.test_data_file == 'auto':
        dir = 'final_results'
        p_key = 'protein' if args.protein else ''
        args.test_data_file = dir + '/' + '_'.join([
            args.model_type, args.dataset, args.sampling, p_key, 'test.pickle'
        ])
    if args.model_file == 'auto':
        dir = 'final_results'
        p_key = 'protein' if args.protein else ''
        args.model_file = dir + '/' + '_'.join(
            [args.model_type, args.dataset, args.sampling, p_key, 'model.pt'])

    # Read test data
    with open(args.test_data_file, 'rb') as handle:
        test = pickle.load(handle)

    device = args.device
    if args.model_type == 'ae':
        test_tcrs, test_peps, test_signs = ae_get_lists_from_pairs(test, 28)
        model = AutoencoderLSTMClassifier(10, device, 28, 21, 100, 50,
                                          args.ae_file, False)
        checkpoint = torch.load(args.model_file)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval()
    if args.model_type == 'lstm':
        test_tcrs, test_peps, test_signs = lstm_get_lists_from_pairs(test)
        model = DoubleLSTMClassifier(10, 500, 0.1, device)
        checkpoint = torch.load(args.model_file)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval()
        pass

    rocs = []
    auc = []

    if args.model_type == 'ae':
        test_batches_pep = ae.get_full_batches(test_tcrs, test_peps,
                                               test_signs, tcr_atox, pep_atox,
                                               50, 28)
        test_auc, roc = ae.evaluate_full(model, test_batches_pep, device)
        filename = 'tpp_test_result_ae2'
    if args.model_type == 'lstm':
        lstm.convert_data(test_tcrs, test_peps, amino_to_ix)
        test_batches_pep = lstm.get_full_batches(test_tcrs, test_peps,
                                                 test_signs, 50, amino_to_ix)
        test_auc, roc = lstm.evaluate_full(model, test_batches_pep, device)
        filename = 'tpp_test_result_lstm2'

    rocs.append(roc)
    auc.append([roc, test_auc])
    print(str(test_auc))
    pickle.dump(auc, open(filename, 'wb'))