def __init__(self, embedding_dim, device, max_len, input_dim, encoding_dim, batch_size, ae_file, train_ae): super(AutoencoderLSTMClassifier, self).__init__() # GPU self.device = device # Dimensions self.embedding_dim = embedding_dim self.lstm_dim = encoding_dim self.max_len = max_len self.input_dim = input_dim self.batch_size = batch_size # TCR Autoencoder self.autoencoder = PaddingAutoencoder(max_len, input_dim, encoding_dim) checkpoint = torch.load(ae_file) self.autoencoder.load_state_dict(checkpoint['model_state_dict']) if train_ae is False: for param in self.autoencoder.parameters(): param.requires_grad = False self.autoencoder.eval() # Embedding matrices - 20 amino acids + padding self.pep_embedding = nn.Embedding(20 + 1, embedding_dim, padding_idx=0) # RNN - LSTM self.pep_lstm = nn.LSTM(embedding_dim, self.lstm_dim, num_layers=2, batch_first=True, dropout=0.1) # MLP self.hidden_layer = nn.Linear(self.lstm_dim * 2, self.lstm_dim) self.relu = torch.nn.LeakyReLU() self.output_layer = nn.Linear(self.lstm_dim, 1) self.dropout = nn.Dropout(p=0.1)
def __init__(self, embedding_dim, device, max_len, input_dim, encoding_dim, output_dim, batch_size, ae_file, train_ae=True): super(PathologyClassifier, self).__init__() # GPU self.device = device # Dimensions self.embedding_dim = embedding_dim self.max_len = max_len self.input_dim = input_dim self.batch_size = batch_size # TCR Autoencoder self.autoencoder = PaddingAutoencoder(max_len, input_dim, encoding_dim) checkpoint = torch.load(ae_file) self.autoencoder.load_state_dict(checkpoint['model_state_dict']) if train_ae is False: for param in self.autoencoder.parameters(): param.requires_grad = False self.autoencoder.eval() self.mlp = nn.Sequential(nn.Linear(encoding_dim, 50), nn.Tanh(), nn.Linear(50, output_dim), nn.Softmax(dim=1))
def train_model(batches, batch_size, max_len, encoding_dim, epochs, device): model = PaddingAutoencoder(max_len, 20 + 1, encoding_dim) model.to(device) loss_function = torch.nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-8, weight_decay=0) for epoch in range(epochs): print('epoch:', epoch + 1) train_epoch(batches, batch_size, model, loss_function, optimizer, device) return model
def get_batches(tcrs, tcr_atox, pathologies, params, args): """ Get batches from the data """ max_length = params['max_len'] batch_size = params['batch_size'] # Load autoencoder autoencoder = PaddingAutoencoder(max_length, 21, params['enc_dim']) checkpoint = torch.load(args['ae_file']) autoencoder.load_state_dict(checkpoint['model_state_dict']) for param in autoencoder.parameters(): param.requires_grad = False autoencoder.eval() # Shuffle z = list(zip(tcrs, pathologies)) shuffle(z) tcrs, pathologies = zip(*z) tcrs = list(tcrs) pathologies = list(pathologies) # Initialization batches = [] index = 0 convert_data(tcrs, tcr_atox, max_length) # Go over all data while index < len(tcrs) // batch_size * batch_size: # Get batch sequences and math tags # Add batch to list batch_tcrs = tcrs[index:index + batch_size] tcr_tensor = torch.zeros((batch_size, max_length, 21)) for i in range(batch_size): tcr_tensor[i] = batch_tcrs[i] concat = tcr_tensor.view(batch_size, max_length * 21) encoded_tcrs = autoencoder.encoder(concat) batch_pathologies = pathologies[index:index + batch_size] batches.append((encoded_tcrs, batch_pathologies)) # Update index index += batch_size ''' # pad data in last batch missing = batch_size - len(tcrs) + index padding_tcrs = ['X'] * missing padding_pathologies = [class_limit] * missing convert_data(padding_tcrs, tcr_atox, max_length) batch_tcrs = tcrs[index:] + padding_tcrs tcr_tensor = torch.zeros((batch_size, max_length, 21)) for i in range(batch_size): tcr_tensor[i] = batch_tcrs[i] batch_pathologies = pathologies[index:] + padding_pathologies batches.append((tcr_tensor, batch_pathologies)) # Update index index += batch_size ''' # Return list of all batches return batches
class PathologyClassifier(nn.Module): def __init__(self, embedding_dim, device, max_len, input_dim, encoding_dim, output_dim, batch_size, ae_file, train_ae=True): super(PathologyClassifier, self).__init__() # GPU self.device = device # Dimensions self.embedding_dim = embedding_dim self.max_len = max_len self.input_dim = input_dim self.batch_size = batch_size # TCR Autoencoder self.autoencoder = PaddingAutoencoder(max_len, input_dim, encoding_dim) checkpoint = torch.load(ae_file) self.autoencoder.load_state_dict(checkpoint['model_state_dict']) if train_ae is False: for param in self.autoencoder.parameters(): param.requires_grad = False self.autoencoder.eval() self.mlp = nn.Sequential(nn.Linear(encoding_dim, 50), nn.Tanh(), nn.Linear(50, output_dim), nn.Softmax(dim=1)) def forward(self, padded_tcrs): # TCR Encoder: # Embedding concat = padded_tcrs.view(self.batch_size, self.max_len * self.input_dim) encoded_tcrs = self.autoencoder.encoder(concat) # MLP Classifier mlp_output = self.mlp(encoded_tcrs) return mlp_output pass
class AutoencoderLSTMClassifier(nn.Module): def __init__(self, embedding_dim, device, max_len, input_dim, encoding_dim, batch_size, ae_file, train_ae): super(AutoencoderLSTMClassifier, self).__init__() # GPU self.device = device # Dimensions self.embedding_dim = embedding_dim self.lstm_dim = encoding_dim self.max_len = max_len self.input_dim = input_dim self.batch_size = batch_size # TCR Autoencoder self.autoencoder = PaddingAutoencoder(max_len, input_dim, encoding_dim) checkpoint = torch.load(ae_file) self.autoencoder.load_state_dict(checkpoint['model_state_dict']) if train_ae is False: for param in self.autoencoder.parameters(): param.requires_grad = False self.autoencoder.eval() # Embedding matrices - 20 amino acids + padding self.pep_embedding = nn.Embedding(20 + 1, embedding_dim, padding_idx=0) # RNN - LSTM self.pep_lstm = nn.LSTM(embedding_dim, self.lstm_dim, num_layers=2, batch_first=True, dropout=0.1) # MLP self.hidden_layer = nn.Linear(self.lstm_dim * 2, self.lstm_dim) self.relu = torch.nn.LeakyReLU() self.output_layer = nn.Linear(self.lstm_dim, 1) self.dropout = nn.Dropout(p=0.1) def init_hidden(self, batch_size): return (autograd.Variable(torch.zeros(2, batch_size, self.lstm_dim)).to(self.device), autograd.Variable(torch.zeros(2, batch_size, self.lstm_dim)).to(self.device)) def lstm_pass(self, lstm, padded_embeds, lengths): # Before using PyTorch pack_padded_sequence we need to order the sequences batch by descending sequence length lengths, perm_idx = lengths.sort(0, descending=True) padded_embeds = padded_embeds[perm_idx] # Pack the batch and ignore the padding padded_embeds = torch.nn.utils.rnn.pack_padded_sequence( padded_embeds, lengths, batch_first=True) # Initialize the hidden state batch_size = len(lengths) hidden = self.init_hidden(batch_size) # Feed into the RNN lstm_out, hidden = lstm(padded_embeds, hidden) # Unpack the batch after the RNN lstm_out, lengths = torch.nn.utils.rnn.pad_packed_sequence( lstm_out, batch_first=True) # Remember that our outputs are sorted. We want the original ordering _, unperm_idx = perm_idx.sort(0) lstm_out = lstm_out[unperm_idx] lengths = lengths[unperm_idx] return lstm_out def forward(self, padded_tcrs, peps, pep_lens): # TCR Encoder: # Embedding concat = padded_tcrs.view(self.batch_size, self.max_len * self.input_dim) encoded_tcrs = self.autoencoder.encoder(concat) # PEPTIDE Encoder: # Embedding pep_embeds = self.pep_embedding(peps) # LSTM Acceptor pep_lstm_out = self.lstm_pass(self.pep_lstm, pep_embeds, pep_lens) pep_last_cell = torch.cat([ pep_lstm_out[i, j.data - 1] for i, j in enumerate(pep_lens) ]).view(len(pep_lens), self.lstm_dim) # MLP Classifier tcr_pep_concat = torch.cat([encoded_tcrs, pep_last_cell], 1) hidden_output = self.dropout( self.relu(self.hidden_layer(tcr_pep_concat))) mlp_output = self.output_layer(hidden_output) output = F.sigmoid(mlp_output) return output