def mps(hparams, model, datafiles, peptides): train_file, test_files = datafiles with open(train_file, 'rb') as handle: train = pickle.load(handle) train_dicts = get_index_dicts(train) samples = get_tpp_ii_pairs(datafiles) preds = np.zeros((len(samples), len(peptides))) key_order = [] for pep_idx, pep in enumerate(peptides): key_order.append(pep) testset = SinglePeptideDataset(samples, train_dicts, peptide_map[pep], force_peptide=True, spb_force=False) loader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=10, collate_fn=lambda b: testset.collate( b, tcr_encoding=hparams.tcr_encoding_model, cat_encoding=hparams.cat_encoding)) outputs = [] with torch.no_grad(): for batch_idx, batch in enumerate(loader): model.eval() outputs.append(model.validation_step(batch, batch_idx)) y_hat = torch.cat([x['y_hat'].detach().cpu() for x in outputs]) preds[:, pep_idx] = y_hat argmax = np.argmax(preds, axis=1) predicted_peps = [key_order[i] for i in argmax] # need to return accuracy return predicted_peps
def auc_predict(model, test, train_dicts, peptide=None): if peptide: test_dataset = SinglePeptideDataset(test, train_dicts, peptide, force_peptide=False) else: test_dataset = SignedPairsDataset(test, train_dicts) # print(test_dataset.data) loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=0, collate_fn=lambda b: test_dataset.collate( b, tcr_encoding=model.tcr_encoding_model, cat_encoding=model.cat_encoding)) outputs = [] for batch_idx, batch in enumerate(loader): output = model.validation_step(batch, batch_idx) if output: outputs.append(output) # print(output['y']) auc = model.validation_end(outputs)['val_auc'] return auc
def diabetes_mps(hparams, model, testfile, pep_pool): with open('mcpas_human_train_samples.pickle', 'rb') as handle: train = pickle.load(handle) train_dicts = get_index_dicts(train) if pep_pool == 4: peptide_map = { 'IGRPp39': 'QLYHFLQIPTHEEHLFYVLS', 'GADp70': 'KVNFFRMVISNPAATHQDID', 'GADp15': 'DVMNILLQYVVKSFDRSTKV', 'IGRPp31': 'KWCANPDWIHIDTTPFAGLV' } else: peptide_map = {} with open(pep_pool, 'r') as file: file.readline() for line in file: pep, index, protein = line.strip().split(',') if protein in ['GAD', 'IGRP', 'Insulin']: protein += 'p' pep_name = protein + index peptide_map[pep_name] = pep samples = read_known_specificity_test(testfile) preds = np.zeros((len(samples), len(peptide_map))) key_order = [] for pep_idx, pep in enumerate(peptide_map): key_order.append(pep) testset = SinglePeptideDataset(samples, train_dicts, peptide_map[pep], force_peptide=True, spb_force=False) loader = DataLoader(testset, batch_size=10, shuffle=False, num_workers=10, collate_fn=lambda b: testset.collate( b, tcr_encoding=hparams.tcr_encoding_model, cat_encoding=hparams.cat_encoding)) outputs = [] with torch.no_grad(): for batch_idx, batch in enumerate(loader): model.eval() outputs.append(model.validation_step(batch, batch_idx)) y_hat = torch.cat([x['y_hat'].detach().cpu() for x in outputs]) preds[:, pep_idx] = y_hat argmax = np.argmax(preds, axis=1) predicted_peps = [key_order[i] for i in argmax] print(predicted_peps) pass
def spb_with_more_negatives(model, datafiles, peptide): test = get_tpp_ii_pairs(datafiles) # Regular SPB # test_dataset = SinglePeptideDataset(test, peptide) # More negatives test_dataset = SinglePeptideDataset(test, peptide, force_peptide=True, spb_force=True) if model.tcr_encoding_model == 'AE': collate_fn = test_dataset.ae_collate elif model.tcr_encoding_model == 'LSTM': collate_fn = test_dataset.lstm_collate loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=10, collate_fn=collate_fn) outputs = [] i = 0 positives = 0 for batch_idx, batch in enumerate(loader): i += 1 outputs.append(model.validation_step(batch, batch_idx)) if i: print('positives:', int(torch.cat([x['y'] for x in outputs]).sum().item())) auc = model.validation_end(outputs)['val_auc'] print(auc) pass
def diabetes_test_set(model): # 8 paired samples, 4 peptides # tcra, tcrb, pep data = [('CAATRTSGTYKYIF', 'CASSPWGAGGTDTQYF', 'IGRPp39'), ('CAVGAGYGGATNKLIF', 'CASSFRGGGNPYEQYF', 'GADp70'), ('CAERLYGNNRLAF', 'CASTLLWGGDSYEQYF', 'GADp15'), ('CAVNPNQAGTALIF', 'CASAPQEAQPQHF', 'IGRPp31'), ('CALSDYSGTSYGKLTF', 'CASSLIPYNEQFF', 'GADp15'), ('CAVEDLNQAGTALIF', 'CASSLALGQGNQQFF', 'IGRPp31'), ('CILRDTISNFGNEKLTF', 'CASSFGSSYYGYTF', 'IGRPp39'), ('CAGQTGANNLFF', 'CASSQEVGTVPNQPQHF', 'IGRPp31')] peptide_map = { 'IGRPp39': 'QLYHFLQIPTHEEHLFYVLS', 'GADp70': 'KVNFFRMVISNPAATHQDID', 'GADp15': 'DVMNILLQYVVKSFDRSTKV', 'IGRPp31': 'KWCANPDWIHIDTTPFAGLV' } true_labels = np.array( [list(peptide_map.keys()).index(d[-1]) for d in data]) print(true_labels) samples = [] for tcra, tcrb, pep in data: tcr_data = (tcra, tcrb, 'v', 'j') pep_data = (peptide_map[pep], 'mhc', 'protein') samples.append((tcr_data, pep_data, 1)) preds = np.zeros((len(samples), len(peptide_map))) for pep_idx, pep in enumerate(peptide_map): # signs do not matter here, we do only forward pass dataset = SinglePeptideDataset(samples, peptide_map[pep], force_peptide=True) if model.tcr_encoding_model == 'AE': collate_fn = dataset.ae_collate elif model.tcr_encoding_model == 'LSTM': collate_fn = dataset.lstm_collate loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=10, collate_fn=collate_fn) outputs = [] for batch_idx, batch in enumerate(loader): outputs.append(model.validation_step(batch, batch_idx)) y_hat = torch.cat([x['y_hat'].detach().cpu() for x in outputs]) preds[:, pep_idx] = y_hat # print(preds) argmax = np.argmax(preds, axis=1) print(argmax) accuracy = sum((argmax == true_labels).astype(int)) / len(samples) print(accuracy) # try protein accuracy - IGRP and GAD true_labels = np.array( [0 if x == 3 else 1 if x == 2 else x for x in true_labels]) argmax = np.array([0 if x == 3 else 1 if x == 2 else x for x in argmax]) print(true_labels) print(argmax) accuracy = sum((argmax == true_labels).astype(int)) / len(samples) print(accuracy) pass
def spb(model, datafiles, peptide): test = get_tpp_ii_pairs(datafiles) test_dataset = SinglePeptideDataset(test, peptide) if model.tcr_encoding_model == 'AE': collate_fn = test_dataset.ae_collate elif model.tcr_encoding_model == 'LSTM': collate_fn = test_dataset.lstm_collate loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=10, collate_fn=collate_fn) outputs = [] for batch_idx, batch in enumerate(loader): outputs.append(model.validation_step(batch, batch_idx)) auc = model.validation_end(outputs)['val_auc'] print(auc) pass