コード例 #1
0
ファイル: Evaluations.py プロジェクト: louzounlab/ERGO-II
def mps(hparams, model, datafiles, peptides):
    train_file, test_files = datafiles
    with open(train_file, 'rb') as handle:
        train = pickle.load(handle)
    train_dicts = get_index_dicts(train)
    samples = get_tpp_ii_pairs(datafiles)
    preds = np.zeros((len(samples), len(peptides)))
    key_order = []
    for pep_idx, pep in enumerate(peptides):
        key_order.append(pep)
        testset = SinglePeptideDataset(samples,
                                       train_dicts,
                                       peptide_map[pep],
                                       force_peptide=True,
                                       spb_force=False)
        loader = DataLoader(testset,
                            batch_size=64,
                            shuffle=False,
                            num_workers=10,
                            collate_fn=lambda b: testset.collate(
                                b,
                                tcr_encoding=hparams.tcr_encoding_model,
                                cat_encoding=hparams.cat_encoding))
        outputs = []
        with torch.no_grad():
            for batch_idx, batch in enumerate(loader):
                model.eval()
                outputs.append(model.validation_step(batch, batch_idx))
        y_hat = torch.cat([x['y_hat'].detach().cpu() for x in outputs])
        preds[:, pep_idx] = y_hat
    argmax = np.argmax(preds, axis=1)
    predicted_peps = [key_order[i] for i in argmax]
    # need to return accuracy
    return predicted_peps
コード例 #2
0
ファイル: Evaluations.py プロジェクト: louzounlab/ERGO-II
def auc_predict(model, test, train_dicts, peptide=None):
    if peptide:
        test_dataset = SinglePeptideDataset(test,
                                            train_dicts,
                                            peptide,
                                            force_peptide=False)
    else:
        test_dataset = SignedPairsDataset(test, train_dicts)
    # print(test_dataset.data)
    loader = DataLoader(test_dataset,
                        batch_size=64,
                        shuffle=False,
                        num_workers=0,
                        collate_fn=lambda b: test_dataset.collate(
                            b,
                            tcr_encoding=model.tcr_encoding_model,
                            cat_encoding=model.cat_encoding))
    outputs = []
    for batch_idx, batch in enumerate(loader):
        output = model.validation_step(batch, batch_idx)
        if output:
            outputs.append(output)
            # print(output['y'])
    auc = model.validation_end(outputs)['val_auc']
    return auc
コード例 #3
0
ファイル: Evaluations.py プロジェクト: louzounlab/ERGO-II
def diabetes_mps(hparams, model, testfile, pep_pool):
    with open('mcpas_human_train_samples.pickle', 'rb') as handle:
        train = pickle.load(handle)
    train_dicts = get_index_dicts(train)
    if pep_pool == 4:
        peptide_map = {
            'IGRPp39': 'QLYHFLQIPTHEEHLFYVLS',
            'GADp70': 'KVNFFRMVISNPAATHQDID',
            'GADp15': 'DVMNILLQYVVKSFDRSTKV',
            'IGRPp31': 'KWCANPDWIHIDTTPFAGLV'
        }
    else:
        peptide_map = {}
        with open(pep_pool, 'r') as file:
            file.readline()
            for line in file:
                pep, index, protein = line.strip().split(',')
                if protein in ['GAD', 'IGRP', 'Insulin']:
                    protein += 'p'
                pep_name = protein + index
                peptide_map[pep_name] = pep
    samples = read_known_specificity_test(testfile)
    preds = np.zeros((len(samples), len(peptide_map)))
    key_order = []
    for pep_idx, pep in enumerate(peptide_map):
        key_order.append(pep)
        testset = SinglePeptideDataset(samples,
                                       train_dicts,
                                       peptide_map[pep],
                                       force_peptide=True,
                                       spb_force=False)
        loader = DataLoader(testset,
                            batch_size=10,
                            shuffle=False,
                            num_workers=10,
                            collate_fn=lambda b: testset.collate(
                                b,
                                tcr_encoding=hparams.tcr_encoding_model,
                                cat_encoding=hparams.cat_encoding))
        outputs = []
        with torch.no_grad():
            for batch_idx, batch in enumerate(loader):
                model.eval()
                outputs.append(model.validation_step(batch, batch_idx))
        y_hat = torch.cat([x['y_hat'].detach().cpu() for x in outputs])
        preds[:, pep_idx] = y_hat
    argmax = np.argmax(preds, axis=1)
    predicted_peps = [key_order[i] for i in argmax]
    print(predicted_peps)
    pass
コード例 #4
0
ファイル: Evaluations.py プロジェクト: louzounlab/ERGO-II
def spb_with_more_negatives(model, datafiles, peptide):
    test = get_tpp_ii_pairs(datafiles)
    # Regular SPB
    # test_dataset = SinglePeptideDataset(test, peptide)
    # More negatives
    test_dataset = SinglePeptideDataset(test,
                                        peptide,
                                        force_peptide=True,
                                        spb_force=True)
    if model.tcr_encoding_model == 'AE':
        collate_fn = test_dataset.ae_collate
    elif model.tcr_encoding_model == 'LSTM':
        collate_fn = test_dataset.lstm_collate
    loader = DataLoader(test_dataset,
                        batch_size=64,
                        shuffle=False,
                        num_workers=10,
                        collate_fn=collate_fn)
    outputs = []
    i = 0
    positives = 0
    for batch_idx, batch in enumerate(loader):
        i += 1
        outputs.append(model.validation_step(batch, batch_idx))
    if i:
        print('positives:',
              int(torch.cat([x['y'] for x in outputs]).sum().item()))
        auc = model.validation_end(outputs)['val_auc']
        print(auc)
    pass
コード例 #5
0
def diabetes_test_set(model):
    # 8 paired samples, 4 peptides
    # tcra, tcrb, pep
    data = [('CAATRTSGTYKYIF', 'CASSPWGAGGTDTQYF', 'IGRPp39'),
            ('CAVGAGYGGATNKLIF', 'CASSFRGGGNPYEQYF', 'GADp70'),
            ('CAERLYGNNRLAF', 'CASTLLWGGDSYEQYF', 'GADp15'),
            ('CAVNPNQAGTALIF', 'CASAPQEAQPQHF', 'IGRPp31'),
            ('CALSDYSGTSYGKLTF', 'CASSLIPYNEQFF', 'GADp15'),
            ('CAVEDLNQAGTALIF', 'CASSLALGQGNQQFF', 'IGRPp31'),
            ('CILRDTISNFGNEKLTF', 'CASSFGSSYYGYTF', 'IGRPp39'),
            ('CAGQTGANNLFF', 'CASSQEVGTVPNQPQHF', 'IGRPp31')]
    peptide_map = {
        'IGRPp39': 'QLYHFLQIPTHEEHLFYVLS',
        'GADp70': 'KVNFFRMVISNPAATHQDID',
        'GADp15': 'DVMNILLQYVVKSFDRSTKV',
        'IGRPp31': 'KWCANPDWIHIDTTPFAGLV'
    }
    true_labels = np.array(
        [list(peptide_map.keys()).index(d[-1]) for d in data])
    print(true_labels)
    samples = []
    for tcra, tcrb, pep in data:
        tcr_data = (tcra, tcrb, 'v', 'j')
        pep_data = (peptide_map[pep], 'mhc', 'protein')
        samples.append((tcr_data, pep_data, 1))
    preds = np.zeros((len(samples), len(peptide_map)))
    for pep_idx, pep in enumerate(peptide_map):
        # signs do not matter here, we do only forward pass
        dataset = SinglePeptideDataset(samples,
                                       peptide_map[pep],
                                       force_peptide=True)
        if model.tcr_encoding_model == 'AE':
            collate_fn = dataset.ae_collate
        elif model.tcr_encoding_model == 'LSTM':
            collate_fn = dataset.lstm_collate
        loader = DataLoader(dataset,
                            batch_size=64,
                            shuffle=False,
                            num_workers=10,
                            collate_fn=collate_fn)
        outputs = []
        for batch_idx, batch in enumerate(loader):
            outputs.append(model.validation_step(batch, batch_idx))
        y_hat = torch.cat([x['y_hat'].detach().cpu() for x in outputs])
        preds[:, pep_idx] = y_hat
    # print(preds)
    argmax = np.argmax(preds, axis=1)
    print(argmax)
    accuracy = sum((argmax == true_labels).astype(int)) / len(samples)
    print(accuracy)
    # try protein accuracy - IGRP and GAD
    true_labels = np.array(
        [0 if x == 3 else 1 if x == 2 else x for x in true_labels])
    argmax = np.array([0 if x == 3 else 1 if x == 2 else x for x in argmax])
    print(true_labels)
    print(argmax)
    accuracy = sum((argmax == true_labels).astype(int)) / len(samples)
    print(accuracy)
    pass
コード例 #6
0
def spb(model, datafiles, peptide):
    test = get_tpp_ii_pairs(datafiles)
    test_dataset = SinglePeptideDataset(test, peptide)
    if model.tcr_encoding_model == 'AE':
        collate_fn = test_dataset.ae_collate
    elif model.tcr_encoding_model == 'LSTM':
        collate_fn = test_dataset.lstm_collate
    loader = DataLoader(test_dataset,
                        batch_size=64,
                        shuffle=False,
                        num_workers=10,
                        collate_fn=collate_fn)
    outputs = []
    for batch_idx, batch in enumerate(loader):
        outputs.append(model.validation_step(batch, batch_idx))
    auc = model.validation_end(outputs)['val_auc']
    print(auc)
    pass