def evaluate(args, model, tcrs, peps, signs): # Word to index dictionary amino_acids = [letter for letter in 'ARNDCEQGHILKMFPSTWYV'] if args.model_type == 'lstm': amino_to_ix = { amino: index for index, amino in enumerate(['PAD'] + amino_acids) } if args.model_type == 'ae': pep_atox = { amino: index for index, amino in enumerate(['PAD'] + amino_acids) } tcr_atox = { amino: index for index, amino in enumerate(amino_acids + ['X']) } max_len = 28 batch_size = 50 # Predict if args.model_type == 'ae': test_batches = ae.get_full_batches(tcrs, peps, signs, tcr_atox, pep_atox, batch_size, max_len) auc, roc = ae.evaluate_full(model, test_batches, args.device) if args.model_type == 'lstm': lstm.convert_data(tcrs, peps, amino_to_ix) test_batches = lstm.get_full_batches(tcrs, peps, signs, batch_size, amino_to_ix) auc, roc = lstm.evaluate_full(model, test_batches, args.device) return auc, roc
def single_protein_score(args, model, test_data, protein, protein_peps): # positive examples - tcr in test that bind a pep belongs to the protein # negative examples - tcr in test that do not bind a pep belongs to the protein # Get pep-relevant data tcrs = [p[0] for p in test_data if p[1][0] in protein_peps[protein]] signs_to_prob = {'n': 0.0, 'p': 1.0} signs = [ signs_to_prob[p[2]] for p in test_data if p[1][0] in protein_peps[protein] ] peps = [p[1][0] for p in test_data if p[1][0] in protein_peps[protein]] # Word to index dictionary amino_acids = [letter for letter in 'ARNDCEQGHILKMFPSTWYV'] if args.model_type == 'lstm': amino_to_ix = { amino: index for index, amino in enumerate(['PAD'] + amino_acids) } if args.model_type == 'ae': pep_atox = { amino: index for index, amino in enumerate(['PAD'] + amino_acids) } tcr_atox = { amino: index for index, amino in enumerate(amino_acids + ['X']) } max_len = 28 batch_size = 50 if args.model_type == 'ae': test_batches = ae.get_full_batches(tcrs, peps, signs, tcr_atox, pep_atox, batch_size, max_len) test_auc, roc = ae.evaluate_full(model, test_batches, args.device) if args.model_type == 'lstm': lstm.convert_data(tcrs, peps, amino_to_ix) test_batches = lstm.get_full_batches(tcrs, peps, signs, batch_size, amino_to_ix) test_auc, roc = lstm.evaluate_full(model, test_batches, args.device) return test_auc, roc
def single_peptide_score(args, model, test_data, pep, neg_type=None): # positive examples - tcr in test that bind this pep # negative examples - tcr in test that do not bind this pep # negs could be from test pairs, or naive, or memory # Get pep-relevant data tcrs = [p[0] for p in test_data if p[1][0] == pep] signs_to_prob = {'n': 0.0, 'p': 1.0} signs = [signs_to_prob[p[2]] for p in test_data if p[1][0] == pep] peps = [pep] * len(tcrs) # Word to index dictionary amino_acids = [letter for letter in 'ARNDCEQGHILKMFPSTWYV'] if args.model_type == 'lstm': amino_to_ix = { amino: index for index, amino in enumerate(['PAD'] + amino_acids) } if args.model_type == 'ae': pep_atox = { amino: index for index, amino in enumerate(['PAD'] + amino_acids) } tcr_atox = { amino: index for index, amino in enumerate(amino_acids + ['X']) } max_len = 28 batch_size = 50 if args.model_type == 'ae': test_batches = ae.get_full_batches(tcrs, peps, signs, tcr_atox, pep_atox, batch_size, max_len) test_auc, roc = ae.evaluate_full(model, test_batches, args.device) if args.model_type == 'lstm': lstm.convert_data(tcrs, peps, amino_to_ix) test_batches = lstm.get_full_batches(tcrs, peps, signs, batch_size, amino_to_ix) test_auc, roc = lstm.evaluate_full(model, test_batches, args.device) return test_auc, roc
def protein_test(args): assert args.protein # Word to index dictionary amino_acids = [letter for letter in 'ARNDCEQGHILKMFPSTWYV'] if args.model_type == 'lstm': amino_to_ix = { amino: index for index, amino in enumerate(['PAD'] + amino_acids) } if args.model_type == 'ae': pep_atox = { amino: index for index, amino in enumerate(['PAD'] + amino_acids) } tcr_atox = { amino: index for index, amino in enumerate(amino_acids + ['X']) } if args.ae_file == 'auto': args.ae_file = 'TCR_Autoencoder/' if args.test_data_file == 'auto': dir = 'memory_and_protein' p_key = 'protein' if args.protein else '' args.test_data_file = dir + '/' + '_'.join([ args.model_type, args.dataset, args.sampling, p_key, 'test.pickle' ]) if args.model_file == 'auto': dir = 'memory_and_protein' p_key = 'protein' if args.protein else '' args.model_file = dir + '/' + '_'.join( [args.model_type, args.dataset, args.sampling, p_key, '']) # Read test data with open(args.test_data_file, 'rb') as handle: test = pickle.load(handle) device = args.device if args.model_type == 'ae': test_tcrs, test_peps, test_signs = ae_get_lists_from_pairs(test, 28) model = AutoencoderLSTMClassifier(10, device, 28, 21, 30, 50, args.ae_file, False) checkpoint = torch.load(args.model_file) model.load_state_dict(checkpoint['model_state_dict']) model.eval() if args.model_type == 'lstm': test_tcrs, test_peps, test_signs = lstm_get_lists_from_pairs(test) model = DoubleLSTMClassifier(10, 30, 0.1, device) checkpoint = torch.load(args.model_file) model.load_state_dict(checkpoint['model_state_dict']) model.eval() pass # Get frequent peps list if args.dataset == 'mcpas': datafile = 'McPAS-TCR.csv' p = [] protein_peps = {} with open(datafile, 'r', encoding='unicode_escape') as file: file.readline() reader = csv.reader(file) for line in reader: pep, protein = line[11], line[9] if protein == 'NA' or pep == 'NA': continue p.append(protein) try: protein_peps[protein].append(pep) except KeyError: protein_peps[protein] = [pep] d = {t: p.count(t) for t in set(p)} sorted_d = sorted(d.items(), key=lambda k: k[1], reverse=True) proteins = [t[0] for t in sorted_d] """ McPAS most frequent proteins NP177 Influenza Matrix protein (M1) Influenza pp65 Cytomegalovirus (CMV) BMLF-1 Epstein Barr virus (EBV) PB1 Influenza """ rocs = [] for protein in proteins[:50]: protein_shows = [ i for i in range(len(test_peps)) if test_peps[i] in protein_peps[protein] ] test_tcrs_protein = [test_tcrs[i] for i in protein_shows] test_peps_protein = [test_peps[i] for i in protein_shows] test_signs_protein = [test_signs[i] for i in protein_shows] if args.model_type == 'ae': test_batches_protein = ae.get_full_batches(test_tcrs_protein, test_peps_protein, test_signs_protein, tcr_atox, pep_atox, 50, 28) if args.model_type == 'lstm': lstm.convert_data(test_tcrs_protein, test_peps_protein, amino_to_ix) test_batches_protein = lstm.get_full_batches( test_tcrs_protein, test_peps_protein, test_signs_protein, 50, amino_to_ix) if len(protein_shows): try: if args.model_type == 'ae': test_auc, roc = ae.evaluate_full(model, test_batches_protein, device) if args.model_type == 'lstm': test_auc, roc = lstm.evaluate_full(model, test_batches_protein, device) rocs.append((pep, roc)) # print(protein) print(str(test_auc)) # print(protein + ', ' + str(test_auc)) except ValueError: # print(protein) print('NA') # print(protein + ', ' 'NA') pass return rocs
def pep_test(args): # Word to index dictionary amino_acids = [letter for letter in 'ARNDCEQGHILKMFPSTWYV'] if args.model_type == 'lstm': amino_to_ix = { amino: index for index, amino in enumerate(['PAD'] + amino_acids) } if args.model_type == 'ae': pep_atox = { amino: index for index, amino in enumerate(['PAD'] + amino_acids) } tcr_atox = { amino: index for index, amino in enumerate(amino_acids + ['X']) } if args.ae_file == 'auto': args.ae_file = 'TCR_Autoencoder/' if args.test_data_file == 'auto': dir = 'final_results' p_key = 'protein' if args.protein else '' args.test_data_file = dir + '/' + '_'.join([ args.model_type, args.dataset, args.sampling, p_key, 'test.pickle' ]) if args.model_file == 'auto': dir = 'final_results' p_key = 'protein' if args.protein else '' args.model_file = dir + '/' + '_'.join( [args.model_type, args.dataset, args.sampling, p_key, '']) # Read test data with open(args.test_data_file, 'rb') as handle: test = pickle.load(handle) device = args.device if args.model_type == 'ae': test_tcrs, test_peps, test_signs = ae_get_lists_from_pairs(test, 28) model = AutoencoderLSTMClassifier(10, device, 28, 21, 100, 50, args.ae_file, False) checkpoint = torch.load(args.model_file) model.load_state_dict(checkpoint['model_state_dict']) model.eval() if args.model_type == 'lstm': test_tcrs, test_peps, test_signs = lstm_get_lists_from_pairs(test) model = DoubleLSTMClassifier(10, 500, 0.1, device) checkpoint = torch.load(args.model_file) model.load_state_dict(checkpoint['model_state_dict']) model.eval() pass # Get frequent peps list if args.dataset == 'mcpas': datafile = 'nine_class_testdata.csv' p = [] with open(datafile, 'r', encoding='unicode_escape') as file: file.readline() reader = csv.reader(file) for line in reader: pep = line[1] if pep == 'NA': continue p.append(pep) d = {t: p.count(t) for t in set(p)} sorted_d = sorted(d.items(), key=lambda k: k[1], reverse=True) peps = [t[0] for t in sorted_d] """ McPAS most frequent peps LPRRSGAAGA Influenza GILGFVFTL Influenza GLCTLVAML Epstein Barr virus (EBV) NLVPMVATV Cytomegalovirus (CMV) SSYRRPVGI Influenza """ rocs = [] auc = [] for pep in peps[:249]: pep_shows = [i for i in range(len(test_peps)) if pep == test_peps[i]] test_tcrs_pep = [test_tcrs[i] for i in pep_shows] test_peps_pep = [test_peps[i] for i in pep_shows] test_signs_pep = [test_signs[i] for i in pep_shows] if args.model_type == 'ae': test_batches_pep = ae.get_full_batches(test_tcrs_pep, test_peps_pep, test_signs_pep, tcr_atox, pep_atox, 50, 28) if args.model_type == 'lstm': lstm.convert_data(test_tcrs_pep, test_peps_pep, amino_to_ix) test_batches_pep = lstm.get_full_batches(test_tcrs_pep, test_peps_pep, test_signs_pep, 50, amino_to_ix) if len(pep_shows): try: if args.model_type == 'ae': test_auc, roc = ae.evaluate_full(model, test_batches_pep, device) filename = 'pep_test_result_ae2' if args.model_type == 'lstm': test_auc, roc = lstm.evaluate_full(model, test_batches_pep, device) filename = 'pep_test_result_lstm2' rocs.append((pep, roc)) auc.append([pep, roc, test_auc]) print(str(test_auc)) # print(pep + ', ' + str(test_auc)) except ValueError: print('NA') # print(pep + ', ' 'NA') pass pickle.dump(auc, open(filename, 'wb'))
def pep_test(args): # Word to index dictionary amino_acids = [letter for letter in 'ARNDCEQGHILKMFPSTWYV'] if args.model_type == 'lstm': amino_to_ix = { amino: index for index, amino in enumerate(['PAD'] + amino_acids) } if args.model_type == 'ae': pep_atox = { amino: index for index, amino in enumerate(['PAD'] + amino_acids) } tcr_atox = { amino: index for index, amino in enumerate(amino_acids + ['X']) } if args.ae_file == 'auto': args.ae_file = 'TCR_Autoencoder/' if args.test_data_file == 'auto': dir = 'final_results' p_key = 'protein' if args.protein else '' args.test_data_file = dir + '/' + '_'.join([ args.model_type, args.dataset, args.sampling, p_key, 'test.pickle' ]) if args.model_file == 'auto': dir = 'final_results' p_key = 'protein' if args.protein else '' args.model_file = dir + '/' + '_'.join( [args.model_type, args.dataset, args.sampling, p_key, '']) # Read test data with open(args.test_data_file, 'rb') as handle: test = pickle.load(handle) device = args.device if args.model_type == 'ae': test_tcrs, test_peps, test_signs = ae_get_lists_from_pairs(test, 28) model = AutoencoderLSTMClassifier(10, device, 28, 21, 100, 50, args.ae_file, False) checkpoint = torch.load(args.model_file) model.load_state_dict(checkpoint['model_state_dict']) model.eval() if args.model_type == 'lstm': test_tcrs, test_peps, test_signs = lstm_get_lists_from_pairs(test) model = DoubleLSTMClassifier(10, 500, 0.1, device) checkpoint = torch.load(args.model_file) model.load_state_dict(checkpoint['model_state_dict']) model.eval() pass rocs = [] auc = [] if args.model_type == 'ae': test_batches_pep = ae.get_full_batches(test_tcrs, test_peps, test_signs, tcr_atox, pep_atox, 50, 28) test_auc, roc = ae.evaluate_full(model, test_batches_pep, device) filename = 'tpp_test_result_ae2' if args.model_type == 'lstm': lstm.convert_data(test_tcrs, test_peps, amino_to_ix) test_batches_pep = lstm.get_full_batches(test_tcrs, test_peps, test_signs, 50, amino_to_ix) test_auc, roc = lstm.evaluate_full(model, test_batches_pep, device) filename = 'tpp_test_result_lstm2' rocs.append(roc) auc.append([roc, test_auc]) print(str(test_auc)) pickle.dump(auc, open(filename, 'wb'))