def validation(args, model_name, model_in, model_out, T, epoch, boundaries, random_seed_list, fig=None, ax=None): # evaluation mode model_in.eval() model_out.eval() T.eval() # dataset loader valid_loader = get_dataloader(model_in, args, model_in.dataset.validset, batch_size=args.validation_batch_size, shuffle=False) # number samples in validset validset_len = len(model_in.dataset.validset) # tensor to molecule smiles def input_tensor2string(input_tensor): return model_in.tensor2string(input_tensor) trainset = set(model_in.dataset.trainset).union(model_out.dataset.trainset) #generate output molecule from input molecule def local_input2output(input_batch): return input2output(args, input_batch, model_in, T, model_out, random_seed_list, max_out_len=args.validation_max_len) # use general validation function avg_similarity, avg_property, avg_SR, avg_validity, avg_novelty, avg_diversity =\ general_validation(args, local_input2output, input_tensor2string, boundaries, valid_loader, validset_len, model_name, trainset, epoch, fig=fig, ax=ax) # back to train mode model_in.train() model_out.train() T.train() return avg_similarity, avg_property, avg_SR, avg_validity, avg_novelty, avg_diversity
def check_testset(args, model_in, model_out, input2output_func): testset_path = 'dataset/' + args.property + '/' + args.testset_filename results_file_path = args.plots_folder + '/' + args.property + '/UGMMT_' + args.property + '_test.txt' valid_results_file_path = args.plots_folder + '/' + args.property + '/valid_UGMMT_' + args.property + '_test.txt' print(' ') print('Loading testset from file => ' + testset_path) # train set for novelty trainset = set(model_in.dataset.trainset).union(model_out.dataset.trainset) # build testset from filename testset_list = filname2testset(testset_path, model_in, model_out) test_loader = get_dataloader(model_in, args, testset_list, batch_size=args.batch_size, shuffle=False) # generate random seeds random_seed_list = get_random_list(args.num_retries) def input2output_func_aux(input_batch): return input2output_func(input_batch, random_seed_list) def input2smiles(input): return model_in.tensor2string(input) # to smiles # generate results file generate_results_file(test_loader, input2output_func_aux, input2smiles, results_file_path) # result file -> valid results + property and similarity for output molecules process_results_file(results_file_path, args, valid_results_file_path, trainset) # calculate metrics validity_mean, validity_std, \ diversity_mean, diversity_std, \ novelty_mean, novelty_std, \ property_mean, property_std, \ similarity_mean, similarity_std, \ SR_mean, SR_std = \ valid_results_file_to_metrics(valid_results_file_path, args, len(testset_list)) # print results print(' ') print('Property => ' + args.property) print('property => mean: ' + str(round(property_mean, 3)) + ' std: ' + str(round(property_std, 3))) print('fingerprint Similarity => mean: ' + str(round(similarity_mean, 3)) + ' std: ' + str(round(similarity_std, 3))) print('SR => mean: ' + str(round(SR_mean, 3)) + ' std: ' + str(round(SR_std, 3))) print('validity => mean: ' + str(round(validity_mean, 3)) + ' std: ' + str(round(validity_std, 3))) print('novelty => mean: ' + str(round(novelty_mean, 3)) + ' std: ' + str(round(novelty_std, 3))) print('diversity => mean: ' + str(round(diversity_mean, 3)) + ' std: ' + str(round(diversity_std, 3)))
def intermediate_analysis(args, model_in, input2all_func): assert args.use_EETN # get data data_path = 'dataset/' + args.property + '/A_train.txt' print(' ') print('Intermediate_analyse: Loading data from file => ' + data_path) # build dataset from filename df = pd.read_csv(data_path, header=None) dataset_list = list(df.iloc[:, 0]) # get random molecule random_seed_list = get_random_list(args.num_samples, last=len(dataset_list) - 1) for seed in random_seed_list: rand_mol = dataset_list.pop(seed) # get intermediate for the *** RAND *** molecule rand_mol_loader = get_dataloader(model_in, args, [rand_mol], batch_size=1, shuffle=False) for i, input_batch in enumerate(rand_mol_loader): rand_enc, rand_trans, [rand_out_mol] = input2all_func(input_batch) # if source output molecule is invalid - skip to next molecule if not is_valid_molecule(rand_out_mol, args.property): continue sim_rand_all_sort_list = [(mol, similarity_calc(rand_mol, mol)) for mol in dataset_list] sim_rand_all_sort_list.sort(key=lambda x: x[1], reverse=True) neighbours_near_list = [ mol for (mol, _) in sim_rand_all_sort_list[:args.num_neighbours] ] neighbours_far_list = [ mol for (mol, _) in sim_rand_all_sort_list[args.num_neighbours:2 * args.num_neighbours] ] # neighbours_far_list = [mol for (mol, _) in sim_rand_all_sort_list[-args.num_neighbours:]] # get intermediate for *** NEAR *** molecules near_loader = get_dataloader(model_in, args, neighbours_near_list, batch_size=args.num_neighbours, shuffle=False) for i, input_batch in enumerate(near_loader): near_enc, near_trans, near_out_mol = input2all_func(input_batch) # get intermediate for *** FAR *** molecules far_loader = get_dataloader(model_in, args, neighbours_far_list, batch_size=args.num_neighbours, shuffle=False) for i, input_batch in enumerate(far_loader): far_enc, far_trans, far_out_mol = input2all_func(input_batch) p = 3 print(' ') print('seed => ' + str(seed)) # analyse *** FAR *** far_input_sim_mean, far_input_sim_std, far_enc_dis_mean, far_enc_dis_std, far_trans_dis_mean, far_trans_dis_std, far_out_sim_mean, far_out_sim_std =\ analyze_intermediate_outputs(rand_mol, rand_enc, rand_trans, rand_out_mol, neighbours_far_list, far_enc, far_trans, far_out_mol) print('*** FAR (NOT similar) ***') print('Input mol similarity => mean: ' + str(round(far_input_sim_mean, p)) + ' std: ' + str(round(far_input_sim_std, p))) print('A_enc emb distance => mean: ' + str(round(far_enc_dis_mean, p)) + ' std: ' + str(round(far_enc_dis_std, p))) print('Translator emb distance => mean: ' + str(round(far_trans_dis_mean, p)) + ' std: ' + str(round(far_trans_dis_std, p))) print('Output emb distance => mean: ' + str(round(far_out_sim_mean, p)) + ' std: ' + str(round(far_out_sim_std, p))) # analyse *** NEAR *** near_input_sim_mean, near_input_sim_std, near_enc_dis_mean, near_enc_dis_std, near_trans_dis_mean, near_trans_dis_std, near_out_sim_mean, near_out_sim_std =\ analyze_intermediate_outputs(rand_mol, rand_enc, rand_trans, rand_out_mol, neighbours_near_list, near_enc, near_trans, near_out_mol) print('*** NEAR (similar) ***') print('Input mol similarity => mean: ' + str(round(near_input_sim_mean, p)) + ' std: ' + str(round(near_input_sim_std, p))) print('A_enc emb distance => mean: ' + str(round(near_enc_dis_mean, p)) + ' std: ' + str(round(near_enc_dis_std, p))) print('Translator emb distance => mean: ' + str(round(near_trans_dis_mean, p)) + ' std: ' + str(round(near_trans_dis_std, p))) print('Output emb distance => mean: ' + str(round(near_out_sim_mean, p)) + ' std: ' + str(round(near_out_sim_std, p)))
def discover_approved_drugs(args, model_in, model_out, input2output_func): print(' ') print('Loading approved drugs dataset from file =>' + args.FDA_approved_filename) # build testset from filename approved_drugs_list, smiles_name_dict = filname2testset( args.FDA_approved_filename, model_in, model_out, drugs=True) approved_drugs_list = [ drug for drug in approved_drugs_list if len(drug) <= args.approved_drugs_max_len ] approved_drugs_set = set(approved_drugs_list) print('Final number of drugs (after length removal): ' + str(len(approved_drugs_set))) print(' ') # testset loader approved_drugs_loader = get_dataloader(model_in, args, approved_drugs_list, batch_size=args.batch_size, shuffle=False) all_drug_matches = dict() num_valid_molecule_matches = 0 unique_drugs_found = set() unique_valid_molecules_found = set() for i, input_batch in enumerate(approved_drugs_loader): current_batch_size = len(input_batch) # generate random seeds random_seed_list = get_random_list(args.approved_drugs_retries) # generate output batch output_batch = input2output_func(input_batch, random_seed_list) # for every molecule for j, input in enumerate(input_batch): output_unique_molecule_smiles_set = set( output_batch[j::current_batch_size]) # num valid molecules matches output_valid_unique_molecule_smiles_list = \ [output_molecule_smiles for output_molecule_smiles in output_unique_molecule_smiles_set if is_valid_molecule(output_molecule_smiles, args.property)] num_unique_valid_molecules_per_input = len( output_valid_unique_molecule_smiles_list) num_valid_molecule_matches += num_unique_valid_molecules_per_input unique_valid_molecules_found = unique_valid_molecules_found.union( output_valid_unique_molecule_smiles_list) intersection_set = approved_drugs_set.intersection( output_unique_molecule_smiles_set) if not not intersection_set: # if intersection_set is not empty input_molecule_smiles = model_in.tensor2string( input) # to smiles all_drug_matches[input_molecule_smiles] = intersection_set unique_drugs_found = unique_drugs_found.union(intersection_set) drug_index = i * current_batch_size + j if drug_index % 30 == 0: print('Drug #' + str(drug_index + 1) + ': Unique approved drugs found so far: ' + str(len(unique_drugs_found))) # print detailed results similarity_list = [] out_property_list = [] property_improvement = [] for input_molecule_smiles, output_molecule_smiles_set in all_drug_matches.items( ): print(' ') property_value_in = property_calc(input_molecule_smiles, args.property) print('Input: ' + str(input_molecule_smiles) + '. Property value: ' + str(round(property_value_in, 4)) + ' Drug name: ' + smiles_name_dict[input_molecule_smiles]) for output_molecule_smiles in output_molecule_smiles_set: property_value_out = property_calc(output_molecule_smiles, args.property) similarity_in_out = similarity_calc(input_molecule_smiles, output_molecule_smiles) print('Output: ' + str(output_molecule_smiles) + '. Property value: ' + str(round(property_value_out, 4)) + ' Drug name: ' + smiles_name_dict[output_molecule_smiles]) print('Similarity: ' + str(round(similarity_in_out, 4))) similarity_list.append(similarity_in_out) out_property_list.append(property_value_out) property_improvement.append(property_value_out - property_value_in) # print final results average_drug_matches_similarity = sum(similarity_list) / len( similarity_list) if len(similarity_list) > 0 else -1 average_out_drugs_property_val = sum(out_property_list) / len( out_property_list) if len(out_property_list) > 0 else -1 average_property_improvement = sum(property_improvement) / len( property_improvement) if len(property_improvement) > 0 else -1 unique_drugs_per_unique_legal_molecule = 100 * len( unique_drugs_found) / len(unique_valid_molecules_found) if len( unique_valid_molecules_found) > 0 else -1 drug_matches_per_valid_matches = 100 * len( out_property_list ) / num_valid_molecule_matches if num_valid_molecule_matches > 0 else -1 print(' ') print('Property => ' + args.property) print('Translation direction => ' + args.test_direction) print('Unique approved drugs found: ' + str(len(unique_drugs_found))) print('Unique legal molecules generated: ' + str(len(unique_valid_molecules_found))) print('Unique approved drugs per unique legal molecules generated: ' + str(round(unique_drugs_per_unique_legal_molecule, 5)) + '%') print('Drugs matches: ' + str(len(out_property_list))) print('All valid matches: ' + str(num_valid_molecule_matches)) print('Unique drug-drug matches per drug-valid molecule matches: ' + str(round(drug_matches_per_valid_matches, 5)) + '%') print('Average generated drugs property value => ' + str(round(average_out_drugs_property_val, 3))) print('Average drug matches fingerprint Similarity =>' + str(round(average_drug_matches_similarity, 3))) print('Average drug matches property value improvement => ' + str(round(average_property_improvement, 3)))
# optimizers for ablation if args.gan_loss: optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=args.init_lr, betas=(0.5, 0.999)) optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=args.init_lr, betas=(0.5, 0.999)) # scheduler lr_scheduler_T = torch.optim.lr_scheduler.LambdaLR(optimizer_T, lr_lambda=LambdaLR(args.epochs, args.epoch_init, args.epoch_decay).step) # schedulers for ablation if args.gan_loss: lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(args.epochs, args.epoch_init, args.epoch_decay).step) lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(args.epochs, args.epoch_init, args.epoch_decay).step) # train dataloaders A_train_loader = get_dataloader(model_A, args, model_A.dataset.trainset, args.batch_size, collate_fn=None, shuffle=True) B_train_loader = get_dataloader(model_B, args, model_B.dataset.trainset, args.batch_size, collate_fn=None, shuffle=True) # buffer for max_size last fake samples for ablation if args.gan_loss: fake_A_buffer = ReplayBuffer(max_size=50) fake_B_buffer = ReplayBuffer(max_size=50) else: fake_A_buffer, fake_B_buffer = None, None # for early stopping best_criterion = None runs_without_improvement = 0 # generate random seeds