def load_pretrained_paccmann(self, params_file: str, lang_file: str, weights_file: str, batch_size: int, batch_mode: str): params = dict() with open(params_file, 'r') as f: params.update(json.load(f)) params['batch_mode'] = batch_mode params['batch_size'] = batch_size self.selfies = params.get('selfies', False) self.device = get_device() self.smiles_language = SMILESLanguage.load(lang_file) self.gru_encoder = StackGRUEncoder(params).to(self.device) self.gru_decoder = StackGRUDecoder(params).to(self.device) self.gru_vae = TeacherVAE(self.gru_encoder, self.gru_decoder).to(self.device) self.gru_vae.load_state_dict( torch.load(weights_file, map_location=self.device)) self.gru_vae.eval() transforms = [] if self.selfies: transforms += [Selfies()] transforms += [ SMILESToTokenIndexes(smiles_language=self.smiles_language) ] transforms += [ToTensor(device=self.device)] self.transform = Compose(transforms)
def load_mca(self, model_path: str): """ Restores pretrained MCA Arguments: model_path {String} -- Path to the model """ # Restore model self.model_path = model_path with open(os.path.join(model_path, 'model_params.json')) as f: params = json.load(f) # Set up language and transforms self.smiles_language = SMILESLanguage.load( os.path.join(model_path, 'smiles_language.pkl') ) self.transforms = self.compose_smiles_transforms(params) # Initialize and restore model weights self.model = MODEL_FACTORY['mca'](params) self.model.load( os.path.join(model_path, 'weights', 'best_ROC-AUC_mca.pt'), map_location=get_device() ) self.model.eval()
def __init__(self, filename, smiles_language): super().__init__() self.filename = filename with open(self.filename, 'r') as f: self.data = [line.strip().split('\t') for line in f.readlines()] self.data = [[(x1, x2), int(y)] for x1, x2, y in self.data] if isinstance(smiles_language, str): smiles_language = SMILESLanguage.load(smiles_language) self.smiles_language = smiles_language
def create_smiles_language(smi_path: str, pretrained_path: str) -> None: """ Create a SMILESLanguage object and save it to disk. Args: smi_path (str): path to a folder containing .smi files. pretrained_path (str): directory to store the language as text files. """ os.makedirs(pretrained_path, exist_ok=True) smiles_language = SMILESLanguage() smiles_language.add_smis([ os.path.join(smi_path, smi_filename) for smi_filename in os.listdir(smi_path) if smi_filename.endswith('.smi') ]) smiles_language.save_pretrained(pretrained_path)
def main(train_affinity_filepath, test_affinity_filepath, protein_filepath, smi_filepath, smiles_language_filepath, protein_language_filepath, model_path, params_filepath, training_name): logger = logging.getLogger(f'{training_name}') # Process parameter file: params = {} with open(params_filepath) as fp: params.update(json.load(fp)) # Create model directory and dump files model_dir = os.path.join(model_path, training_name) os.makedirs(os.path.join(model_dir, 'weights'), exist_ok=True) os.makedirs(os.path.join(model_dir, 'results'), exist_ok=True) with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp: json.dump(params, fp, indent=4) # Prepare the dataset logger.info("Start data preprocessing...") device = get_device() # Load languages smiles_language = SMILESLanguage.load(smiles_language_filepath) protein_language = ProteinLanguage.load(protein_language_filepath) # Assemble datasets train_dataset = DrugAffinityDataset( drug_affinity_filepath=train_affinity_filepath, smi_filepath=smi_filepath, protein_filepath=protein_filepath, smiles_language=smiles_language, protein_language=protein_language, smiles_padding=params.get('smiles_padding', True), smiles_padding_length=params.get('smiles_padding_length', None), smiles_add_start_and_stop=params.get('smiles_add_start_stop', True), smiles_augment=params.get('augment_smiles', False), smiles_canonical=params.get('smiles_canonical', False), smiles_kekulize=params.get('smiles_kekulize', False), smiles_all_bonds_explicit=params.get('smiles_bonds_explicit', False), smiles_all_hs_explicit=params.get('smiles_all_hs_explicit', False), smiles_remove_bonddir=params.get('smiles_remove_bonddir', False), smiles_remove_chirality=params.get('smiles_remove_chirality', False), smiles_selfies=params.get('selfies', False), protein_amino_acid_dict=params.get('protein_amino_acid_dict', 'iupac'), protein_padding=params.get('protein_padding', True), protein_padding_length=params.get('protein_padding_length', None), protein_add_start_and_stop=params.get('protein_add_start_stop', True), protein_augment_by_revert=params.get('protein_augment', False), device=device, drug_affinity_dtype=torch.float, backend='eager') train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=params['batch_size'], shuffle=True, drop_last=True, num_workers=params.get( 'num_workers', 0)) test_dataset = DrugAffinityDataset( drug_affinity_filepath=test_affinity_filepath, smi_filepath=smi_filepath, protein_filepath=protein_filepath, smiles_language=smiles_language, protein_language=protein_language, smiles_padding=params.get('smiles_padding', True), smiles_padding_length=params.get('smiles_padding_length', None), smiles_add_start_and_stop=params.get('smiles_add_start_stop', True), smiles_augment=False, smiles_canonical=params.get('smiles_canonical', False), smiles_kekulize=params.get('smiles_kekulize', False), smiles_all_bonds_explicit=params.get('smiles_bonds_explicit', False), smiles_all_hs_explicit=params.get('smiles_all_hs_explicit', False), smiles_remove_bonddir=params.get('smiles_remove_bonddir', False), smiles_remove_chirality=params.get('smiles_remove_chirality', False), smiles_selfies=params.get('selfies', False), protein_amino_acid_dict=params.get('protein_amino_acid_dict', 'iupac'), protein_padding=params.get('protein_padding', True), protein_padding_length=params.get('protein_padding_length', None), protein_add_start_and_stop=params.get('protein_add_start_stop', True), protein_augment_by_revert=False, device=device, drug_affinity_dtype=torch.float, backend='eager') test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=params['batch_size'], shuffle=True, drop_last=True, num_workers=params.get( 'num_workers', 0)) logger.info( f'Training dataset has {len(train_dataset)} samples, test set has ' f'{len(test_dataset)}.') logger.info(f'Device for data loader is {train_dataset.device} and for ' f'model is {device}') save_top_model = os.path.join(model_dir, 'weights/{}_{}_{}.pt') params.update({ 'smiles_vocabulary_size': smiles_language.number_of_tokens, 'protein_vocabulary_size': protein_language.number_of_tokens }) # yapf: disable model_fn = params.get('model_fn', 'bimodal_mca') model = MODEL_FACTORY[model_fn](params).to(device) model._associate_language(smiles_language) model._associate_language(protein_language) if os.path.isfile(os.path.join(model_dir, 'weights', 'best_mca.pt')): logger.info('Found existing model, restoring now...') try: model.load(os.path.join(model_dir, 'weights', 'best_mca.pt')) with open(os.path.join(model_dir, 'results', 'mse.json'), 'r') as f: info = json.load(f) max_roc_auc = info['best_roc_auc'] min_loss = info['test_loss'] except Exception: min_loss, max_roc_auc = 100, 0 else: min_loss, max_roc_auc = 100, 0 # Define optimizer optimizer = (OPTIMIZER_FACTORY[params.get('optimizer', 'adam')](model.parameters(), lr=params.get( 'lr', 0.001))) num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) params.update({'number_of_parameters': num_params}) logger.info(f'Number of parameters: {num_params}') logger.info(f'Model: {model}') # Overwrite params.json file with updated parameters. with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp: json.dump(params, fp) # Start training logger.info('Training about to start...\n') t = time() model.save(save_top_model.format('epoch', '0', model_fn)) for epoch in range(params['epochs']): model.train() logger.info(f"== Epoch [{epoch}/{params['epochs']}] ==") train_loss = 0 for ind, (smiles, proteins, y) in enumerate(train_loader): if ind % 100 == 0: logger.info(f'Batch {ind}/{len(train_loader)}') y_hat, pred_dict = model(smiles, proteins) loss = model.loss(y_hat, y.to(device)) optimizer.zero_grad() loss.backward() # Apply gradient clipping # torch.nn.utils.clip_grad_norm_(model.parameters(),1e-6) optimizer.step() train_loss += loss.item() logger.info("\t **** TRAINING **** " f"Epoch [{epoch + 1}/{params['epochs']}], " f"loss: {train_loss / len(train_loader):.5f}. " f"This took {time() - t:.1f} secs.") t = time() # Measure validation performance model.eval() with torch.no_grad(): test_loss = 0 predictions = [] labels = [] for ind, (smiles, proteins, y) in enumerate(test_loader): y_hat, pred_dict = model(smiles.to(device), proteins.to(device)) predictions.append(y_hat) labels.append(y.clone()) loss = model.loss(y_hat, y.to(device)) test_loss += loss.item() predictions = torch.cat(predictions, dim=0).flatten().cpu().numpy() labels = torch.cat(labels, dim=0).flatten().cpu().numpy() test_loss = test_loss / len(test_loader) fpr, tpr, _ = roc_curve(labels, predictions) test_roc_auc = auc(fpr, tpr) # calculations for visualization plot precision, recall, _ = precision_recall_curve(labels, predictions) avg_precision = average_precision_score(labels, predictions) test_loss = test_loss / len(test_loader) logger.info( f"\t **** TESTING **** Epoch [{epoch + 1}/{params['epochs']}], " f"loss: {test_loss:.5f}, ROC-AUC: {test_roc_auc:.3f}, " f"Average precision: {avg_precision:.3f}.") def save(path, metric, typ, val=None): model.save(path.format(typ, metric, model_fn)) info = { 'best_roc_auc': str(max_roc_auc), 'test_loss': str(min_loss) } with open(os.path.join(model_dir, 'results', metric + '.json'), 'w') as f: json.dump(info, f) np.save(os.path.join(model_dir, 'results', metric + '_preds.npy'), np.vstack([predictions, labels])) if typ == 'best': logger.info(f'\t New best performance in "{metric}"' f' with value : {val:.7f} in epoch: {epoch}') if test_roc_auc > max_roc_auc: max_roc_auc = test_roc_auc save(save_top_model, 'ROC-AUC', 'best', max_roc_auc) ep_roc = epoch roc_auc_loss = test_loss if test_loss < min_loss: min_loss = test_loss save(save_top_model, 'loss', 'best', min_loss) ep_loss = epoch loss_roc_auc = test_roc_auc if (epoch + 1) % params.get('save_model', 100) == 0: save(save_top_model, 'epoch', str(epoch)) logger.info('Overall best performances are: \n \t' f'Loss = {min_loss:.4f} in epoch {ep_loss} ' f'\t (ROC-AUC was {loss_roc_auc:4f}) \n \t' f'ROC-AUC = {max_roc_auc:.4f} in epoch {ep_roc} ' f'\t (Loss was {roc_auc_loss:4f})') save(save_top_model, 'training', 'done') logger.info('Done with training, models saved, shutting down.')
def load_data( filename: str, smiles_language: Union[SMILESLanguage, str], max_close_neighbors: int = 5, max_total_bundle: int = 20, ) -> List[Data]: """Loads the data from a pre-processed JSON file and parses it into a list of torch_geometric.data.Data objects. Each element of the Data object is obtained by expanding a node up to 2 levels of depth. The sampling for this is controlled by `max_close_neighbors` and `max_total_bundle` (i.e. maximum total number of elements). This will be applied to all the nodes of the dataset, resulting in a returned list of `len` equal to the number of SMILES (i.e. total nodes). NOTE: Reason for this implementation is that ClusterData is failing when running it on a Data object with the enire dataset: ``` # smiles = [nodes2smiles[i] for i in range(len(nodes2smiles))] # return Data( # x=torch.tensor(list(range(len(smiles)))).view(-1, 1), # edge_index=torch.tensor(edgelist).T, # # num_nodes=len(smiles) # ) # number_of_nodes = 10 # cluster_data = ClusterData( # data, num_parts=data.num_nodes // number_of_nodes, log=True # ) # --> This is causing: SIGSEGV (Address boundary error) ``` Args: filename (str): Path to the JSON file. It must have the fields `smiles2nodes` and `edgelist`. max_close_neighbors (int, optional): Number of close (1 jump away) nodes that the algorith sample. Defaults to 5. max_total_bundle (int, optional): Maximum total number of elements in each Data object. Defaults to 20. Returns: List[Data] """ with open(filename, 'r') as f: raw = json.load(f) nodes2smiles = {v: k for k, v in raw['smiles2nodes'].items()} edgelist = raw['edgelist'] if isinstance(smiles_language, str): smiles_language = SMILESLanguage.load(smiles_language) smiles_language = smiles_language data = [] # Custom sampling G = nx.from_edgelist(edgelist) for root in range(len(G)): # FIXME This could have better been a DFS-like approach with fixed # depth but I had this half way there and it was easier to first sample # the closer hood and then expand to one more step from there. Not even # sure if DFS is what I want tho. I can see pros and cons. shortlist = [] close_neighbors = np.random.permutation(list( G.neighbors(root)))[:max_close_neighbors] total_far_neighbors = sum( [len(list(G.neighbors(x))) for x in close_neighbors]) shortlist += [[root, x] for x in close_neighbors] counter = 0 while (len(shortlist) < max_total_bundle and counter < total_far_neighbors): # TODO Random sampling probability inversely proportional to the # degree? current_node = close_neighbors[counter % len(close_neighbors)] far_node = np.random.choice(list(G.neighbors(current_node))) shortlist.append([current_node, far_node]) counter += 1 # We need to relabel the nodes, but keep the original location in order # to retrieve the corresponding smiles sub_graph = nx.convert_node_labels_to_integers( nx.from_edgelist(shortlist), label_attribute='original') x = [ torch.tensor( smiles_language.smiles_to_token_indexes( nodes2smiles[sub_graph.nodes[i]['original']])) for i in range(len(sub_graph)) ] edge_index = torch.tensor(nx.to_pandas_edgelist(sub_graph).values).T data.append(Data(x=x, edge_index=edge_index, num_nodes=len(x))) return data
def main( train_scores_filepath, test_scores_filepath, smi_filepath, model_path, params_filepath, training_name ): logging.basicConfig(level=logging.INFO, format='%(message)s') logger = logging.getLogger(f'{training_name}') logger.setLevel(logging.INFO) disable_rdkit_logging() device = get_device() # Restore pretrained model logger.info('Start model restoring.') try: with open(os.path.join(model_path, 'model_params.json'), 'r') as fp: params = json.load(fp) smiles_language = SMILESLanguage.load( os.path.join(model_path, 'smiles_language.pkl') ) model = MODEL_FACTORY[params.get('model_fn', 'mca')](params).to(device) # Try weight restoring try: weight_path = os.path.join( model_path, 'weights', params.get('weights_name', 'best_ROC-AUC_mca.pt') ) model.load(weight_path) except Exception: try: wp = os.listdir(os.path.join(model_path, 'weights'))[0] logger.info( f"Weights {weight_path} not found. Try restore {wp}" ) model.load(os.path.join(model_path, 'weights', wp)) except Exception: raise Exception('Error in weight loading.') except Exception: raise Exception( 'Error in model restoring. model_path should point to the model root ' 'folder that contains a model_params.json, a smiles_language.pkl and a ' 'weights folder.' ) logger.info('Model restored. Now starting to craft it for the task') # Process parameter file: model_params = {} with open(params_filepath) as fp: model_params.update(json.load(fp)) model = update_mca_model(model, model_params) logger.info('Model set up.') for idx, (name, param) in enumerate(model.named_parameters()): logger.info( (idx, name, param.shape, f'Gradients: {param.requires_grad}') ) ft_model_path = os.path.join(model_path, 'finetuned') os.makedirs(os.path.join(ft_model_path, 'weights'), exist_ok=True) os.makedirs(os.path.join(ft_model_path, 'results'), exist_ok=True) logger.info('Now start data preprocessing...') # Assemble datasets smiles_dataset = SMILESDataset( smi_filepath, smiles_language=smiles_language, padding_length=params.get('smiles_padding_length', None), padding=params.get('padd_smiles', True), add_start_and_stop=params.get('add_start_stop_token', True), augment=params.get('augment_smiles', False), canonical=params.get('canonical', False), kekulize=params.get('kekulize', False), all_bonds_explicit=params.get('all_bonds_explicit', False), all_hs_explicit=params.get('all_hs_explicit', False), randomize=params.get('randomize', False), remove_bonddir=params.get('remove_bonddir', False), remove_chirality=params.get('remove_chirality', False), selfies=params.get('selfies', False), sanitize=params.get('sanitize', True) ) train_dataset = AnnotatedDataset( annotations_filepath=train_scores_filepath, dataset=smiles_dataset, device=get_device() ) train_loader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=params['batch_size'], shuffle=True, drop_last=True, num_workers=params.get('num_workers', 0) ) # Generally, if sanitize is True molecules are de-kekulized. Augmentation # preserves the "kekulization", so if it is used, test data should be # sanitized or canonicalized. smiles_test_dataset = SMILESDataset( smi_filepath, smiles_language=smiles_language, padding_length=params.get('smiles_padding_length', None), padding=params.get('padd_smiles', True), add_start_and_stop=params.get('add_start_stop_token', True), augment=params.get('augment_test_smiles', False), canonical=params.get('test_canonical', False), kekulize=params.get('test_kekulize', False), all_bonds_explicit=params.get('test_all_bonds_explicit', False), all_hs_explicit=params.get('test_all_hs_explicit', False), randomize=False, remove_bonddir=params.get('test_remove_bonddir', False), remove_chirality=params.get('test_remove_chirality', False), selfies=params.get('selfies', False), sanitize=params.get('test_sanitize', False) ) # Dump eventually modified SMILES Language smiles_language.save(os.path.join(ft_model_path, 'smiles_language.pkl')) test_dataset = AnnotatedDataset( annotations_filepath=test_scores_filepath, dataset=smiles_test_dataset, device=get_device() ) test_loader = torch.utils.data.DataLoader( dataset=test_dataset, batch_size=params['batch_size'], shuffle=False, drop_last=False, num_workers=params.get('num_workers', 0) ) save_top_model = os.path.join(ft_model_path, 'weights/{}_{}_{}.pt') num_t_params = sum( p.numel() for p in model.parameters() if p.requires_grad ) num_params = sum(p.numel() for p in model.parameters()) params.update({'params': num_params, 'trainable_params': num_t_params}) logger.info( f'Number of parameters: {num_params} (trainable {num_t_params}).' ) # Define optimizer only for those layers which require gradients optimizer = ( OPTIMIZER_FACTORY[params.get('optimizer', 'adam')]( filter(lambda p: p.requires_grad, model.parameters()), lr=params.get('lr', 0.00001) ) ) # Dump params.json file with updated parameters. with open(os.path.join(ft_model_path, 'model_params.json'), 'w') as fp: json.dump(params, fp) # Start training logger.info('Training about to start...\n') t = time() min_loss, max_roc_auc = 1000000, 0 max_precision_recall_score = 0 for epoch in range(params['epochs']): model.train() logger.info(params_filepath.split('/')[-1]) logger.info(f"== Epoch [{epoch}/{params['epochs']}] ==") train_loss = 0 for ind, (smiles, y) in enumerate(train_loader): smiles = torch.squeeze(smiles.to(device)) y_hat, pred_dict = model(smiles) loss = model.loss(y_hat, y.to(device)) optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.item() logger.info( '\t **** TRAINING **** ' f"Epoch [{epoch + 1}/{params['epochs']}], " f'loss: {train_loss / len(train_loader):.5f}. ' f'This took {time() - t:.1f} secs.' ) t = time() # Measure validation performance model.eval() with torch.no_grad(): test_loss = 0 predictions = [] labels = [] for ind, (smiles, y) in enumerate(test_loader): smiles = torch.squeeze(smiles.to(device)) y_hat, pred_dict = model(smiles) predictions.append(y_hat) # Copy y tensor since loss function applies downstream modification labels.append(y.clone()) loss = model.loss(y_hat, y.to(device)) test_loss += loss.item() predictions = torch.cat(predictions, dim=0).flatten().cpu().numpy() labels = torch.cat(labels, dim=0).flatten().cpu().numpy() # Remove NaNs from labels to compute scores predictions = predictions[~np.isnan(labels)] labels = labels[~np.isnan(labels)] test_loss_a = test_loss / len(test_loader) fpr, tpr, _ = roc_curve(labels, predictions) test_roc_auc_a = auc(fpr, tpr) # calculations for visualization plot precision, recall, _ = precision_recall_curve(labels, predictions) # score for precision vs accuracy test_precision_recall_score = average_precision_score( labels, predictions ) logger.info( f"\t **** TEST **** Epoch [{epoch + 1}/{params['epochs']}], " f'loss: {test_loss_a:.5f}, , roc_auc: {test_roc_auc_a:.5f}, ' f'avg precision-recall score: {test_precision_recall_score:.5f}' ) info = { 'test_auc': test_roc_auc_a, 'train_loss': train_loss / len(train_loader), 'test_loss': test_loss_a, 'test_auc': test_roc_auc_a, 'best_test_auc': max_roc_auc, 'test_precision_recall_score': test_precision_recall_score, 'best_precision_recall_score': max_precision_recall_score, } def save(path, metric, typ, val=None): model.save(path.format(typ, metric, params.get('model_fn', 'mca'))) if typ == 'best': logger.info( f'\t New best performance in {metric}' f' with value : {val:.7f} in epoch: {epoch+1}' ) if test_roc_auc_a > max_roc_auc: max_roc_auc = test_roc_auc_a info.update({'best_test_auc': max_roc_auc}) save(save_top_model, 'ROC-AUC', 'best', max_roc_auc) np.save( os.path.join(ft_model_path, 'results', 'best_predictions.npy'), predictions ) with open( os.path.join(ft_model_path, 'results', 'metrics.json'), 'w' ) as f: json.dump(info, f) if test_precision_recall_score > max_precision_recall_score: max_precision_recall_score = test_precision_recall_score info.update( {'best_precision_recall_score': max_precision_recall_score} ) save( save_top_model, 'precision-recall score', 'best', max_precision_recall_score ) if test_loss_a < min_loss: min_loss = test_loss_a save(save_top_model, 'loss', 'best', min_loss) ep_loss = epoch if (epoch + 1) % params.get('save_model', 100) == 0: save(save_top_model, 'epoch', str(epoch)) logger.info( 'Overall best performances are: \n \t' f'Loss = {min_loss:.4f} in epoch {ep_loss} ' ) save(save_top_model, 'training', 'done') logger.info('Done with training, models saved, shutting down.')
def main(train_scores_filepath, test_scores_filepath, smi_filepath, smiles_language_filepath, model_path, params_filepath, training_name, embedding_path=None): logging.basicConfig(level=logging.INFO, format='%(message)s') logger = logging.getLogger(f'{training_name}') logger.setLevel(logging.INFO) disable_rdkit_logging() # Process parameter file: params = {} with open(params_filepath) as fp: params.update(json.load(fp)) if embedding_path: params['embedding_path'] = embedding_path # Create model directory and dump files model_dir = os.path.join(model_path, training_name) os.makedirs(os.path.join(model_dir, 'weights'), exist_ok=True) os.makedirs(os.path.join(model_dir, 'results'), exist_ok=True) with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp: json.dump(params, fp, indent=4) logger.info('Start data preprocessing...') smiles_language = SMILESLanguage.load(smiles_language_filepath) # Prepare FP processing if params.get('model_fn', 'mca') == 'dense': morgan_transform = Compose([ SMILESToMorganFingerprints(radius=params.get('fp_radius', 2), bits=params.get('num_drug_features', 512), chirality=params.get( 'fp_chirality', True)), ToTensor(get_device()) ]) def smiles_tensor_batch_to_fp(smiles): """ To abuse SMILES dataset for FP usage""" out = torch.Tensor(smiles.shape[0], params.get('num_drug_features', 256)) for ind, tensor in enumerate(smiles): smiles = smiles_language.token_indexes_to_smiles( tensor.tolist()) out[ind, :] = torch.squeeze(morgan_transform(smiles)) return out # Assemble datasets smiles_dataset = SMILESDataset( smi_filepath, smiles_language=smiles_language, padding_length=params.get('smiles_padding_length', None), padding=params.get('padd_smiles', True), add_start_and_stop=params.get('add_start_stop_token', True), augment=params.get('augment_smiles', False), canonical=params.get('canonical', False), kekulize=params.get('kekulize', False), all_bonds_explicit=params.get('all_bonds_explicit', False), all_hs_explicit=params.get('all_hs_explicit', False), randomize=params.get('randomize', False), remove_bonddir=params.get('remove_bonddir', False), remove_chirality=params.get('remove_chirality', False), selfies=params.get('selfies', False), sanitize=params.get('sanitize', True)) # include arg label_columns if data file has any unwanted columns (such as index) to be ignored. train_dataset = AnnotatedDataset( annotations_filepath=train_scores_filepath, dataset=smiles_dataset, device=get_device()) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=params['batch_size'], shuffle=True, drop_last=True, num_workers=params.get( 'num_workers', 0)) if (params.get('uncertainty', True) and params.get('augment_test_data', False)): raise ValueError( 'Epistemic uncertainty evaluation not supported if augmentation ' 'is not enabled for test data.') # Generally, if sanitize is True molecules are de-kekulized. Augmentation # preserves the "kekulization", so if it is used, test data should be # sanitized or canonicalized. smiles_test_dataset = SMILESDataset( smi_filepath, smiles_language=smiles_language, padding_length=params.get('smiles_padding_length', None), padding=params.get('padd_smiles', True), add_start_and_stop=params.get('add_start_stop_token', True), augment=params.get('augment_test_smiles', False), canonical=params.get('test_canonical', False), kekulize=params.get('test_kekulize', False), all_bonds_explicit=params.get('test_all_bonds_explicit', False), all_hs_explicit=params.get('test_all_hs_explicit', False), randomize=False, remove_bonddir=params.get('test_remove_bonddir', False), remove_chirality=params.get('test_remove_chirality', False), selfies=params.get('selfies', False), sanitize=params.get('test_sanitize', False)) # Dump eventually modified SMILES Language smiles_language.save(os.path.join(model_dir, 'smiles_language.pkl')) logger.info(smiles_dataset._dataset.transform) logger.info(smiles_test_dataset._dataset.transform) # include arg label_columns if data file has any unwanted columns (such as index) to be ignored. test_dataset = AnnotatedDataset(annotations_filepath=test_scores_filepath, dataset=smiles_test_dataset, device=get_device()) test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=params['batch_size'], shuffle=False, drop_last=False, num_workers=params.get( 'num_workers', 0)) if params.get('confidence', False): smiles_conf_dataset = SMILESDataset( smi_filepath, smiles_language=smiles_language, padding_length=params.get('smiles_padding_length', None), padding=params.get('padd_smiles', True), add_start_and_stop=params.get('add_start_stop_token', True), augment=True, # Natively true for epistemic uncertainity estimate canonical=params.get('canonical', False), kekulize=params.get('kekulize', False), all_bonds_explicit=params.get('all_bonds_explicit', False), all_hs_explicit=params.get('all_hs_explicit', False), randomize=params.get('randomize', False), remove_bonddir=params.get('remove_bonddir', False), remove_chirality=params.get('remove_chirality', False), selfies=params.get('selfies', False)) conf_dataset = AnnotatedDataset( annotations_filepath=test_scores_filepath, dataset=smiles_conf_dataset, device=get_device()) conf_loader = torch.utils.data.DataLoader( dataset=conf_dataset, batch_size=params['batch_size'], shuffle=False, drop_last=False, num_workers=params.get('num_workers', 0)) if not params.get('embedding', 'learned') == 'pretrained': params.update( {'smiles_vocabulary_size': smiles_language.number_of_tokens}) device = get_device() logger.info(f'Device for data loader is {train_dataset.device} and for ' f'model is {device}') save_top_model = os.path.join(model_dir, 'weights/{}_{}_{}.pt') model = MODEL_FACTORY[params.get('model_fn', 'mca')](params).to(device) logger.info(model) logger.info(model.loss_fn.class_weights) logger.info('Parameters follow') for name, param in model.named_parameters(): logger.info((name, param.shape)) num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) params.update({'number_of_parameters': num_params}) logger.info(f'Number of parameters {num_params}') # Define optimizer optimizer = (OPTIMIZER_FACTORY[params.get('optimizer', 'adam')](model.parameters(), lr=params.get( 'lr', 0.00001))) # Overwrite params.json file with updated parameters. with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp: json.dump(params, fp) # Start training logger.info('Training about to start...\n') t = time() min_loss, max_roc_auc = 1000000, 0 max_precision_recall_score = 0 for epoch in range(params['epochs']): model.train() logger.info(params_filepath.split('/')[-1]) logger.info(f"== Epoch [{epoch}/{params['epochs']}] ==") train_loss = 0 for ind, (smiles, y) in enumerate(train_loader): smiles = torch.squeeze(smiles.to(device)) # Transform smiles to FP if needed if params.get('model_fn', 'mca') == 'dense': smiles = smiles_tensor_batch_to_fp(smiles).to(device) y_hat, pred_dict = model(smiles) loss = model.loss(y_hat, y.to(device)) optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.item() logger.info('\t **** TRAINING **** ' f"Epoch [{epoch + 1}/{params['epochs']}], " f'loss: {train_loss / len(train_loader):.5f}. ' f'This took {time() - t:.1f} secs.') t = time() # Measure validation performance model.eval() with torch.no_grad(): test_loss = 0 predictions = [] labels = [] for ind, (smiles, y) in enumerate(test_loader): smiles = torch.squeeze(smiles.to(device)) # Transform smiles to FP if needed if params.get('model_fn', 'mca') == 'dense': smiles = smiles_tensor_batch_to_fp(smiles).to(device) y_hat, pred_dict = model(smiles) predictions.append(y_hat) # Copy y tensor since loss function applies downstream # modification labels.append(y.clone()) loss = model.loss(y_hat, y.to(device)) test_loss += loss.item() predictions = torch.cat(predictions, dim=0).flatten().cpu().numpy() labels = torch.cat(labels, dim=0).flatten().cpu().numpy() # Remove NaNs from labels to compute scores predictions = predictions[~np.isnan(labels)] labels = labels[~np.isnan(labels)] test_loss_a = test_loss / len(test_loader) fpr, tpr, _ = roc_curve(labels, predictions) test_roc_auc_a = auc(fpr, tpr) # calculations for visualization plot precision, recall, _ = precision_recall_curve(labels, predictions) # score for precision vs accuracy test_precision_recall_score = average_precision_score( labels, predictions) logger.info( f"\t **** TEST **** Epoch [{epoch + 1}/{params['epochs']}], " f'loss: {test_loss_a:.5f}, , roc_auc: {test_roc_auc_a:.5f}, ' f'avg precision-recall score: {test_precision_recall_score:.5f}') info = { 'test_auc': test_roc_auc_a, 'train_loss': train_loss / len(train_loader), 'test_loss': test_loss_a, 'test_auc': test_roc_auc_a, 'best_test_auc': max_roc_auc, 'test_precision_recall_score': test_precision_recall_score, 'best_precision_recall_score': max_precision_recall_score, } def save(path, metric, typ, val=None): model.save(path.format(typ, metric, params.get('model_fn', 'mca'))) if typ == 'best': logger.info(f'\t New best performance in {metric}' f' with value : {val:.7f} in epoch: {epoch+1}') if test_roc_auc_a > max_roc_auc: max_roc_auc = test_roc_auc_a info.update({'best_test_auc': max_roc_auc}) save(save_top_model, 'ROC-AUC', 'best', max_roc_auc) np.save(os.path.join(model_dir, 'results', 'best_predictions.npy'), predictions) with open(os.path.join(model_dir, 'results', 'metrics.json'), 'w') as f: json.dump(info, f) if params.get('confidence', False): # Compute uncertainity estimates and save them epistemic_conf = monte_carlo_dropout(model, regime='loader', loader=conf_loader) aleatoric_conf = test_time_augmentation(model, regime='loader', loader=conf_loader) np.save( os.path.join(model_dir, 'results', 'epistemic_conf.npy'), epistemic_conf) np.save( os.path.join(model_dir, 'results', 'aleatoric_conf.npy'), aleatoric_conf) if test_precision_recall_score > max_precision_recall_score: max_precision_recall_score = test_precision_recall_score info.update( {'best_precision_recall_score': max_precision_recall_score}) save(save_top_model, 'precision-recall score', 'best', max_precision_recall_score) if test_loss_a < min_loss: min_loss = test_loss_a save(save_top_model, 'loss', 'best', min_loss) ep_loss = epoch if (epoch + 1) % params.get('save_model', 100) == 0: save(save_top_model, 'epoch', str(epoch)) logger.info('Overall best performances are: \n \t' f'Loss = {min_loss:.4f} in epoch {ep_loss} ') save(save_top_model, 'training', 'done') logger.info('Done with training, models saved, shutting down.')
def main(train_sensitivity_filepath, test_sensitivity_filepath, gep_filepath, smi_filepath, gene_filepath, smiles_language_filepath, model_path, params_filepath, training_name): logger = logging.getLogger(f'{training_name}') # Process parameter file: params = {} with open(params_filepath) as fp: params.update(json.load(fp)) # Create model directory and dump files model_dir = os.path.join(model_path, training_name) os.makedirs(os.path.join(model_dir, 'weights'), exist_ok=True) os.makedirs(os.path.join(model_dir, 'results'), exist_ok=True) with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp: json.dump(params, fp, indent=4) # Prepare the dataset logger.info("Start data preprocessing...") # Load SMILES language smiles_language = SMILESLanguage.load(smiles_language_filepath) # Load the gene list with open(gene_filepath, 'rb') as f: gene_list = pickle.load(f) # Assemble datasets train_dataset = DrugSensitivityDataset( drug_sensitivity_filepath=train_sensitivity_filepath, smi_filepath=smi_filepath, gene_expression_filepath=gep_filepath, smiles_language=smiles_language, gene_list=gene_list, drug_sensitivity_min_max=params.get('drug_sensitivity_min_max', True), augment=params.get('augment_smiles', True), add_start_and_stop=params.get('smiles_start_stop_token', True), padding_length=params.get('smiles_padding_length', None), device=torch.device(params.get('dataset_device', 'cpu')), backend='eager') train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=params['batch_size'], shuffle=True, drop_last=True, num_workers=params.get( 'num_workers', 0)) test_dataset = DrugSensitivityDataset( drug_sensitivity_filepath=test_sensitivity_filepath, smi_filepath=smi_filepath, gene_expression_filepath=gep_filepath, smiles_language=smiles_language, gene_list=gene_list, drug_sensitivity_min_max=params.get('drug_sensitivity_min_max', True), augment=params.get('augment_smiles', True), add_start_and_stop=params.get('smiles_start_stop_token', True), padding_length=params.get('smiles_padding_length', None), device=torch.device(params.get('dataset_device', 'cpu')), backend='eager') test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=params['batch_size'], shuffle=True, drop_last=True, num_workers=params.get( 'num_workers', 0)) params.update({ 'number_of_genes': len(gene_list), 'smiles_vocabulary_size': smiles_language.number_of_tokens }) # Set the tensorboard logger tb_logger = Logger(os.path.join(model_dir, 'tb_logs')) device = get_device() logger.info(f'Device for data loader is {train_dataset.device} and for ' f'model is {device}') save_top_model = os.path.join(model_dir, 'weights/{}_{}_{}.pt') model = MODEL_FACTORY[params.get('model_fn', 'mca')](params).to(device) # Define optimizer optimizer = (OPTIMIZER_FACTORY[params.get('optimizer', 'Adam')](model.parameters(), lr=params.get( 'lr', 0.01))) # Overwrite params.json file with updated parameters. with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp: json.dump(params, fp) # Start training logger.info('Training about to start...\n') t = time() min_loss, max_pearson = 100, 0 for epoch in range(params['epochs']): model.train() logger.info(params_filepath.split('/')[-1]) logger.info(f"== Epoch [{epoch}/{params['epochs']}] ==") train_loss = 0 for ind, (smiles, gep, y) in enumerate(train_loader): y_hat, pred_dict = model(torch.squeeze(smiles.to(device)), gep.to(device)) loss = model.loss(y_hat, y.to(device)) optimizer.zero_grad() loss.backward() # Apply gradient clipping # torch.nn.utils.clip_grad_norm_(model.parameters(),1e-6) optimizer.step() train_loss += loss.item() logger.info("\t **** TRAINING **** " f"Epoch [{epoch + 1}/{params['epochs']}], " f"loss: {train_loss / len(train_loader):.5f}. " f"This took {time() - t:.1f} secs.") t = time() # Measure validation performance model.eval() with torch.no_grad(): test_loss = 0 predictions = [] labels = [] for ind, (smiles, gep, y) in enumerate(test_loader): y_hat, pred_dict = model(torch.squeeze(smiles.to(device)), gep.to(device)) predictions.append(y_hat) labels.append(y) loss = model.loss(y_hat, y.to(device)) test_loss += loss.item() predictions = np.array( [p.cpu() for preds in predictions for p in preds]) labels = np.array([l.cpu() for label in labels for l in label]) test_pearson_a = pearsonr(predictions, labels) test_rmse_a = np.sqrt(np.mean((predictions - labels)**2)) np.save(os.path.join(model_dir, 'results', 'preds.npy'), np.hstack([predictions, labels])) test_loss_a = test_loss / len(test_loader) logger.info( f"\t **** TESTING **** Epoch [{epoch + 1}/{params['epochs']}], " f"loss: {test_loss_a:.5f}, " f"Pearson: {test_pearson_a:.3f}, " f"RMSE: {test_rmse_a:.3f}") # TensorBoard logging of scalars. info = { 'train_loss': train_loss / len(train_loader), 'test_loss': test_loss_a, } for tag, value in info.items(): tb_logger.scalar_summary(tag, value, epoch + 1) def save(path, metric, typ, val=None): model.save(path.format(typ, metric, params.get('model_fn', 'mca'))) if typ == 'best': logger.info(f'\t New best performance in "{metric}"' f' with value : {val:.7f} in epoch: {epoch}') if test_loss_a < min_loss: min_loss = test_loss_a min_loss_pearson = test_pearson_a save(save_top_model, 'mse', 'best', min_loss) ep_loss = epoch if test_pearson_a > max_pearson: max_pearson = test_pearson_a max_pearson_loss = test_loss_a save(save_top_model, 'pearson', 'best', max_pearson) ep_pearson = epoch if (epoch + 1) % params.get('save_model', 100) == 0: save(save_top_model, 'epoch', str(epoch)) logger.info('Overall best performances are: \n \t' f'Loss = {min_loss:.4f} in epoch {ep_loss} ' f'\t (Pearson was {min_loss_pearson:4f}) \n \t' f'Pearson = {max_pearson:.4f} in epoch {ep_pearson} ' f'\t (Loss was {max_pearson_loss:2f})') save(save_top_model, 'training', 'done') logger.info('Done with training, models saved, shutting down.')
def main(parser_namespace): # model loading disable_rdkit_logging() affinity_path = parser_namespace.affinity_path svae_path = parser_namespace.svae_path svae_weights_path = os.path.join(svae_path, "weights", "best_rec.pt") results_file_name = parser_namespace.optimisation_name logger.add(results_file_name + ".log", rotation="10 MB") svae_params = dict() with open(os.path.join(svae_path, "model_params.json"), "r") as f: svae_params.update(json.load(f)) smiles_language = SMILESLanguage.load( os.path.join(svae_path, "selfies_language.pkl")) # initialize encoder, decoder, testVAE, and GP_generator_MW gru_encoder = StackGRUEncoder(svae_params) gru_decoder = StackGRUDecoder(svae_params) gru_vae = TeacherVAE(gru_encoder, gru_decoder) gru_vae.load_state_dict( torch.load(svae_weights_path, map_location=get_device())) gru_vae._associate_language(smiles_language) gru_vae.eval() smiles_generator = SmilesGenerator(gru_vae) with open(os.path.join(affinity_path, "model_params.json")) as f: predictor_params = json.load(f) affinity_predictor = MODEL_FACTORY["bimodal_mca"](predictor_params) affinity_predictor.load( os.path.join( affinity_path, f"weights/best_{predictor_params.get('p_metric', 'ROC-AUC')}_bimodal_mca.pt", ), map_location=get_device(), ) affinity_protein_language = ProteinLanguage.load( os.path.join(affinity_path, "protein_language.pkl")) affinity_smiles_language = SMILESLanguage.load( os.path.join(affinity_path, "smiles_language.pkl")) affinity_predictor._associate_language(affinity_smiles_language) affinity_predictor._associate_language(affinity_protein_language) affinity_predictor.eval() erg_protein = "MASTIKEALSVVSEDQSLFECAYGTPHLAKTEMTASSSSDYGQTSKMSPRVPQQDWLSQPPARVTIKMECNPSQVNGSRNSPDECSVAKGGKMVGSPDTVGMNYGSYMEEKHMPPPNMTTNERRVIVPADPTLWSTDHVRQWLEWAVKEYGLPDVNILLFQNIDGKELCKMTKDDFQRLTPSYNADILLSHLHYLRETPLPHLTSDDVDKALQNSPRLMHARNTGGAAFIFPNTSVYPEATQRITTRPDLPYEPPRRSAWTGHGHPTPQSKAAQPSPSTVPKTEDQRPQLDPYQILGPTSSRLANPGSGQIQLWQFLLELLSDSSNSSCITWEGTNGEFKMTDPDEVARRWGERKSKPNMNYDKLSRALRYYYDKNIMTKVHGKRYAYKFDFHGIAQALQPHPPESSLYKYPSDLPYMGSYHAHPQKMNFVAPHPPALPVTSSSFFAAPNPYWNSPTGGIYPNTRLPTSHMPSHLGTYY" target_minimization_function = AffinityMinimization( smiles_generator, 30, affinity_predictor, erg_protein) qed_function = QEDMinimization(smiles_generator, 30) sa_function = SAMinimization(smiles_generator, 30) combined_minimization = CombinedMinimization( [target_minimization_function, qed_function, sa_function], 1, [0.75, 1, 0.5]) target_optimizer = GPOptimizer(combined_minimization.evaluate) params = dict( dimensions=[(-5.0, 5.0)] * 256, acq_func="EI", n_calls=20, n_initial_points=19, initial_point_generator="random", random_state=1234, ) logger.info("Optimisation parameters: {params}", params=params) # optimisation for j in range(5): res = target_optimizer.optimize(params) latent_point = torch.tensor([[res.x]]) with open(results_file_name + "_LP" + str(j + 1) + ".pkl", "wb") as f: pickle.dump(latent_point, f, protocol=2) smile_set = set() while len(smile_set) < 20: smiles = smiles_generator.generate_smiles( latent_point.repeat(1, 30, 1)) smile_set.update(set(smiles)) smile_set = list(smile_set) pad_smiles_predictor = LeftPadding( affinity_predictor.smiles_padding_length, affinity_predictor.smiles_language.padding_index, ) to_tensor = ToTensor(get_device()) smiles_num = [ torch.unsqueeze( to_tensor( pad_smiles_predictor( affinity_predictor.smiles_language. smiles_to_token_indexes(smile))), 0, ) for smile in smile_set ] smiles_tensor = torch.cat(smiles_num, dim=0) pad_protein_predictor = LeftPadding( affinity_predictor.protein_padding_length, affinity_predictor.protein_language.padding_index, ) protein_num = torch.unsqueeze( to_tensor( pad_protein_predictor( affinity_predictor.protein_language. sequence_to_token_indexes([erg_protein]))), 0, ) protein_num = protein_num.repeat(len(smile_set), 1) with torch.no_grad(): pred, _ = affinity_predictor(smiles_tensor, protein_num) affinities = torch.squeeze(pred, 1).numpy() sas = SAS() sa_scores = [sas(smile) for smile in smile_set] qed_scores = [qed(Chem.MolFromSmiles(smile)) for smile in smile_set] # save to file file = results_file_name + str(j + 1) + ".txt" logger.info("creating {file}", file=file) with open(file, "w") as f: f.write( f'{"point":<10}{"Affinity":<10}{"QED":<10}{"SA":<10}{"smiles":<15}\n' ) for i in range(20): dat = [ i + 1, affinities[i], qed_scores[i], sa_scores[i], smile_set[i] ] f.write( f'{dat[0]:<10}{"%.3f"%dat[1]:<10}{"%.3f"%dat[2]:<10}{"%.3f"%dat[3]:<10}{dat[4]:<15}\n' )
# in the thesis we used 40k molecules extracted # from the PubChem database # The model is separated in 3 files, PARAMS, WEIGHTS, LANGUAGE. # See: https://github.com/PaccMann/paccmann_chemistry for more info. DATA_DIR = # Directory with the model and the data PARAM_FILE = os.path.join(DATA_DIR, 'model_params.json') WEIGHT_FILE = os.path.join(DATA_DIR, 'weights', 'best_loss.pt') LANG_FILE = os.path.join(DATA_DIR, 'selfies_language.pkl') MODEL_DIR = './finetuned_model' # Directory for the finetuned model # START LOADING paccmann_vae = Encoder(PARAM_FILE, LANG_FILE, WEIGHT_FILE, 1) smiles_language = SMILESLanguage.load(LANG_FILE) SAVE_FOLDER = './' LATENTS_FILE = latents_file if latents_file is not None \ else os.path.join(SAVE_FOLDER, f'latents.{model}.npy') CATALYST_LATENTS_FILE = os.path.join( SAVE_FOLDER, f'catalyst_latents.{model}.npy' ) df = pd.read_csv(COMPOUNDS_PATH, header=None, sep='\t', names=['SMILES', 'id']) db_catalysts = pd.read_csv(CATALYSTS_PATH) latents, failed_smiles = calculate_latents( paccmann_vae, df.SMILES, LATENTS_FILE )
def main(parser_namespace): disable_rdkit_logging() model_path = parser_namespace.model_path data_path = parser_namespace.data_path weights_path = os.path.join(model_path, 'weights', 'best_loss.pt') device = get_device() # read the params json params = dict() with open(os.path.join(model_path, 'model_params.json'), 'r') as f: params.update(json.load(f)) params['batch_size'] = 1 # Load SMILES language smiles_language = SMILESLanguage.load( os.path.join(model_path, 'selfies_language.pkl')) data_preparation = get_data_preparation(params.get('batch_mode')) device = get_device() dataset = SMILESDataset( data_path, smiles_language=smiles_language, padding=False, selfies=params.get('selfies', False), add_start_and_stop=params.get('add_start_stop_token', True), augment=False, #params.get('augment_smiles', False), canonical=params.get('canonical', False), kekulize=params.get('kekulize', False), all_bonds_explicit=params.get('all_bonds_explicit', False), all_hs_explicit=params.get('all_hs_explicit', False), remove_bonddir=params.get('remove_bonddir', False), remove_chirality=params.get('remove_chirality', False), backend='lazy', device=device) dataloader = torch.utils.data.DataLoader( dataset, batch_size=params.get('batch_size', 64), collate_fn=collate_fn, drop_last=True, shuffle=True, pin_memory=params.get('pin_memory', True), num_workers=params.get('num_workers', 8)) # initialize encoder and decoder gru_encoder = StackGRUEncoder(params).to(device) gru_decoder = StackGRUDecoder(params).to(device) gru_vae = TeacherVAE(gru_encoder, gru_decoder).to(device) logger.info('\n****MODEL SUMMARY***\n') for name, parameter in gru_vae.named_parameters(): logger.info(f'Param {name}, shape:\t{parameter.shape}') total_params = sum(p.numel() for p in gru_vae.parameters()) logger.info(f'Total # params: {total_params}') gru_vae.load_state_dict(torch.load(weights_path, map_location=device)) # Updating the vocab size will break the model params.update({ # 'vocab_size': smiles_language.number_of_tokens, 'pad_index': smiles_language.padding_index }) # yapf:disable # if params.get('embedding', 'learned') == 'one_hot': # params.update({'embedding_size': params['vocab_size']}) # train for n_epoch epochs logger.info( 'Model creation, loading and data processing done. Evaluation starts.') gru_vae.eval() gru_vae.to(device) counter = 0 with torch.no_grad(): latent_code = [] from tqdm import tqdm for batch in tqdm(dataloader, total=len(dataloader)): (encoder_seq, _, _) = data_preparation(batch, input_keep=0., start_index=2, end_index=3, device=device) try: mu, logvar = gru_vae.encode(encoder_seq) except RuntimeError: # Substitute any new tokens by "<UNK>" tokens new_seq = [] padd_encoder_seq, lenghts = ( torch.nn.utils.rnn.pad_packed_sequence(encoder_seq, batch_first=True)) for seq, _len in zip(padd_encoder_seq, lenghts): seq = seq[:_len] if any([x >= params['vocab_size'] for x in seq]): seq = torch.tensor([ x if x < params['vocab_size'] else smiles_language.unknown_index for x in seq.tolist() ]).short() failed_smiles = smiles_language.selfies_to_smiles( smiles_language.token_indexes_to_smiles( seq.tolist())) logger.warning( f'Out of bounds sample: ~{counter}\t{failed_smiles}' ) new_seq.append(seq) if new_seq: for _ in range(params['batch_size'] - len(new_seq)): new_seq.append(torch.ones_like(new_seq[-1])) (encoder_seq, _, _) = data_preparation(new_seq, input_keep=0., start_index=2, end_index=3, device=device) mu, logvar = gru_vae.encode(encoder_seq) for _mu in mu.tolist(): latent_code.append([counter, _mu]) counter += 1 LATENT_CODE_PATH = os.path.join(os.path.dirname(data_path), 'samples_latent_code.tsv') with open(LATENT_CODE_PATH, 'w') as f: for i, mu in latent_code: f.write(f'{i}\t{",".join([str(x) for x in mu[0]])}\n')
def main(parser_namespace): try: device = get_device() disable_rdkit_logging() # read the params json params = dict() with open(parser_namespace.params_filepath) as f: params.update(json.load(f)) # get params train_smiles_filepath = parser_namespace.train_smiles_filepath test_smiles_filepath = parser_namespace.test_smiles_filepath smiles_language_filepath = ( parser_namespace.smiles_language_filepath if parser_namespace.smiles_language_filepath.lower() != 'none' else None) model_path = parser_namespace.model_path training_name = parser_namespace.training_name writer = SummaryWriter(f'logs/{training_name}') logger.info(f'Model with name {training_name} starts.') model_dir = os.path.join(model_path, training_name) log_path = os.path.join(model_dir, 'logs') val_dir = os.path.join(log_path, 'val_logs') os.makedirs(os.path.join(model_dir, 'weights'), exist_ok=True) os.makedirs(os.path.join(model_dir, 'results'), exist_ok=True) os.makedirs(log_path, exist_ok=True) os.makedirs(val_dir, exist_ok=True) # Load SMILES language smiles_language = None if smiles_language_filepath is not None: smiles_language = SMILESLanguage.load(smiles_language_filepath) logger.info(f'Smiles filepath: {train_smiles_filepath}') # create SMILES eager dataset smiles_train_data = SMILESDataset( train_smiles_filepath, smiles_language=smiles_language, padding=False, selfies=params.get('selfies', False), add_start_and_stop=params.get('add_start_stop_token', True), augment=params.get('augment_smiles', False), canonical=params.get('canonical', False), kekulize=params.get('kekulize', False), all_bonds_explicit=params.get('all_bonds_explicit', False), all_hs_explicit=params.get('all_hs_explicit', False), remove_bonddir=params.get('remove_bonddir', False), remove_chirality=params.get('remove_chirality', False), backend='lazy', device=device, ) smiles_test_data = SMILESDataset( test_smiles_filepath, smiles_language=smiles_language, padding=False, selfies=params.get('selfies', False), add_start_and_stop=params.get('add_start_stop_token', True), augment=params.get('augment_smiles', False), canonical=params.get('canonical', False), kekulize=params.get('kekulize', False), all_bonds_explicit=params.get('all_bonds_explicit', False), all_hs_explicit=params.get('all_hs_explicit', False), remove_bonddir=params.get('remove_bonddir', False), remove_chirality=params.get('remove_chirality', False), backend='lazy', device=device, ) if smiles_language_filepath is None: smiles_language = smiles_train_data.smiles_language smiles_language.save( os.path.join(model_path, f'{training_name}.lang')) else: smiles_language_filename = os.path.basename( smiles_language_filepath) smiles_language.save( os.path.join(model_dir, smiles_language_filename)) params.update({ 'vocab_size': smiles_language.number_of_tokens, 'pad_index': smiles_language.padding_index }) vocab_dict = smiles_language.index_to_token params.update({ 'start_index': list(vocab_dict.keys())[list( vocab_dict.values()).index('<START>')], 'end_index': list(vocab_dict.keys())[list(vocab_dict.values()).index('<STOP>')] }) if params.get('embedding', 'learned') == 'one_hot': params.update({'embedding_size': params['vocab_size']}) with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp: json.dump(params, fp) # create DataLoaders train_data_loader = torch.utils.data.DataLoader( smiles_train_data, batch_size=params.get('batch_size', 64), collate_fn=collate_fn, drop_last=True, shuffle=True, pin_memory=params.get('pin_memory', True), num_workers=params.get('num_workers', 8)) test_data_loader = torch.utils.data.DataLoader( smiles_test_data, batch_size=params.get('batch_size', 64), collate_fn=collate_fn, drop_last=True, shuffle=True, pin_memory=params.get('pin_memory', True), num_workers=params.get('num_workers', 8)) # initialize encoder and decoder gru_encoder = StackGRUEncoder(params).to(device) gru_decoder = StackGRUDecoder(params).to(device) gru_vae = TeacherVAE(gru_encoder, gru_decoder).to(device) # TODO I haven't managed to get this to work. I will leave it here # if somewant (or future me) wants to give it a look and get the # tensorboard graph to work # if writer and False: # gru_vae.set_batch_mode('padded') # dummy_input = torch.ones(smiles_train_data[0].shape) # dummy_input = dummy_input.unsqueeze(0).to(device) # writer.add_graph(gru_vae, (dummy_input, dummy_input, dummy_input)) # gru_vae.set_batch_mode(params.get('batch_mode')) logger.info('\n****MODEL SUMMARY***\n') for name, parameter in gru_vae.named_parameters(): logger.info(f'Param {name}, shape:\t{parameter.shape}') total_params = sum(p.numel() for p in gru_vae.parameters()) logger.info(f'Total # params: {total_params}') loss_tracker = { 'test_loss_a': 10e4, 'test_rec_a': 10e4, 'test_kld_a': 10e4, 'ep_loss': 0, 'ep_rec': 0, 'ep_kld': 0 } # train for n_epoch epochs logger.info( 'Model creation and data processing done, Training starts.') decoder_search = SEARCH_FACTORY[ params.get('decoder_search', 'sampling') ]( temperature=params.get('temperature', 1.), beam_width=params.get('beam_width', 3), top_tokens=params.get('top_tokens', 5) ) # yapf: disable if writer: pparams = params.copy() pparams['training_file'] = train_smiles_filepath pparams['test_file'] = test_smiles_filepath pparams['language_file'] = smiles_language_filepath pparams['model_path'] = model_path pparams = { k: v if v is not None else 'N.A.' for k, v in params.items() } pparams['training_name'] = training_name from pprint import pprint pprint(pparams) writer.add_hparams(hparam_dict=pparams, metric_dict={}) for epoch in range(params['epochs'] + 1): t = time() loss_tracker = train_vae( epoch, gru_vae, train_data_loader, test_data_loader, smiles_language, model_dir, search=decoder_search, optimizer=params.get('optimizer', 'adadelta'), lr=params['learning_rate'], kl_growth=params['kl_growth'], input_keep=params['input_keep'], test_input_keep=params['test_input_keep'], generate_len=params['generate_len'], log_interval=params['log_interval'], save_interval=params['save_interval'], eval_interval=params['eval_interval'], loss_tracker=loss_tracker, logger=logger, # writer=writer, batch_mode=params.get('batch_mode')) logger.info(f'Epoch {epoch}, took {time() - t:.1f}.') logger.info('OVERALL: \t Best loss = {0:.4f} in Ep {1}, ' 'best Rec = {2:.4f} in Ep {3}, ' 'best KLD = {4:.4f} in Ep {5}'.format( loss_tracker['test_loss_a'], loss_tracker['ep_loss'], loss_tracker['test_rec_a'], loss_tracker['ep_rec'], loss_tracker['test_kld_a'], loss_tracker['ep_kld'])) logger.info('Training done, shutting down.') except Exception: logger.exception('Exception occurred while running train_vae.py.')
def main(*, parser_namespace): disable_rdkit_logging() # read the params json params = dict() with open(parser_namespace.params_path) as f: params.update(json.load(f)) # get params, json args take precedence mol_model_path = params.get('mol_model_path', parser_namespace.mol_model_path) protein_model_path = params.get('protein_model_path', parser_namespace.protein_model_path) affinity_model_path = params.get('affinity_model_path', parser_namespace.affinity_model_path) protein_data_path = params.get('protein_data_path', parser_namespace.protein_data_path) model_name = params.get( 'model_name', parser_namespace.model_name ) # yapf: disable test_id = int(params.get( 'test_protein_id', parser_namespace.test_protein_id )) # yapf: disable unbiased_preds_path = params.get( 'unbiased_predictions_path', parser_namespace.unbiased_predictions_path ) # yapf: disable model_name += '_' + str(test_id) logger.info(f'Model with name {model_name} starts.') # passing optional paths to params to possibly update_reward_fn optional_reward_args = [ 'tox21_path', 'organdb_path', 'site', 'clintox_path', 'sider_path' ] for arg in optional_reward_args: if parser_namespace.__dict__[arg]: # json still has presedence params[arg] = params.get(arg, parser_namespace.__dict__[arg]) # Load protein sequence data if protein_data_path.endswith('.smi'): protein_df = read_smi(protein_data_path, names=['Sequence']) elif protein_data_path.endswith('.csv'): protein_df = pd.read_csv(protein_data_path, index_col='entry_name') else: raise TypeError( f"{protein_data_path.split('.')[-1]} files are not supported.") protein_test_name = protein_df.iloc[test_id].name logger.info(f'Test protein is {protein_test_name}') # Restore SMILES Model with open(os.path.join(mol_model_path, 'model_params.json')) as f: mol_params = json.load(f) gru_encoder = StackGRUEncoder(mol_params) gru_decoder = StackGRUDecoder(mol_params) generator = TeacherVAE(gru_encoder, gru_decoder) generator.load(os.path.join( mol_model_path, f"weights/best_{params.get('smiles_metric', 'rec')}.pt"), map_location=get_device()) # Load languages generator_smiles_language = SMILESLanguage.load( os.path.join(mol_model_path, 'selfies_language.pkl')) generator._associate_language(generator_smiles_language) # Restore protein model with open(os.path.join(protein_model_path, 'model_params.json')) as f: protein_params = json.load(f) # Define network protein_encoder = ENCODER_FACTORY['dense'](protein_params) protein_encoder.load(os.path.join( protein_model_path, f"weights/best_{params.get('omics_metric','both')}_encoder.pt"), map_location=get_device()) protein_encoder.eval() # Restore affinity predictor with open(os.path.join(affinity_model_path, 'model_params.json')) as f: predictor_params = json.load(f) predictor = MODEL_FACTORY['bimodal_mca'](predictor_params) predictor.load(os.path.join( affinity_model_path, f"weights/best_{params.get('p_metric', 'ROC-AUC')}_bimodal_mca.pt"), map_location=get_device()) predictor.eval() # Load languages affinity_smiles_language = SMILESLanguage.load( os.path.join(affinity_model_path, 'smiles_language.pkl')) affinity_protein_language = ProteinLanguage.load( os.path.join(affinity_model_path, 'protein_language.pkl')) predictor._associate_language(affinity_smiles_language) predictor._associate_language(affinity_protein_language) # Specifies the baseline model used for comparison unbiased_preds = np.array( pd.read_csv( os.path.join(unbiased_preds_path, protein_test_name + '.csv') )['affinity'].values ) # yapf: disable # Create a fresh model that will be optimized gru_encoder_rl = StackGRUEncoder(mol_params) gru_decoder_rl = StackGRUDecoder(mol_params) generator_rl = TeacherVAE(gru_encoder_rl, gru_decoder_rl) generator_rl.load(os.path.join( mol_model_path, f"weights/best_{params.get('metric', 'rec')}.pt"), map_location=get_device()) generator_rl.eval() # Load languages generator_rl._associate_language(generator_smiles_language) protein_encoder_rl = ENCODER_FACTORY['dense'](protein_params) protein_encoder_rl.load(os.path.join( protein_model_path, f"weights/best_{params.get('metric', 'both')}_encoder.pt"), map_location=get_device()) protein_encoder_rl.eval() model_folder_name = model_name learner = ReinforceProtein(generator_rl, protein_encoder_rl, predictor, protein_df, params, model_folder_name, logger) biased_ratios, tox_ratios = [], [] rewards, rl_losses = [], [] gen_mols, gen_prot, gen_affinity, mode = [], [], [], [] logger.info(f'Model stored at {learner.model_path}') for epoch in range(1, params['epochs'] + 1): for step in range(1, params['steps']): # Randomly sample a protein protein_name = np.random.choice(protein_df.index) while protein_name == protein_test_name: protein_name = np.random.choice(protein_df.index) logger.info(f'Current train protein: {protein_name}') rew, loss = learner.policy_gradient(protein_name, epoch, params['batch_size']) logger.info( f"Epoch {epoch:d}/{params['epochs']:d}, step {step:d}/" f"{params['steps']:d}\t loss={loss:.2f}, mean rew={rew:.2f}") rewards.append(rew.item()) rl_losses.append(loss) # Save model if epoch % 10 == 0: learner.save(f'gen_{epoch}.pt', f'enc_{epoch}.pt') logger.info(f'EVAL protein: {protein_test_name}') smiles, preds = (learner.generate_compounds_and_evaluate( epoch, params['eval_batch_size'], protein_test_name)) gs = [s for i, s in enumerate(smiles) if preds[i] > 0.5] gp = preds[preds > 0.5] for p, s in zip(gp, gs): gen_mols.append(s) gen_prot.append(protein_test_name) gen_affinity.append(p) mode.append('eval') inds = np.argsort(gp)[::-1] for i in inds[:5]: logger.info(f'Epoch {epoch:d}, generated {gs[i]} against ' f'{protein_test_name}.\n Predicted IC50 = {gp[i]}. ') plot_and_compare_proteins(unbiased_preds, preds, protein_test_name, epoch, learner.model_path, 'train', params['eval_batch_size']) biased_ratios.append( np.round(100 * (np.sum(preds > 0.5) / len(preds)), 1)) all_toxes = np.array([learner.tox21(s) for s in smiles]) tox_ratios.append( np.round(100 * (np.sum(all_toxes == 1.) / len(all_toxes)), 1)) logger.info(f'Percentage of non-toxic compounds {tox_ratios[-1]}') toxes = [learner.tox21(s) for s in gen_mols] # Save results (good molecules!) in DF df = pd.DataFrame({ 'protein': gen_prot, 'SMILES': gen_mols, 'Binding probability': gen_affinity, 'mode': mode, 'Tox21': toxes }) df.to_csv(os.path.join(learner.model_path, 'results', 'generated.csv')) # Plot loss development loss_df = pd.DataFrame({'loss': rl_losses, 'rewards': rewards}) loss_df.to_csv(learner.model_path + '/results/loss_reward_evolution.csv') plot_loss(rl_losses, rewards, params['epochs'], protein_name, learner.model_path, rolling=5) pd.DataFrame({ 'efficacy_ratio': biased_ratios, 'tox_ratio': tox_ratios }).to_csv(learner.model_path + '/results/ratios.csv')
def main( train_sensitivity_filepath, test_sensitivity_filepath, gep_filepath, smi_filepath, gene_filepath, smiles_language_filepath, model_path, params_filepath, training_name ): logger = logging.getLogger(f'{training_name}') # Process parameter file: params = {} with open(params_filepath) as fp: params.update(json.load(fp)) # Create model directory and dump files model_dir = os.path.join(model_path, training_name) os.makedirs(os.path.join(model_dir, 'weights'), exist_ok=True) os.makedirs(os.path.join(model_dir, 'results'), exist_ok=True) with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp: json.dump(params, fp, indent=4) # Prepare the dataset logger.info("Start data preprocessing...") # Load SMILES language smiles_language = SMILESLanguage.load(smiles_language_filepath) # Load the gene list with open(gene_filepath, 'rb') as f: gene_list = pickle.load(f) # Assemble datasets train_dataset = DrugSensitivityDataset( drug_sensitivity_filepath=train_sensitivity_filepath, smi_filepath=smi_filepath, gene_expression_filepath=gep_filepath, smiles_language=smiles_language, gene_list=gene_list, drug_sensitivity_min_max=params.get('drug_sensitivity_min_max', True), drug_sensitivity_processing_parameters=params.get( 'drug_sensitivity_processing_parameters', {} ), augment=params.get('augment_smiles', True), canonical=params.get('canonical', False), kekulize=params.get('kekulize', False), all_bonds_explicit=params.get('all_bonds_explicit', False), all_hs_explicit=params.get('all_hs_explicit', False), randomize=params.get('randomize', False), remove_bonddir=params.get('remove_bonddir', False), remove_chirality=params.get('remove_chirality', False), selfies=params.get('selfies', False), add_start_and_stop=params.get('smiles_start_stop_token', True), padding_length=params.get('smiles_padding_length', None), gene_expression_standardize=params.get( 'gene_expression_standardize', True ), gene_expression_min_max=params.get('gene_expression_min_max', False), gene_expression_processing_parameters=params.get( 'gene_expression_processing_parameters', {} ), device=torch.device(params.get('dataset_device', 'cpu')), backend='eager' ) train_loader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=params['batch_size'], shuffle=True, drop_last=True, num_workers=params.get('num_workers', 0) ) test_dataset = DrugSensitivityDataset( drug_sensitivity_filepath=test_sensitivity_filepath, smi_filepath=smi_filepath, gene_expression_filepath=gep_filepath, smiles_language=smiles_language, gene_list=gene_list, drug_sensitivity_min_max=params.get('drug_sensitivity_min_max', True), drug_sensitivity_processing_parameters=params.get( 'drug_sensitivity_processing_parameters', train_dataset.drug_sensitivity_processing_parameters ), augment=params.get('augment_test_smiles', False), canonical=params.get('canonical', False), kekulize=params.get('kekulize', False), all_bonds_explicit=params.get('all_bonds_explicit', False), all_hs_explicit=params.get('all_hs_explicit', False), randomize=params.get('randomize', False), remove_bonddir=params.get('remove_bonddir', False), remove_chirality=params.get('remove_chirality', False), selfies=params.get('selfies', False), add_start_and_stop=params.get('smiles_start_stop_token', True), padding_length=params.get('smiles_padding_length', None), gene_expression_standardize=params.get( 'gene_expression_standardize', True ), gene_expression_min_max=params.get('gene_expression_min_max', False), gene_expression_processing_parameters=params.get( 'gene_expression_processing_parameters', train_dataset.gene_expression_dataset.processing ), device=torch.device(params.get('dataset_device', 'cpu')), backend='eager' ) test_loader = torch.utils.data.DataLoader( dataset=test_dataset, batch_size=params['batch_size'], shuffle=True, drop_last=True, num_workers=params.get('num_workers', 0) ) logger.info( f'Training dataset has {len(train_dataset)} samples, test set has ' f'{len(test_dataset)}.' ) device = get_device() logger.info( f'Device for data loader is {train_dataset.device} and for ' f'model is {device}' ) save_top_model = os.path.join(model_dir, 'weights/{}_{}_{}.pt') params.update({ # yapf: disable 'number_of_genes': len(gene_list), # yapf: disable 'smiles_vocabulary_size': smiles_language.number_of_tokens, 'drug_sensitivity_processing_parameters': train_dataset.drug_sensitivity_processing_parameters, 'gene_expression_processing_parameters': train_dataset.gene_expression_dataset.processing }) model = MODEL_FACTORY[params.get('model_fn', 'mca')](params).to(device) model._associate_language(smiles_language) if os.path.isfile(os.path.join(model_dir, 'weights', 'best_mse_mca.pt')): logger.info('Found existing model, restoring now...') model.load(os.path.join(model_dir, 'weights', 'mse_best_mca.pt')) with open(os.path.join(model_dir, 'results', 'mse.json'), 'r') as f: info = json.load(f) min_rmse = info['best_rmse'] max_pearson = info['best_pearson'] min_loss = info['test_loss'] else: min_loss, min_rmse, max_pearson = 100, 1000, 0 # Define optimizer optimizer = ( OPTIMIZER_FACTORY[params.get('optimizer', 'Adam')] (model.parameters(), lr=params.get('lr', 0.01)) ) num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) params.update({'number_of_parameters': num_params}) logger.info(f'Number of parameters {num_params}') # Overwrite params.json file with updated parameters. with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp: json.dump(params, fp) # Start training logger.info('Training about to start...\n') t = time() model.save( save_top_model.format('epoch', '0', params.get('model_fn', 'mca')) ) for epoch in range(params['epochs']): model.train() logger.info(params_filepath.split('/')[-1]) logger.info(f"== Epoch [{epoch}/{params['epochs']}] ==") train_loss = 0 for ind, (smiles, gep, y) in enumerate(train_loader): y_hat, pred_dict = model( torch.squeeze(smiles.to(device)), gep.to(device) ) loss = model.loss(y_hat, y.to(device)) optimizer.zero_grad() loss.backward() # Apply gradient clipping # torch.nn.utils.clip_grad_norm_(model.parameters(),1e-6) optimizer.step() train_loss += loss.item() logger.info( "\t **** TRAINING **** " f"Epoch [{epoch + 1}/{params['epochs']}], " f"loss: {train_loss / len(train_loader):.5f}. " f"This took {time() - t:.1f} secs." ) t = time() # Measure validation performance model.eval() with torch.no_grad(): test_loss = 0 predictions = [] labels = [] for ind, (smiles, gep, y) in enumerate(test_loader): y_hat, pred_dict = model( torch.squeeze(smiles.to(device)), gep.to(device) ) predictions.append(y_hat) labels.append(y) loss = model.loss(y_hat, y.to(device)) test_loss += loss.item() predictions = np.array( [p.cpu() for preds in predictions for p in preds] ) labels = np.array([l.cpu() for label in labels for l in label]) test_pearson_a = pearsonr( torch.Tensor(predictions), torch.Tensor(labels) ) test_rmse_a = np.sqrt(np.mean((predictions - labels)**2)) test_loss_a = test_loss / len(test_loader) logger.info( f"\t **** TESTING **** Epoch [{epoch + 1}/{params['epochs']}], " f"loss: {test_loss_a:.5f}, " f"Pearson: {test_pearson_a:.3f}, " f"RMSE: {test_rmse_a:.3f}" ) def save(path, metric, typ, val=None): model.save(path.format(typ, metric, params.get('model_fn', 'mca'))) with open( os.path.join(model_dir, 'results', metric + '.json'), 'w' ) as f: json.dump(info, f) np.save( os.path.join(model_dir, 'results', metric + '_preds.npy'), np.vstack([predictions, labels]) ) if typ == 'best': logger.info( f'\t New best performance in "{metric}"' f' with value : {val:.7f} in epoch: {epoch}' ) def update_info(): return { 'best_rmse': str(min_rmse), 'best_pearson': str(float(max_pearson)), 'test_loss': str(min_loss), 'predictions': [float(p) for p in predictions] } if test_loss_a < min_loss: min_rmse = test_rmse_a min_loss = test_loss_a min_loss_pearson = test_pearson_a info = update_info() save(save_top_model, 'mse', 'best', min_loss) ep_loss = epoch if test_pearson_a > max_pearson: max_pearson = test_pearson_a max_pearson_loss = test_loss_a info = update_info() save(save_top_model, 'pearson', 'best', max_pearson) ep_pearson = epoch if (epoch + 1) % params.get('save_model', 100) == 0: save(save_top_model, 'epoch', str(epoch)) logger.info( 'Overall best performances are: \n \t' f'Loss = {min_loss:.4f} in epoch {ep_loss} ' f'\t (Pearson was {min_loss_pearson:4f}) \n \t' f'Pearson = {max_pearson:.4f} in epoch {ep_pearson} ' f'\t (Loss was {max_pearson_loss:2f})' ) save(save_top_model, 'training', 'done') logger.info('Done with training, models saved, shutting down.')
def main(*, parser_namespace): disable_rdkit_logging() # read the params json params = dict() with open(parser_namespace.params_path) as f: params.update(json.load(f)) # get params mol_model_path = params.get('mol_model_path', parser_namespace.mol_model_path) omics_model_path = params.get('omics_model_path', parser_namespace.omics_model_path) ic50_model_path = params.get('ic50_model_path', parser_namespace.ic50_model_path) omics_data_path = params.get('omics_data_path', parser_namespace.omics_data_path) model_name = params.get( 'model_name', parser_namespace.model_name ) # yapf: disable site = params.get( 'site', parser_namespace.site ) # yapf: disable params['site'] = site logger.info(f'Model with name {model_name} starts.') # Load omics profiles for conditional generation, # complement with avg per site omics_df = pd.read_pickle(omics_data_path) omics_df = add_avg_profile(omics_df) # Restore SMILES Model with open(os.path.join(mol_model_path, 'model_params.json')) as f: mol_params = json.load(f) gru_encoder = StackGRUEncoder(mol_params) gru_decoder = StackGRUDecoder(mol_params) generator = TeacherVAE(gru_encoder, gru_decoder) generator.load(os.path.join( mol_model_path, f"weights/best_{params.get('smiles_metric', 'rec')}.pt"), map_location=get_device()) # Load languages generator_smiles_language = SMILESLanguage.load( os.path.join(mol_model_path, 'selfies_language.pkl')) generator._associate_language(generator_smiles_language) # Restore omics model with open(os.path.join(omics_model_path, 'model_params.json')) as f: cell_params = json.load(f) # Define network cell_encoder = ENCODER_FACTORY['dense'](cell_params) cell_encoder.load(os.path.join( omics_model_path, f"weights/best_{params.get('omics_metric','both')}_encoder.pt"), map_location=get_device()) cell_encoder.eval() # Restore PaccMann with open(os.path.join(ic50_model_path, 'model_params.json')) as f: paccmann_params = json.load(f) paccmann_predictor = MODEL_FACTORY['mca'](paccmann_params) paccmann_predictor.load(os.path.join( ic50_model_path, f"weights/best_{params.get('ic50_metric', 'rmse')}_mca.pt"), map_location=get_device()) paccmann_predictor.eval() paccmann_smiles_language = SMILESLanguage.load( os.path.join(ic50_model_path, 'smiles_language.pkl')) paccmann_predictor._associate_language(paccmann_smiles_language) # Specifies the baseline model used for comparison baseline = ReinforceOmic(generator, cell_encoder, paccmann_predictor, omics_df, params, 'baseline', logger) # Create a fresh model that will be optimized gru_encoder_rl = StackGRUEncoder(mol_params) gru_decoder_rl = StackGRUDecoder(mol_params) generator_rl = TeacherVAE(gru_encoder_rl, gru_decoder_rl) generator_rl.load(os.path.join( mol_model_path, f"weights/best_{params.get('metric', 'rec')}.pt"), map_location=get_device()) generator_rl.eval() generator_rl._associate_language(generator_smiles_language) cell_encoder_rl = ENCODER_FACTORY['dense'](cell_params) cell_encoder_rl.load(os.path.join( omics_model_path, f"weights/best_{params.get('metric', 'both')}_encoder.pt"), map_location=get_device()) cell_encoder_rl.eval() model_folder_name = site + '_' + model_name learner = ReinforceOmic(generator_rl, cell_encoder_rl, paccmann_predictor, omics_df, params, model_folder_name, logger) # Split the samples for conditional generation and initialize training train_omics, test_omics = omics_data_splitter( omics_df, site, params.get('test_fraction', 0.2)) rewards, rl_losses = [], [] gen_mols, gen_cell, gen_ic50, modes = [], [], [], [] logger.info('Models restored, start training.') for epoch in range(1, params['epochs'] + 1): for step in range(1, params['steps']): # Randomly sample a cell line: cell_line = np.random.choice(train_omics) rew, loss = learner.policy_gradient(cell_line, epoch, params['batch_size']) print(f"Epoch {epoch:d}/{params['epochs']:d}, step {step:d}/" f"{params['steps']:d}\t loss={loss:.2f}, rew={rew:.2f}") rewards.append(rew.item()) rl_losses.append(loss) # Save model learner.save(f'gen_{epoch}.pt', f'enc_{epoch}.pt') # Compare baseline and trained model on cell line base_smiles, base_preds = baseline.generate_compounds_and_evaluate( epoch, params['eval_batch_size'], cell_line) smiles, preds = learner.generate_compounds_and_evaluate( epoch, params['eval_batch_size'], cell_line) gs = [ s for i, s in enumerate(smiles) if preds[i] < learner.ic50_threshold ] gp = preds[preds < learner.ic50_threshold] for p, s in zip(gp, gs): gen_mols.append(s) gen_cell.append(cell_line) gen_ic50.append(p) modes.append('train') plot_and_compare(base_preds, preds, site, cell_line, epoch, learner.model_path, 'train', params['eval_batch_size']) # Evaluate on a validation cell line. eval_cell_line = np.random.choice(test_omics) base_smiles, base_preds = baseline.generate_compounds_and_evaluate( epoch, params['eval_batch_size'], eval_cell_line) smiles, preds = learner.generate_compounds_and_evaluate( epoch, params['eval_batch_size'], eval_cell_line) plot_and_compare(base_preds, preds, site, eval_cell_line, epoch, learner.model_path, 'test', params['eval_batch_size']) gs = [ s for i, s in enumerate(smiles) if preds[i] < learner.ic50_threshold ] gp = preds[preds < learner.ic50_threshold] for p, s in zip(gp, gs): gen_mols.append(s) gen_cell.append(eval_cell_line) gen_ic50.append(p) modes.append('test') inds = np.argsort(preds) for i in inds[:5]: logger.info(f'Epoch {epoch:d}, generated {smiles[i]} against ' f'{eval_cell_line}.\n Predicted IC50 = {preds[i]}. ') # Save results (good molecules!) in DF df = pd.DataFrame({ 'cell_line': gen_cell, 'SMILES': gen_mols, 'IC50': gen_ic50, 'mode': modes, 'tox21': [learner.tox21(s) for s in gen_mols] }) df.to_csv(os.path.join(learner.model_path, 'results', 'generated.csv')) # Plot loss development loss_df = pd.DataFrame({'loss': rl_losses, 'rewards': rewards}) loss_df.to_csv(learner.model_path + '/results/loss_reward_evolution.csv') plot_loss(rl_losses, rewards, params['epochs'], cell_line, learner.model_path, rolling=5, site=site)
def main(model_path, smi_filepath, labels_filepath, output_folder, smiles_language_filepath, model_id, confidence): logging.basicConfig(level=logging.INFO, format='%(message)s') logger = logging.getLogger('eval_toxicity') logger.setLevel(logging.INFO) disable_rdkit_logging() # Process parameter file: params = {} with open(os.path.join(model_path, 'model_params.json'), 'r') as fp: params.update(json.load(fp)) # Create model directory os.makedirs(output_folder, exist_ok=True) if model_id not in MODEL_FACTORY.keys(): raise KeyError( f'Model ID: Pass one of {MODEL_FACTORY.keys()}, not {model_id}') device = get_device() weights_path = os.path.join(model_path, 'weights', f'best_ROC-AUC_{model_id}.pt') # Restore model model = MODEL_FACTORY[model_id](params).to(device) if os.path.isfile(weights_path): try: model.load(weights_path, map_location=device) except Exception: logger.error(f'Error in model restoring from {weights_path}') else: logger.info(f'Did not find weights at {weights_path}, ' f'name weights: "best_ROC-AUC_{model_id}.pt".') model.eval() logger.info('Model restored. Model specs & parameters follow') for name, param in model.named_parameters(): logger.info((name, param.shape)) # Load language if smiles_language_filepath == '.': smiles_language_filepath = os.path.join(model_path, 'smiles_language.pkl') smiles_language = SMILESLanguage.load(smiles_language_filepath) # Assemble datasets smiles_dataset = SMILESDataset( smi_filepath, smiles_language=smiles_language, padding_length=params.get('smiles_padding_length', None), padding=params.get('padd_smiles', True), add_start_and_stop=params.get('add_start_stop_token', True), augment=params.get('augment_test_smiles', False), canonical=params.get('test_canonical', False), kekulize=params.get('test_kekulize', False), all_bonds_explicit=params.get('test_all_bonds_explicit', False), all_hs_explicit=params.get('test_all_hs_explicit', False), randomize=False, remove_bonddir=params.get('test_remove_bonddir', False), remove_chirality=params.get('test_remove_chirality', False), selfies=params.get('selfies', False), sanitize=params.get('test_sanitize', False), ) logger.info( f'SMILES Padding length is {smiles_dataset._dataset.padding_length}.' 'Consider setting manually if this looks wrong.') dataset = AnnotatedDataset(annotations_filepath=labels_filepath, dataset=smiles_dataset, device=device) loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=10, shuffle=False, drop_last=False, num_workers=0) if confidence: smiles_aleatoric_dataset = SMILESDataset( smi_filepath, smiles_language=smiles_language, padding_length=params.get('smiles_padding_length', None), padding=params.get('padd_smiles', True), add_start_and_stop=params.get('add_start_stop_token', True), augment=True, # Natively true for aleatoric uncertainity estimate canonical=params.get('test_canonical', False), kekulize=params.get('test_kekulize', False), all_bonds_explicit=params.get('test_all_bonds_explicit', False), all_hs_explicit=params.get('test_all_hs_explicit', False), remove_bonddir=params.get('test_remove_bonddir', False), remove_chirality=params.get('test_remove_chirality', False), selfies=params.get('selfies', False), sanitize=params.get('test_sanitize', False), ) ale_dataset = AnnotatedDataset(annotations_filepath=labels_filepath, dataset=smiles_aleatoric_dataset, device=device) ale_loader = torch.utils.data.DataLoader(dataset=ale_dataset, batch_size=10, shuffle=False, drop_last=False, num_workers=0) smiles_epi_dataset = SMILESDataset( smi_filepath, smiles_language=smiles_language, padding_length=params.get('smiles_padding_length', None), padding=params.get('padd_smiles', True), add_start_and_stop=params.get('add_start_stop_token', True), augment=False, # Natively false for epistemic uncertainity estimate canonical=params.get('test_canonical', False), kekulize=params.get('test_kekulize', False), all_bonds_explicit=params.get('test_all_bonds_explicit', False), all_hs_explicit=params.get('test_all_hs_explicit', False), remove_bonddir=params.get('test_remove_bonddir', False), remove_chirality=params.get('test_remove_chirality', False), selfies=params.get('selfies', False), sanitize=params.get('test_sanitize', False), ) epi_dataset = AnnotatedDataset(annotations_filepath=labels_filepath, dataset=smiles_epi_dataset, device=device) epi_loader = torch.utils.data.DataLoader(dataset=epi_dataset, batch_size=10, shuffle=False, drop_last=False, num_workers=0) logger.info(f'Device for data loader is {dataset.device} and for ' f'model is {device}') # Start evaluation logger.info('Evaluation about to start...\n') preds, labels, attention_scores, smiles = [], [], [], [] for idx, (smiles_batch, labels_batch) in enumerate(loader): pred, pred_dict = model(smiles_batch.to(device)) preds.extend(pred.detach().squeeze().tolist()) attention_scores.extend( torch.stack(pred_dict['smiles_attention'], dim=1).detach()) smiles.extend([ smiles_language.token_indexes_to_smiles(s.tolist()) for s in smiles_batch ]) labels.extend(labels_batch.squeeze().tolist()) # Scores are now 3D: num_samples x num_att_layers x padding_length attention = torch.stack(attention_scores, dim=0).numpy() if confidence: # Compute uncertainity estimates and save them epistemic_conf, epistemic_pred = monte_carlo_dropout(model, regime='loader', loader=epi_loader) aleatoric_conf, aleatoric_pred = test_time_augmentation( model, regime='loader', loader=ale_loader) epi_conf_df = pd.DataFrame( data=epistemic_conf.numpy(), columns=[ f'epistemic_conf_{i}' for i in range(epistemic_conf.shape[1]) ], ) ale_conf_df = pd.DataFrame( data=aleatoric_conf.numpy(), columns=[ f'aleatoric_conf_{i}' for i in range(aleatoric_conf.shape[1]) ], ) epi_pred_df = pd.DataFrame( data=epistemic_pred.numpy(), columns=[ f'epistemic_pred_{i}' for i in range(epistemic_pred.shape[1]) ], ) ale_pred_df = pd.DataFrame( data=aleatoric_pred.numpy(), columns=[ f'aleatoric_pred_{i}' for i in range(aleatoric_pred.shape[1]) ], ) logger.info(f'Shape of attention scores {attention.shape}.') np.save(os.path.join(output_folder, 'attention_raw.npy'), attention) attention_avg = np.mean(attention, axis=1) att_df = pd.DataFrame( data=attention_avg, columns=[f'att_idx_{i}' for i in range(attention_avg.shape[1])], ) pred_df = pd.DataFrame( data=preds, columns=[f'pred_{i}' for i in range(len(preds[0]))], ) lab_df = pd.DataFrame( data=labels, columns=[f'label_{i}' for i in range(len(labels[0]))], ) df = pd.concat([pred_df, lab_df], axis=1) if confidence: df = pd.concat( [df, epi_conf_df, ale_conf_df, epi_pred_df, ale_pred_df], axis=1) df = pd.concat([df, att_df], axis=1) df.insert(0, 'SMILES', smiles) df.to_csv(os.path.join(output_folder, 'results.csv'), index=False) logger.info('Done, shutting down.')