Ejemplo n.º 1
0
def mps(hparams, model, datafiles, peptides):
    train_file, test_files = datafiles
    with open(train_file, 'rb') as handle:
        train = pickle.load(handle)
    train_dicts = get_index_dicts(train)
    samples = get_tpp_ii_pairs(datafiles)
    preds = np.zeros((len(samples), len(peptides)))
    key_order = []
    for pep_idx, pep in enumerate(peptides):
        key_order.append(pep)
        testset = SinglePeptideDataset(samples,
                                       train_dicts,
                                       peptide_map[pep],
                                       force_peptide=True,
                                       spb_force=False)
        loader = DataLoader(testset,
                            batch_size=64,
                            shuffle=False,
                            num_workers=10,
                            collate_fn=lambda b: testset.collate(
                                b,
                                tcr_encoding=hparams.tcr_encoding_model,
                                cat_encoding=hparams.cat_encoding))
        outputs = []
        with torch.no_grad():
            for batch_idx, batch in enumerate(loader):
                model.eval()
                outputs.append(model.validation_step(batch, batch_idx))
        y_hat = torch.cat([x['y_hat'].detach().cpu() for x in outputs])
        preds[:, pep_idx] = y_hat
    argmax = np.argmax(preds, axis=1)
    predicted_peps = [key_order[i] for i in argmax]
    # need to return accuracy
    return predicted_peps
Ejemplo n.º 2
0
def load_test(datafiles):
    train_pickle, test_pickle = datafiles
    with open(train_pickle, 'rb') as handle:
        train = pickle.load(handle)
    with open(test_pickle, 'rb') as handle:
        test = pickle.load(handle)
    train_dicts = get_index_dicts(train)
    return test, train_dicts
Ejemplo n.º 3
0
 def train_dataloader(self):
     with open('Samples/' + self.dataset + '_train_samples.pickle',
               'rb') as handle:
         train = pickle.load(handle)
     train_dataset = SignedPairsDataset(train, get_index_dicts(train))
     return DataLoader(train_dataset,
                       batch_size=128,
                       shuffle=True,
                       num_workers=10,
                       collate_fn=lambda b: train_dataset.collate(
                           b,
                           tcr_encoding=self.tcr_encoding_model,
                           cat_encoding=self.cat_encoding))
Ejemplo n.º 4
0
 def train_dataloader(self):
     with open(self.dataset + '_train_samples.pickle', 'rb') as handle:
         train = pickle.load(handle)
     train_dataset = YellowFeverDataset(train,
                                        get_index_dicts(train),
                                        weight_factor=self.weight_factor)
     return DataLoader(train_dataset,
                       batch_size=128,
                       shuffle=True,
                       num_workers=10,
                       collate_fn=lambda b: train_dataset.collate(
                           b,
                           tcr_encoding=self.tcr_encoding_model,
                           cat_encoding=self.cat_encoding))
Ejemplo n.º 5
0
def diabetes_mps(hparams, model, testfile, pep_pool):
    with open('mcpas_human_train_samples.pickle', 'rb') as handle:
        train = pickle.load(handle)
    train_dicts = get_index_dicts(train)
    if pep_pool == 4:
        peptide_map = {
            'IGRPp39': 'QLYHFLQIPTHEEHLFYVLS',
            'GADp70': 'KVNFFRMVISNPAATHQDID',
            'GADp15': 'DVMNILLQYVVKSFDRSTKV',
            'IGRPp31': 'KWCANPDWIHIDTTPFAGLV'
        }
    else:
        peptide_map = {}
        with open(pep_pool, 'r') as file:
            file.readline()
            for line in file:
                pep, index, protein = line.strip().split(',')
                if protein in ['GAD', 'IGRP', 'Insulin']:
                    protein += 'p'
                pep_name = protein + index
                peptide_map[pep_name] = pep
    samples = read_known_specificity_test(testfile)
    preds = np.zeros((len(samples), len(peptide_map)))
    key_order = []
    for pep_idx, pep in enumerate(peptide_map):
        key_order.append(pep)
        testset = SinglePeptideDataset(samples,
                                       train_dicts,
                                       peptide_map[pep],
                                       force_peptide=True,
                                       spb_force=False)
        loader = DataLoader(testset,
                            batch_size=10,
                            shuffle=False,
                            num_workers=10,
                            collate_fn=lambda b: testset.collate(
                                b,
                                tcr_encoding=hparams.tcr_encoding_model,
                                cat_encoding=hparams.cat_encoding))
        outputs = []
        with torch.no_grad():
            for batch_idx, batch in enumerate(loader):
                model.eval()
                outputs.append(model.validation_step(batch, batch_idx))
        y_hat = torch.cat([x['y_hat'].detach().cpu() for x in outputs])
        preds[:, pep_idx] = y_hat
    argmax = np.argmax(preds, axis=1)
    predicted_peps = [key_order[i] for i in argmax]
    print(predicted_peps)
    pass
Ejemplo n.º 6
0
def get_train_dicts(train_pickle):
    with open(train_pickle, 'rb') as handle:
        train = pickle.load(handle)
    train_dicts = get_index_dicts(train)
    return train_dicts
Ejemplo n.º 7
0
 def __init__(self, hparams):
     super(ERGOLightning, self).__init__()
     self.hparams = hparams
     self.dataset = hparams.dataset
     # Model Type
     self.tcr_encoding_model = hparams.tcr_encoding_model
     self.use_alpha = hparams.use_alpha
     self.use_vj = hparams.use_vj
     self.use_mhc = hparams.use_mhc
     self.use_t_type = hparams.use_t_type
     self.cat_encoding = hparams.cat_encoding
     # Dimensions
     self.aa_embedding_dim = hparams.aa_embedding_dim
     self.cat_embedding_dim = hparams.cat_embedding_dim
     self.lstm_dim = hparams.lstm_dim
     self.encoding_dim = hparams.encoding_dim
     self.dropout_rate = hparams.dropout
     self.lr = hparams.lr
     self.wd = hparams.wd
     # get train indicies for V,J etc
     if self.cat_encoding == 'embedding':
         with open('Samples/' + self.dataset + '_train_samples.pickle',
                   'rb') as handle:
             train = pickle.load(handle)
         vatox, vbtox, jatox, jbtox, mhctox = get_index_dicts(train)
         self.v_vocab_size = len(vatox) + len(vbtox)
         self.j_vocab_size = len(jatox) + len(jbtox)
         self.mhc_vocab_size = len(mhctox)
     # TCR Encoder
     if self.tcr_encoding_model == 'AE':
         if self.use_alpha:
             self.tcra_encoder = AE_Encoder(encoding_dim=self.encoding_dim,
                                            tcr_type='alpha',
                                            max_len=34)
         self.tcrb_encoder = AE_Encoder(encoding_dim=self.encoding_dim,
                                        tcr_type='beta')
     elif self.tcr_encoding_model == 'LSTM':
         if self.use_alpha:
             self.tcra_encoder = LSTM_Encoder(self.aa_embedding_dim,
                                              self.lstm_dim,
                                              self.dropout_rate)
         self.tcrb_encoder = LSTM_Encoder(self.aa_embedding_dim,
                                          self.lstm_dim, self.dropout_rate)
         self.encoding_dim = self.lstm_dim
     # Peptide Encoder
     self.pep_encoder = LSTM_Encoder(self.aa_embedding_dim, self.lstm_dim,
                                     self.dropout_rate)
     # Categorical
     self.cat_encoding = hparams.cat_encoding
     if hparams.cat_encoding == 'embedding':
         if self.use_vj:
             self.v_embedding = nn.Embedding(self.v_vocab_size,
                                             self.cat_embedding_dim,
                                             padding_idx=0)
             self.j_embedding = nn.Embedding(self.j_vocab_size,
                                             self.cat_embedding_dim,
                                             padding_idx=0)
         if self.use_mhc:
             self.mhc_embedding = nn.Embedding(self.mhc_vocab_size,
                                               self.cat_embedding_dim,
                                               padding_idx=0)
     # different mlp sizes, depends on model input
     if self.cat_encoding == 'binary':
         self.cat_embedding_dim = 10
     mlp_input_size = self.lstm_dim + self.encoding_dim
     if self.use_vj:
         mlp_input_size += 2 * self.cat_embedding_dim
     if self.use_mhc:
         mlp_input_size += self.cat_embedding_dim
     if self.use_t_type:
         mlp_input_size += 1
     # MLP I (without alpha)
     self.mlp_dim1 = mlp_input_size
     self.hidden_layer1 = nn.Linear(self.mlp_dim1,
                                    int(np.sqrt(self.mlp_dim1)))
     self.relu = torch.nn.LeakyReLU()
     self.output_layer1 = nn.Linear(int(np.sqrt(self.mlp_dim1)), 1)
     self.dropout = nn.Dropout(p=self.dropout_rate)
     # MLP II (with alpha)
     if self.use_alpha:
         mlp_input_size += self.encoding_dim
         if self.use_vj:
             mlp_input_size += 2 * self.cat_embedding_dim
         self.mlp_dim2 = mlp_input_size
         self.hidden_layer2 = nn.Linear(self.mlp_dim2,
                                        int(np.sqrt(self.mlp_dim2)))
         self.output_layer2 = nn.Linear(int(np.sqrt(self.mlp_dim2)), 1)