def train_model(batches, test_batches, device, args, params): """ Train and evaluate the model """ losses = [] # We use Cross-Entropy loss loss_function = nn.BCELoss() # Set model with relevant parameters if args['siamese'] is True: model = SiameseLSTMClassifier(params['emb_dim'], params['lstm_dim'], device) else: model = DoubleLSTMClassifier(params['emb_dim'], params['lstm_dim'], params['dropout'], device) # Move to GPU model.to(device) # We use Adam optimizer optimizer = optim.Adam(model.parameters(), lr=params['lr'], weight_decay=params['wd']) # Train several epochs best_auc = 0 best_roc = None for epoch in range(params['epochs']): print('epoch:', epoch + 1) epoch_time = time.time() # Train model and get loss loss = train_epoch(batches, model, loss_function, optimizer, device) losses.append(loss) # Compute auc train_auc = evaluate(model, batches, device)[0] print('train auc:', train_auc) with open(args['train_auc_file'], 'a+') as file: file.write(str(train_auc) + '\n') if params['option'] == 2: test_w, test_c = test_batches test_auc_w = evaluate(model, test_w, device) print('test auc w:', test_auc_w) with open(args['test_auc_file_w'], 'a+') as file: file.write(str(test_auc_w) + '\n') test_auc_c = evaluate(model, test_c, device) print('test auc c:', test_auc_c) with open(args['test_auc_file_c'], 'a+') as file: file.write(str(test_auc_c) + '\n') else: test_auc, roc = evaluate(model, test_batches, device) # nni.report_intermediate_result(test_auc) if test_auc > best_auc: best_auc = test_auc best_roc = roc print('test auc:', test_auc) with open(args['test_auc_file'], 'a+') as file: file.write(str(test_auc) + '\n') print('one epoch time:', time.time() - epoch_time) return model, best_auc, best_roc
def load_model_and_data(args): # train if args.train_data_file == 'auto': dir = 'save_results' p_key = 'protein' if args.protein else '' args.train_data_file = dir + '/' + '_'.join([ args.model_type, args.dataset, args.sampling, p_key, 'train.pickle' ]) # test if args.test_data_file == 'auto': dir = 'save_results' p_key = 'protein' if args.protein else '' args.test_data_file = dir + '/' + '_'.join([ args.model_type, args.dataset, args.sampling, p_key, 'test.pickle' ]) # Read train data with open(args.train_data_file, "rb") as file: train_data = pickle.load(file) # Read test data with open(args.test_data_file, "rb") as file: test_data = pickle.load(file) # trained model if args.model_file == 'auto': dir = 'save_results' p_key = 'protein' if args.protein else '' args.model_file = dir + '/' + '_'.join( [args.model_type, args.dataset, args.sampling, p_key, 'model.pt']) # enc_dim = 30 # Load model device = args.device if args.model_type == 'ae': checkpoint = torch.load(args.model_file, map_location=device) params = checkpoint['params'] args.ae_file = 'TCR_Autoencoder/tcr_ae_dim_' + str( params['enc_dim']) + '.pt' model = AutoencoderLSTMClassifier(params['emb_dim'], device, 28, 21, params['enc_dim'], params['batch_size'], args.ae_file, False) model.load_state_dict(checkpoint['model_state_dict']) model.to(device) model.eval() if args.model_type == 'lstm': checkpoint = torch.load(args.model_file, map_location=device) params = checkpoint['params'] model = DoubleLSTMClassifier(params['emb_dim'], params['lstm_dim'], params['dropout'], device) model.load_state_dict(checkpoint['model_state_dict']) model.to(device) model.eval() data = [train_data, test_data] return model, data
def predict(args): # Word to index dictionary amino_acids = [letter for letter in 'ARNDCEQGHILKMFPSTWYV'] if args.model_type == 'lstm': amino_to_ix = { amino: index for index, amino in enumerate(['PAD'] + amino_acids) } if args.model_type == 'ae': pep_atox = { amino: index for index, amino in enumerate(['PAD'] + amino_acids) } tcr_atox = { amino: index for index, amino in enumerate(amino_acids + ['X']) } if args.ae_file == 'auto': args.ae_file = 'TCR_Autoencoder/tcr_autoencoder.pt' if args.model_file == 'auto': dir = 'models' p_key = 'protein' if args.protein else '' args.model_file = dir + '/' + '_'.join( [args.model_type, args.dataset, args.sampling, p_key, 'model.pt']) if args.test_data_file == 'auto': args.test_data_file = 'pairs_example.csv' # Read test data tcrs = [] peps = [] signs = [] max_len = 28 with open(args.test_data_file, 'r') as csv_file: reader = csv.reader(csv_file) for line in reader: tcr, pep = line if args.model_type == 'ae' and len(tcr) >= max_len: continue tcrs.append(tcr) peps.append(pep) signs.append(0.0) tcrs_copy = tcrs.copy() peps_copy = peps.copy() # Load model device = args.device if args.model_type == 'ae': model = AutoencoderLSTMClassifier(10, device, 28, 21, 30, 50, args.ae_file, False) checkpoint = torch.load(args.model_file, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model.to(device) model.eval() if args.model_type == 'lstm': model = DoubleLSTMClassifier(10, 30, 0.1, device) checkpoint = torch.load(args.model_file, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model.to(device) model.eval() pass # Predict batch_size = 50 if args.model_type == 'ae': test_batches = ae.get_full_batches(tcrs, peps, signs, tcr_atox, pep_atox, batch_size, max_len) preds = ae.predict(model, test_batches, device) if args.model_type == 'lstm': lstm.convert_data(tcrs, peps, amino_to_ix) test_batches = lstm.get_full_batches(tcrs, peps, signs, batch_size, amino_to_ix) preds = lstm.predict(model, test_batches, device) # Print predictions for tcr, pep, pred in zip(tcrs_copy, peps_copy, preds): print('\t'.join([tcr, pep, str(pred)]))
def protein_test(args): assert args.protein # Word to index dictionary amino_acids = [letter for letter in 'ARNDCEQGHILKMFPSTWYV'] if args.model_type == 'lstm': amino_to_ix = { amino: index for index, amino in enumerate(['PAD'] + amino_acids) } if args.model_type == 'ae': pep_atox = { amino: index for index, amino in enumerate(['PAD'] + amino_acids) } tcr_atox = { amino: index for index, amino in enumerate(amino_acids + ['X']) } if args.ae_file == 'auto': args.ae_file = 'TCR_Autoencoder/tcr_autoencoder.pt' if args.test_data_file == 'auto': dir = 'memory_and_protein' p_key = 'protein' if args.protein else '' args.test_data_file = dir + '/' + '_'.join([ args.model_type, args.dataset, args.sampling, p_key, 'test.pickle' ]) if args.model_file == 'auto': dir = 'memory_and_protein' p_key = 'protein' if args.protein else '' args.model_file = dir + '/' + '_'.join( [args.model_type, args.dataset, args.sampling, p_key, 'model.pt']) # Read test data with open(args.test_data_file, 'rb') as handle: test = pickle.load(handle) device = args.device if args.model_type == 'ae': test_tcrs, test_peps, test_signs = ae_get_lists_from_pairs(test, 28) model = AutoencoderLSTMClassifier(10, device, 28, 21, 30, 50, args.ae_file, False) checkpoint = torch.load(args.model_file) model.load_state_dict(checkpoint['model_state_dict']) model.to(device) model.eval() if args.model_type == 'lstm': test_tcrs, test_peps, test_signs = lstm_get_lists_from_pairs(test) model = DoubleLSTMClassifier(10, 30, 0.1, device) checkpoint = torch.load(args.model_file) model.load_state_dict(checkpoint['model_state_dict']) model.to(device) model.eval() pass # Get frequent peps list if args.dataset == 'mcpas': datafile = 'McPAS-TCR.csv' p = [] protein_peps = {} with open(datafile, 'r', encoding='unicode_escape') as file: file.readline() reader = csv.reader(file) for line in reader: pep, protein = line[11], line[9] if protein == 'NA' or pep == 'NA': continue p.append(protein) try: protein_peps[protein].append(pep) except KeyError: protein_peps[protein] = [pep] d = {t: p.count(t) for t in set(p)} sorted_d = sorted(d.items(), key=lambda k: k[1], reverse=True) proteins = [t[0] for t in sorted_d] """ McPAS most frequent proteins NP177 Influenza Matrix protein (M1) Influenza pp65 Cytomegalovirus (CMV) BMLF-1 Epstein Barr virus (EBV) PB1 Influenza """ rocs = [] for protein in proteins[:50]: protein_shows = [ i for i in range(len(test_peps)) if test_peps[i] in protein_peps[protein] ] test_tcrs_protein = [test_tcrs[i] for i in protein_shows] test_peps_protein = [test_peps[i] for i in protein_shows] test_signs_protein = [test_signs[i] for i in protein_shows] if args.model_type == 'ae': test_batches_protein = ae.get_full_batches(test_tcrs_protein, test_peps_protein, test_signs_protein, tcr_atox, pep_atox, 50, 28) if args.model_type == 'lstm': lstm.convert_data(test_tcrs_protein, test_peps_protein, amino_to_ix) test_batches_protein = lstm.get_full_batches( test_tcrs_protein, test_peps_protein, test_signs_protein, 50, amino_to_ix) if len(protein_shows): try: if args.model_type == 'ae': test_auc, roc = ae.evaluate_full(model, test_batches_protein, device) if args.model_type == 'lstm': test_auc, roc = lstm.evaluate_full(model, test_batches_protein, device) rocs.append((pep, roc)) # print(protein) print(str(test_auc)) # print(protein + ', ' + str(test_auc)) except ValueError: # print(protein) print('NA') # print(protein + ', ' 'NA') pass return rocs
def pep_test(args): # Word to index dictionary amino_acids = [letter for letter in 'ARNDCEQGHILKMFPSTWYV'] if args.model_type == 'lstm': amino_to_ix = { amino: index for index, amino in enumerate(['PAD'] + amino_acids) } if args.model_type == 'ae': pep_atox = { amino: index for index, amino in enumerate(['PAD'] + amino_acids) } tcr_atox = { amino: index for index, amino in enumerate(amino_acids + ['X']) } if args.ae_file == 'auto': args.ae_file = 'TCR_Autoencoder/tcr_ae_dim_100.pt' if args.test_data_file == 'auto': dir = 'final_results' p_key = 'protein' if args.protein else '' args.test_data_file = dir + '/' + '_'.join([ args.model_type, args.dataset, args.sampling, p_key, 'test.pickle' ]) if args.model_file == 'auto': dir = 'final_results' p_key = 'protein' if args.protein else '' args.model_file = dir + '/' + '_'.join( [args.model_type, args.dataset, args.sampling, p_key, 'model.pt']) # Read test data with open(args.test_data_file, 'rb') as handle: test = pickle.load(handle) device = args.device if args.model_type == 'ae': test_tcrs, test_peps, test_signs = ae_get_lists_from_pairs(test, 28) model = AutoencoderLSTMClassifier(10, device, 28, 21, 100, 50, args.ae_file, False) checkpoint = torch.load(args.model_file) model.load_state_dict(checkpoint['model_state_dict']) model.to(device) model.eval() if args.model_type == 'lstm': test_tcrs, test_peps, test_signs = lstm_get_lists_from_pairs(test) model = DoubleLSTMClassifier(10, 500, 0.1, device) checkpoint = torch.load(args.model_file) model.load_state_dict(checkpoint['model_state_dict']) model.to(device) model.eval() pass # Get frequent peps list if args.dataset == 'mcpas': datafile = 'nine_class_testdata.csv' p = [] with open(datafile, 'r', encoding='unicode_escape') as file: file.readline() reader = csv.reader(file) for line in reader: pep = line[1] if pep == 'NA': continue p.append(pep) d = {t: p.count(t) for t in set(p)} sorted_d = sorted(d.items(), key=lambda k: k[1], reverse=True) peps = [t[0] for t in sorted_d] """ McPAS most frequent peps LPRRSGAAGA Influenza GILGFVFTL Influenza GLCTLVAML Epstein Barr virus (EBV) NLVPMVATV Cytomegalovirus (CMV) SSYRRPVGI Influenza """ rocs = [] auc = [] for pep in peps[:249]: pep_shows = [i for i in range(len(test_peps)) if pep == test_peps[i]] test_tcrs_pep = [test_tcrs[i] for i in pep_shows] test_peps_pep = [test_peps[i] for i in pep_shows] test_signs_pep = [test_signs[i] for i in pep_shows] if args.model_type == 'ae': test_batches_pep = ae.get_full_batches(test_tcrs_pep, test_peps_pep, test_signs_pep, tcr_atox, pep_atox, 50, 28) if args.model_type == 'lstm': lstm.convert_data(test_tcrs_pep, test_peps_pep, amino_to_ix) test_batches_pep = lstm.get_full_batches(test_tcrs_pep, test_peps_pep, test_signs_pep, 50, amino_to_ix) if len(pep_shows): try: if args.model_type == 'ae': test_auc, roc = ae.evaluate_full(model, test_batches_pep, device) filename = 'pep_test_result_ae2' if args.model_type == 'lstm': test_auc, roc = lstm.evaluate_full(model, test_batches_pep, device) filename = 'pep_test_result_lstm2' rocs.append((pep, roc)) auc.append([pep, roc, test_auc]) print(str(test_auc)) # print(pep + ', ' + str(test_auc)) except ValueError: print('NA') # print(pep + ', ' 'NA') pass pickle.dump(auc, open(filename, 'wb'))
def pep_test(args): # Word to index dictionary amino_acids = [letter for letter in 'ARNDCEQGHILKMFPSTWYV'] if args.model_type == 'lstm': amino_to_ix = { amino: index for index, amino in enumerate(['PAD'] + amino_acids) } if args.model_type == 'ae': pep_atox = { amino: index for index, amino in enumerate(['PAD'] + amino_acids) } tcr_atox = { amino: index for index, amino in enumerate(amino_acids + ['X']) } if args.ae_file == 'auto': args.ae_file = 'TCR_Autoencoder/tcr_ae_dim_100.pt' if args.test_data_file == 'auto': dir = 'final_results' p_key = 'protein' if args.protein else '' args.test_data_file = dir + '/' + '_'.join([ args.model_type, args.dataset, args.sampling, p_key, 'test.pickle' ]) if args.model_file == 'auto': dir = 'final_results' p_key = 'protein' if args.protein else '' args.model_file = dir + '/' + '_'.join( [args.model_type, args.dataset, args.sampling, p_key, 'model.pt']) # Read test data with open(args.test_data_file, 'rb') as handle: test = pickle.load(handle) device = args.device if args.model_type == 'ae': test_tcrs, test_peps, test_signs = ae_get_lists_from_pairs(test, 28) model = AutoencoderLSTMClassifier(10, device, 28, 21, 100, 50, args.ae_file, False) checkpoint = torch.load(args.model_file) model.load_state_dict(checkpoint['model_state_dict']) model.to(device) model.eval() if args.model_type == 'lstm': test_tcrs, test_peps, test_signs = lstm_get_lists_from_pairs(test) model = DoubleLSTMClassifier(10, 500, 0.1, device) checkpoint = torch.load(args.model_file) model.load_state_dict(checkpoint['model_state_dict']) model.to(device) model.eval() pass rocs = [] auc = [] if args.model_type == 'ae': test_batches_pep = ae.get_full_batches(test_tcrs, test_peps, test_signs, tcr_atox, pep_atox, 50, 28) test_auc, roc = ae.evaluate_full(model, test_batches_pep, device) filename = 'tpp_test_result_ae2' if args.model_type == 'lstm': lstm.convert_data(test_tcrs, test_peps, amino_to_ix) test_batches_pep = lstm.get_full_batches(test_tcrs, test_peps, test_signs, 50, amino_to_ix) test_auc, roc = lstm.evaluate_full(model, test_batches_pep, device) filename = 'tpp_test_result_lstm2' rocs.append(roc) auc.append([roc, test_auc]) print(str(test_auc)) pickle.dump(auc, open(filename, 'wb'))