def load_mca(self, model_path: str): """ Restores pretrained MCA Arguments: model_path {String} -- Path to the model """ # Restore model self.model_path = model_path with open(os.path.join(model_path, 'model_params.json')) as f: params = json.load(f) # Set up language and transforms self.smiles_language = SMILESLanguage.load( os.path.join(model_path, 'smiles_language.pkl') ) self.transforms = self.compose_smiles_transforms(params) # Initialize and restore model weights self.model = MODEL_FACTORY['mca'](params) self.model.load( os.path.join(model_path, 'weights', 'best_ROC-AUC_mca.pt'), map_location=get_device() ) self.model.eval()
def __init__(self, generator, encoder, params, model_name, logger): """ Constructor for the Reinforcement object. Args: generator (nn.Module): SMILES generator object. encoder (nn.Module): An encoder object. params (dict): dict with hyperparameter. model_name (str): name of the model. logger: a logger. Returns: object of type REINFORCE used for biasing the properties estimated by the predictor of trajectories produced by the generator to maximize the custom reward function get_reward. """ super(Reinforce, self).__init__() self.generator = generator self.generator.eval() self.encoder = encoder self.encoder.eval() self.logger = logger self.device = get_device() self.optimizer = torch.optim.Adam( list(self.generator.decoder.parameters()), lr=params.get('learning_rate', 0.0001), eps=params.get('eps', 0.0001), weight_decay=params.get('weight_decay', 0.00001)) self.model_name = model_name self.model_path = os.path.join( params.get('model_folder', 'biased_models'), model_name) self.weights_path = os.path.join(self.model_path, 'weights/{}') self.smiles_to_tensor = ToTensor(self.device) # If model does not yet exist, create it. if not os.path.isdir(self.model_path): os.makedirs(os.path.join(self.model_path, 'weights'), exist_ok=True) os.makedirs(os.path.join(self.model_path, 'results'), exist_ok=True) # Save untrained models self.save('generator_epoch_0.pt', 'encoder_epoch_0.pt') with open(os.path.join(self.model_path, 'model_params.json'), 'w') as f: json.dump(params, f) else: self.logger.warning( 'Model exists already. Call model.load() to restore weights.')
def __init__(self, params, *args, **kwargs): """Constructor. Args: params (dict): A dictionary containing the parameter to built the dense Decoder. Items in params: dense_sizes (list[int]): Number of neurons in the hidden layers. num_drug_features (int, optional): Number of features for molecule. Defaults to 512 (bits fingerprint). activation_fn (string, optional): Activation function used in all layers for specification in ACTIVATION_FN_FACTORY. Defaults to 'relu'. batch_norm (bool, optional): Whether batch normalization is applied. Defaults to False. dropout (float, optional): Dropout probability in all except parametric layer. Defaults to 0.0. *args, **kwargs: positional and keyword arguments are ignored. """ super(Dense, self).__init__(*args, **kwargs) self.device = get_device() self.params = params self.num_drug_features = params.get('num_drug_features', 512) self.num_tasks = params.get('num_tasks', 12) self.hidden_sizes = params.get( 'stacked_dense_hidden_sizes', [self.num_drug_features, 5000, 1000, 500] ) self.dropout = params.get('dropout', 0.0) self.act_fn = ACTIVATION_FN_FACTORY[ params.get('activation_fn', 'relu')] self.dense_layers = nn.ModuleList( [ dense_layer( self.hidden_sizes[ind], self.hidden_sizes[ind + 1], act_fn=self.act_fn, dropout=self.dropout, batch_norm=self.params.get('batch_norm', True) ).to(self.device) for ind in range(len(self.hidden_sizes) - 1) ] ) self.final_dense = EnsembleLayer( typ=params.get('ensemble', 'score'), input_size=self.hidden_sizes[-1], output_size=self.num_tasks, ensemble_size=params.get('ensemble_size', 5), fn=ACTIVATION_FN_FACTORY['sigmoid'] ).to(self.device) self.loss_fn = LOSS_FN_FACTORY[ params.get('loss_fn', 'binary_cross_entropy_ignore_nan_and_sum')]
def main(train_affinity_filepath, test_affinity_filepath, receptor_filepath, ligand_filepath, model_path, params_filepath, training_name, smiles_language_filepath): logger = logging.getLogger(f'{training_name}') # Process parameter file: params = {} with open(params_filepath) as fp: params.update(json.load(fp)) # Create model directory and dump files model_dir = os.path.join(model_path, training_name) os.makedirs(os.path.join(model_dir, 'weights'), exist_ok=True) os.makedirs(os.path.join(model_dir, 'results'), exist_ok=True) with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp: json.dump(params, fp, indent=4) # Prepare the dataset logger.info("Start data preprocessing...") device = get_device() # Load languages if smiles_language_filepath == '': smiles_language_filepath = os.path.join( os.sep, *metadata.__file__.split(os.sep)[:-1], 'smiles_language') smiles_language = SMILESTokenizer.from_pretrained(smiles_language_filepath) smiles_language.set_encoding_transforms( randomize=None, add_start_and_stop=params.get('ligand_start_stop_token', True), padding=params.get('ligand_padding', True), padding_length=params.get('ligand_padding_length', True), device=device, ) smiles_language.set_smiles_transforms( augment=params.get('augment_smiles', False), canonical=params.get('smiles_canonical', False), kekulize=params.get('smiles_kekulize', False), all_bonds_explicit=params.get('smiles_bonds_explicit', False), all_hs_explicit=params.get('smiles_all_hs_explicit', False), remove_bonddir=params.get('smiles_remove_bonddir', False), remove_chirality=params.get('smiles_remove_chirality', False), selfies=params.get('selfies', False), sanitize=params.get('sanitize', False)) if params.get('receptor_embedding', 'learned') == 'predefined': protein_language = ProteinFeatureLanguage( features=params.get('predefined_embedding', 'blosum')) else: protein_language = ProteinLanguage() if params.get('ligand_embedding', 'learned') == 'one_hot': logger.warning( 'ligand_embedding_size parameter in param file is ignored in ' 'one_hot embedding setting, ligand_vocabulary_size used instead.') if params.get('receptor_embedding', 'learned') == 'one_hot': logger.warning( 'receptor_embedding_size parameter in param file is ignored in ' 'one_hot embedding setting, receptor_vocabulary_size used instead.' ) # Assemble datasets train_dataset = DrugAffinityDataset( drug_affinity_filepath=train_affinity_filepath, smi_filepath=ligand_filepath, protein_filepath=receptor_filepath, protein_language=protein_language, smiles_language=smiles_language, smiles_padding=params.get('ligand_padding', True), smiles_padding_length=params.get('ligand_padding_length', None), smiles_add_start_and_stop=params.get('ligand_add_start_stop', True), smiles_augment=params.get('augment_smiles', False), smiles_canonical=params.get('smiles_canonical', False), smiles_kekulize=params.get('smiles_kekulize', False), smiles_all_bonds_explicit=params.get('smiles_bonds_explicit', False), smiles_all_hs_explicit=params.get('smiles_all_hs_explicit', False), smiles_remove_bonddir=params.get('smiles_remove_bonddir', False), smiles_remove_chirality=params.get('smiles_remove_chirality', False), smiles_selfies=params.get('selfies', False), protein_amino_acid_dict=params.get('protein_amino_acid_dict', 'iupac'), protein_padding=params.get('receptor_padding', True), protein_padding_length=params.get('receptor_padding_length', None), protein_add_start_and_stop=params.get('receptor_add_start_stop', True), protein_augment_by_revert=params.get('protein_augment', False), device=device, drug_affinity_dtype=torch.float, backend='eager', iterate_dataset=params.get('iterate_dataset', False), ) train_loader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=params['batch_size'], shuffle=True, drop_last=True, num_workers=params.get('num_workers', 0), ) test_dataset = DrugAffinityDataset( drug_affinity_filepath=test_affinity_filepath, smi_filepath=ligand_filepath, protein_filepath=receptor_filepath, protein_language=protein_language, smiles_language=smiles_language, smiles_padding=params.get('ligand_padding', True), smiles_padding_length=params.get('ligand_padding_length', None), smiles_add_start_and_stop=params.get('ligand_add_start_stop', True), smiles_augment=False, smiles_canonical=params.get('smiles_test_canonical', False), smiles_kekulize=params.get('smiles_kekulize', False), smiles_all_bonds_explicit=params.get('smiles_bonds_explicit', False), smiles_all_hs_explicit=params.get('smiles_all_hs_explicit', False), smiles_remove_bonddir=params.get('smiles_remove_bonddir', False), smiles_remove_chirality=params.get('smiles_remove_chirality', False), smiles_selfies=params.get('selfies', False), protein_amino_acid_dict=params.get('protein_amino_acid_dict', 'iupac'), protein_padding=params.get('receptor_padding', True), protein_padding_length=params.get('receptor_padding_length', None), protein_add_start_and_stop=params.get('receptor_add_start_stop', True), protein_augment_by_revert=False, device=device, drug_affinity_dtype=torch.float, backend='eager', iterate_dataset=params.get('iterate_dataset', False), ) test_loader = torch.utils.data.DataLoader( dataset=test_dataset, batch_size=params['batch_size'], shuffle=True, drop_last=True, num_workers=params.get('num_workers', 0), ) logger.info( f'Training dataset has {len(train_dataset)} samples, test set has ' f'{len(test_dataset)}.') logger.info(f'Device for data loader is {train_dataset.device} and for ' f'model is {device}') save_top_model = os.path.join(model_dir, 'weights/{}_{}_{}.pt') params.update({ 'ligand_vocabulary_size': (train_dataset.smiles_dataset.smiles_language.number_of_tokens), 'receptor_vocabulary_size': protein_language.number_of_tokens, }) logger.info( f'Receptor vocabulary size is {protein_language.number_of_tokens} and ' f'ligand vocabulary size is {train_dataset.smiles_dataset.smiles_language.number_of_tokens}' ) model_fn = params.get('model_fn', 'bimodal_mca') model = MODEL_FACTORY[model_fn](params).to(device) model._associate_language(smiles_language) model._associate_language(protein_language) if os.path.isfile(os.path.join(model_dir, 'weights', 'best_mca.pt')): logger.info('Found existing model, restoring now...') try: model.load(os.path.join(model_dir, 'weights', 'best_mca.pt')) with open(os.path.join(model_dir, 'results', 'mse.json'), 'r') as f: info = json.load(f) max_roc_auc = info['best_roc_auc'] min_loss = info['test_loss'] except Exception: min_loss, max_roc_auc = 100, 0 else: min_loss, max_roc_auc = 100, 0 # Define optimizer optimizer = OPTIMIZER_FACTORY[params.get('optimizer', 'adam')](model.parameters(), lr=params.get( 'lr', 0.001)) num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) params.update({'number_of_parameters': num_params}) logger.info(f'Number of parameters: {num_params}') logger.info(f'Model: {model}') # Overwrite params.json file with updated parameters. with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp: json.dump(params, fp) # Start training logger.info('Training about to start...\n') t = time() model.save(save_top_model.format('epoch', '0', model_fn)) for epoch in range(params['epochs']): model.train() logger.info(f"== Epoch [{epoch}/{params['epochs']}] ==") train_loss = 0 for ind, (ligands, receptors, y) in enumerate(train_loader): if ind % 100 == 0: logger.info(f'Batch {ind}/{len(train_loader)}') y_hat, pred_dict = model(ligands, receptors) loss = model.loss(y_hat, y.to(device)) optimizer.zero_grad() loss.backward() # Apply gradient clipping # torch.nn.utils.clip_grad_norm_(model.parameters(),1e-6) optimizer.step() train_loss += loss.item() logger.info("\t **** TRAINING **** " f"Epoch [{epoch + 1}/{params['epochs']}], " f"loss: {train_loss / len(train_loader):.5f}. " f"This took {time() - t:.1f} secs.") t = time() # Measure validation performance model.eval() with torch.no_grad(): test_loss = 0 predictions = [] labels = [] for ind, (ligands, receptors, y) in enumerate(test_loader): y_hat, pred_dict = model(ligands.to(device), receptors.to(device)) predictions.append(y_hat) labels.append(y.clone()) loss = model.loss(y_hat, y.to(device)) test_loss += loss.item() predictions = torch.cat(predictions, dim=0).flatten().cpu().numpy() labels = torch.cat(labels, dim=0).flatten().cpu().numpy() test_loss = test_loss / len(test_loader) fpr, tpr, _ = roc_curve(labels, predictions) test_roc_auc = auc(fpr, tpr) # calculations for visualization plot precision, recall, _ = precision_recall_curve(labels, predictions) avg_precision = average_precision_score(labels, predictions) test_loss = test_loss / len(test_loader) logger.info( f"\t **** TESTING **** Epoch [{epoch + 1}/{params['epochs']}], " f"loss: {test_loss:.5f}, ROC-AUC: {test_roc_auc:.3f}, " f"Average precision: {avg_precision:.3f}.") def save(path, metric, typ, val=None): model.save(path.format(typ, metric, model_fn)) info = { 'best_roc_auc': str(max_roc_auc), 'test_loss': str(min_loss), } with open(os.path.join(model_dir, 'results', metric + '.json'), 'w') as f: json.dump(info, f) np.save( os.path.join(model_dir, 'results', metric + '_preds.npy'), np.vstack([predictions, labels]), ) if typ == 'best': logger.info(f'\t New best performance in "{metric}"' f' with value : {val:.7f} in epoch: {epoch}') if test_roc_auc > max_roc_auc: max_roc_auc = test_roc_auc save(save_top_model, 'ROC-AUC', 'best', max_roc_auc) ep_roc = epoch roc_auc_loss = test_loss if test_loss < min_loss: min_loss = test_loss save(save_top_model, 'loss', 'best', min_loss) ep_loss = epoch loss_roc_auc = test_roc_auc if (epoch + 1) % params.get('save_model', 100) == 0: save(save_top_model, 'epoch', str(epoch)) logger.info('Overall best performances are: \n \t' f'Loss = {min_loss:.4f} in epoch {ep_loss} ' f'\t (ROC-AUC was {loss_roc_auc:4f}) \n \t' f'ROC-AUC = {max_roc_auc:.4f} in epoch {ep_roc} ' f'\t (Loss was {roc_auc_loss:4f})') save(save_top_model, 'training', 'done') logger.info('Done with training, models saved, shutting down.')
def main( train_sensitivity_filepath, test_sensitivity_filepath, gep_filepath, smi_filepath, gene_filepath, smiles_language_filepath, model_path, params_filepath, training_name ): logger = logging.getLogger(f'{training_name}') # Process parameter file: params = {} with open(params_filepath) as fp: params.update(json.load(fp)) # Create model directory and dump files model_dir = os.path.join(model_path, training_name) os.makedirs(os.path.join(model_dir, 'weights'), exist_ok=True) os.makedirs(os.path.join(model_dir, 'results'), exist_ok=True) with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp: json.dump(params, fp, indent=4) # Prepare the dataset logger.info("Start data preprocessing...") # Load SMILES language smiles_language = SMILESLanguage.load(smiles_language_filepath) # Load the gene list with open(gene_filepath, 'rb') as f: gene_list = pickle.load(f) # Assemble datasets train_dataset = DrugSensitivityDataset( drug_sensitivity_filepath=train_sensitivity_filepath, smi_filepath=smi_filepath, gene_expression_filepath=gep_filepath, smiles_language=smiles_language, gene_list=gene_list, drug_sensitivity_min_max=params.get('drug_sensitivity_min_max', True), drug_sensitivity_processing_parameters=params.get( 'drug_sensitivity_processing_parameters', {} ), augment=params.get('augment_smiles', True), canonical=params.get('canonical', False), kekulize=params.get('kekulize', False), all_bonds_explicit=params.get('all_bonds_explicit', False), all_hs_explicit=params.get('all_hs_explicit', False), randomize=params.get('randomize', False), remove_bonddir=params.get('remove_bonddir', False), remove_chirality=params.get('remove_chirality', False), selfies=params.get('selfies', False), add_start_and_stop=params.get('smiles_start_stop_token', True), padding_length=params.get('smiles_padding_length', None), gene_expression_standardize=params.get( 'gene_expression_standardize', True ), gene_expression_min_max=params.get('gene_expression_min_max', False), gene_expression_processing_parameters=params.get( 'gene_expression_processing_parameters', {} ), device=torch.device(params.get('dataset_device', 'cpu')), backend='eager' ) train_loader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=params['batch_size'], shuffle=True, drop_last=True, num_workers=params.get('num_workers', 0) ) test_dataset = DrugSensitivityDataset( drug_sensitivity_filepath=test_sensitivity_filepath, smi_filepath=smi_filepath, gene_expression_filepath=gep_filepath, smiles_language=smiles_language, gene_list=gene_list, drug_sensitivity_min_max=params.get('drug_sensitivity_min_max', True), drug_sensitivity_processing_parameters=params.get( 'drug_sensitivity_processing_parameters', train_dataset.drug_sensitivity_processing_parameters ), augment=params.get('augment_test_smiles', False), canonical=params.get('canonical', False), kekulize=params.get('kekulize', False), all_bonds_explicit=params.get('all_bonds_explicit', False), all_hs_explicit=params.get('all_hs_explicit', False), randomize=params.get('randomize', False), remove_bonddir=params.get('remove_bonddir', False), remove_chirality=params.get('remove_chirality', False), selfies=params.get('selfies', False), add_start_and_stop=params.get('smiles_start_stop_token', True), padding_length=params.get('smiles_padding_length', None), gene_expression_standardize=params.get( 'gene_expression_standardize', True ), gene_expression_min_max=params.get('gene_expression_min_max', False), gene_expression_processing_parameters=params.get( 'gene_expression_processing_parameters', train_dataset.gene_expression_dataset.processing ), device=torch.device(params.get('dataset_device', 'cpu')), backend='eager' ) test_loader = torch.utils.data.DataLoader( dataset=test_dataset, batch_size=params['batch_size'], shuffle=True, drop_last=True, num_workers=params.get('num_workers', 0) ) logger.info( f'Training dataset has {len(train_dataset)} samples, test set has ' f'{len(test_dataset)}.' ) device = get_device() logger.info( f'Device for data loader is {train_dataset.device} and for ' f'model is {device}' ) save_top_model = os.path.join(model_dir, 'weights/{}_{}_{}.pt') params.update({ # yapf: disable 'number_of_genes': len(gene_list), # yapf: disable 'smiles_vocabulary_size': smiles_language.number_of_tokens, 'drug_sensitivity_processing_parameters': train_dataset.drug_sensitivity_processing_parameters, 'gene_expression_processing_parameters': train_dataset.gene_expression_dataset.processing }) model = MODEL_FACTORY[params.get('model_fn', 'mca')](params).to(device) model._associate_language(smiles_language) if os.path.isfile(os.path.join(model_dir, 'weights', 'best_mse_mca.pt')): logger.info('Found existing model, restoring now...') model.load(os.path.join(model_dir, 'weights', 'mse_best_mca.pt')) with open(os.path.join(model_dir, 'results', 'mse.json'), 'r') as f: info = json.load(f) min_rmse = info['best_rmse'] max_pearson = info['best_pearson'] min_loss = info['test_loss'] else: min_loss, min_rmse, max_pearson = 100, 1000, 0 # Define optimizer optimizer = ( OPTIMIZER_FACTORY[params.get('optimizer', 'Adam')] (model.parameters(), lr=params.get('lr', 0.01)) ) num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) params.update({'number_of_parameters': num_params}) logger.info(f'Number of parameters {num_params}') # Overwrite params.json file with updated parameters. with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp: json.dump(params, fp) # Start training logger.info('Training about to start...\n') t = time() model.save( save_top_model.format('epoch', '0', params.get('model_fn', 'mca')) ) for epoch in range(params['epochs']): model.train() logger.info(params_filepath.split('/')[-1]) logger.info(f"== Epoch [{epoch}/{params['epochs']}] ==") train_loss = 0 for ind, (smiles, gep, y) in enumerate(train_loader): y_hat, pred_dict = model( torch.squeeze(smiles.to(device)), gep.to(device) ) loss = model.loss(y_hat, y.to(device)) optimizer.zero_grad() loss.backward() # Apply gradient clipping # torch.nn.utils.clip_grad_norm_(model.parameters(),1e-6) optimizer.step() train_loss += loss.item() logger.info( "\t **** TRAINING **** " f"Epoch [{epoch + 1}/{params['epochs']}], " f"loss: {train_loss / len(train_loader):.5f}. " f"This took {time() - t:.1f} secs." ) t = time() # Measure validation performance model.eval() with torch.no_grad(): test_loss = 0 predictions = [] labels = [] for ind, (smiles, gep, y) in enumerate(test_loader): y_hat, pred_dict = model( torch.squeeze(smiles.to(device)), gep.to(device) ) predictions.append(y_hat) labels.append(y) loss = model.loss(y_hat, y.to(device)) test_loss += loss.item() predictions = np.array( [p.cpu() for preds in predictions for p in preds] ) labels = np.array([l.cpu() for label in labels for l in label]) test_pearson_a = pearsonr( torch.Tensor(predictions), torch.Tensor(labels) ) test_rmse_a = np.sqrt(np.mean((predictions - labels)**2)) test_loss_a = test_loss / len(test_loader) logger.info( f"\t **** TESTING **** Epoch [{epoch + 1}/{params['epochs']}], " f"loss: {test_loss_a:.5f}, " f"Pearson: {test_pearson_a:.3f}, " f"RMSE: {test_rmse_a:.3f}" ) def save(path, metric, typ, val=None): model.save(path.format(typ, metric, params.get('model_fn', 'mca'))) with open( os.path.join(model_dir, 'results', metric + '.json'), 'w' ) as f: json.dump(info, f) np.save( os.path.join(model_dir, 'results', metric + '_preds.npy'), np.vstack([predictions, labels]) ) if typ == 'best': logger.info( f'\t New best performance in "{metric}"' f' with value : {val:.7f} in epoch: {epoch}' ) def update_info(): return { 'best_rmse': str(min_rmse), 'best_pearson': str(float(max_pearson)), 'test_loss': str(min_loss), 'predictions': [float(p) for p in predictions] } if test_loss_a < min_loss: min_rmse = test_rmse_a min_loss = test_loss_a min_loss_pearson = test_pearson_a info = update_info() save(save_top_model, 'mse', 'best', min_loss) ep_loss = epoch if test_pearson_a > max_pearson: max_pearson = test_pearson_a max_pearson_loss = test_loss_a info = update_info() save(save_top_model, 'pearson', 'best', max_pearson) ep_pearson = epoch if (epoch + 1) % params.get('save_model', 100) == 0: save(save_top_model, 'epoch', str(epoch)) logger.info( 'Overall best performances are: \n \t' f'Loss = {min_loss:.4f} in epoch {ep_loss} ' f'\t (Pearson was {min_loss_pearson:4f}) \n \t' f'Pearson = {max_pearson:.4f} in epoch {ep_pearson} ' f'\t (Loss was {max_pearson_loss:2f})' ) save(save_top_model, 'training', 'done') logger.info('Done with training, models saved, shutting down.')
def __init__(self): self.device = get_device()
def main( train_scores_filepath, test_scores_filepath, smi_filepath, model_path, params_filepath, training_name ): logging.basicConfig(level=logging.INFO, format='%(message)s') logger = logging.getLogger(f'{training_name}') logger.setLevel(logging.INFO) disable_rdkit_logging() device = get_device() # Restore pretrained model logger.info('Start model restoring.') try: with open(os.path.join(model_path, 'model_params.json'), 'r') as fp: params = json.load(fp) smiles_language = SMILESLanguage.load( os.path.join(model_path, 'smiles_language.pkl') ) model = MODEL_FACTORY[params.get('model_fn', 'mca')](params).to(device) # Try weight restoring try: weight_path = os.path.join( model_path, 'weights', params.get('weights_name', 'best_ROC-AUC_mca.pt') ) model.load(weight_path) except Exception: try: wp = os.listdir(os.path.join(model_path, 'weights'))[0] logger.info( f"Weights {weight_path} not found. Try restore {wp}" ) model.load(os.path.join(model_path, 'weights', wp)) except Exception: raise Exception('Error in weight loading.') except Exception: raise Exception( 'Error in model restoring. model_path should point to the model root ' 'folder that contains a model_params.json, a smiles_language.pkl and a ' 'weights folder.' ) logger.info('Model restored. Now starting to craft it for the task') # Process parameter file: model_params = {} with open(params_filepath) as fp: model_params.update(json.load(fp)) model = update_mca_model(model, model_params) logger.info('Model set up.') for idx, (name, param) in enumerate(model.named_parameters()): logger.info( (idx, name, param.shape, f'Gradients: {param.requires_grad}') ) ft_model_path = os.path.join(model_path, 'finetuned') os.makedirs(os.path.join(ft_model_path, 'weights'), exist_ok=True) os.makedirs(os.path.join(ft_model_path, 'results'), exist_ok=True) logger.info('Now start data preprocessing...') # Assemble datasets smiles_dataset = SMILESDataset( smi_filepath, smiles_language=smiles_language, padding_length=params.get('smiles_padding_length', None), padding=params.get('padd_smiles', True), add_start_and_stop=params.get('add_start_stop_token', True), augment=params.get('augment_smiles', False), canonical=params.get('canonical', False), kekulize=params.get('kekulize', False), all_bonds_explicit=params.get('all_bonds_explicit', False), all_hs_explicit=params.get('all_hs_explicit', False), randomize=params.get('randomize', False), remove_bonddir=params.get('remove_bonddir', False), remove_chirality=params.get('remove_chirality', False), selfies=params.get('selfies', False), sanitize=params.get('sanitize', True) ) train_dataset = AnnotatedDataset( annotations_filepath=train_scores_filepath, dataset=smiles_dataset, device=get_device() ) train_loader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=params['batch_size'], shuffle=True, drop_last=True, num_workers=params.get('num_workers', 0) ) # Generally, if sanitize is True molecules are de-kekulized. Augmentation # preserves the "kekulization", so if it is used, test data should be # sanitized or canonicalized. smiles_test_dataset = SMILESDataset( smi_filepath, smiles_language=smiles_language, padding_length=params.get('smiles_padding_length', None), padding=params.get('padd_smiles', True), add_start_and_stop=params.get('add_start_stop_token', True), augment=params.get('augment_test_smiles', False), canonical=params.get('test_canonical', False), kekulize=params.get('test_kekulize', False), all_bonds_explicit=params.get('test_all_bonds_explicit', False), all_hs_explicit=params.get('test_all_hs_explicit', False), randomize=False, remove_bonddir=params.get('test_remove_bonddir', False), remove_chirality=params.get('test_remove_chirality', False), selfies=params.get('selfies', False), sanitize=params.get('test_sanitize', False) ) # Dump eventually modified SMILES Language smiles_language.save(os.path.join(ft_model_path, 'smiles_language.pkl')) test_dataset = AnnotatedDataset( annotations_filepath=test_scores_filepath, dataset=smiles_test_dataset, device=get_device() ) test_loader = torch.utils.data.DataLoader( dataset=test_dataset, batch_size=params['batch_size'], shuffle=False, drop_last=False, num_workers=params.get('num_workers', 0) ) save_top_model = os.path.join(ft_model_path, 'weights/{}_{}_{}.pt') num_t_params = sum( p.numel() for p in model.parameters() if p.requires_grad ) num_params = sum(p.numel() for p in model.parameters()) params.update({'params': num_params, 'trainable_params': num_t_params}) logger.info( f'Number of parameters: {num_params} (trainable {num_t_params}).' ) # Define optimizer only for those layers which require gradients optimizer = ( OPTIMIZER_FACTORY[params.get('optimizer', 'adam')]( filter(lambda p: p.requires_grad, model.parameters()), lr=params.get('lr', 0.00001) ) ) # Dump params.json file with updated parameters. with open(os.path.join(ft_model_path, 'model_params.json'), 'w') as fp: json.dump(params, fp) # Start training logger.info('Training about to start...\n') t = time() min_loss, max_roc_auc = 1000000, 0 max_precision_recall_score = 0 for epoch in range(params['epochs']): model.train() logger.info(params_filepath.split('/')[-1]) logger.info(f"== Epoch [{epoch}/{params['epochs']}] ==") train_loss = 0 for ind, (smiles, y) in enumerate(train_loader): smiles = torch.squeeze(smiles.to(device)) y_hat, pred_dict = model(smiles) loss = model.loss(y_hat, y.to(device)) optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.item() logger.info( '\t **** TRAINING **** ' f"Epoch [{epoch + 1}/{params['epochs']}], " f'loss: {train_loss / len(train_loader):.5f}. ' f'This took {time() - t:.1f} secs.' ) t = time() # Measure validation performance model.eval() with torch.no_grad(): test_loss = 0 predictions = [] labels = [] for ind, (smiles, y) in enumerate(test_loader): smiles = torch.squeeze(smiles.to(device)) y_hat, pred_dict = model(smiles) predictions.append(y_hat) # Copy y tensor since loss function applies downstream modification labels.append(y.clone()) loss = model.loss(y_hat, y.to(device)) test_loss += loss.item() predictions = torch.cat(predictions, dim=0).flatten().cpu().numpy() labels = torch.cat(labels, dim=0).flatten().cpu().numpy() # Remove NaNs from labels to compute scores predictions = predictions[~np.isnan(labels)] labels = labels[~np.isnan(labels)] test_loss_a = test_loss / len(test_loader) fpr, tpr, _ = roc_curve(labels, predictions) test_roc_auc_a = auc(fpr, tpr) # calculations for visualization plot precision, recall, _ = precision_recall_curve(labels, predictions) # score for precision vs accuracy test_precision_recall_score = average_precision_score( labels, predictions ) logger.info( f"\t **** TEST **** Epoch [{epoch + 1}/{params['epochs']}], " f'loss: {test_loss_a:.5f}, , roc_auc: {test_roc_auc_a:.5f}, ' f'avg precision-recall score: {test_precision_recall_score:.5f}' ) info = { 'test_auc': test_roc_auc_a, 'train_loss': train_loss / len(train_loader), 'test_loss': test_loss_a, 'test_auc': test_roc_auc_a, 'best_test_auc': max_roc_auc, 'test_precision_recall_score': test_precision_recall_score, 'best_precision_recall_score': max_precision_recall_score, } def save(path, metric, typ, val=None): model.save(path.format(typ, metric, params.get('model_fn', 'mca'))) if typ == 'best': logger.info( f'\t New best performance in {metric}' f' with value : {val:.7f} in epoch: {epoch+1}' ) if test_roc_auc_a > max_roc_auc: max_roc_auc = test_roc_auc_a info.update({'best_test_auc': max_roc_auc}) save(save_top_model, 'ROC-AUC', 'best', max_roc_auc) np.save( os.path.join(ft_model_path, 'results', 'best_predictions.npy'), predictions ) with open( os.path.join(ft_model_path, 'results', 'metrics.json'), 'w' ) as f: json.dump(info, f) if test_precision_recall_score > max_precision_recall_score: max_precision_recall_score = test_precision_recall_score info.update( {'best_precision_recall_score': max_precision_recall_score} ) save( save_top_model, 'precision-recall score', 'best', max_precision_recall_score ) if test_loss_a < min_loss: min_loss = test_loss_a save(save_top_model, 'loss', 'best', min_loss) ep_loss = epoch if (epoch + 1) % params.get('save_model', 100) == 0: save(save_top_model, 'epoch', str(epoch)) logger.info( 'Overall best performances are: \n \t' f'Loss = {min_loss:.4f} in epoch {ep_loss} ' ) save(save_top_model, 'training', 'done') logger.info('Done with training, models saved, shutting down.')
def update_mca_model(model: MCAMultiTask, params: dict) -> MCAMultiTask: """ Receives a pretrained model (instance of MCAMultiTask), modifies it and returns the updated object Args: model (MCAMultiTask): Pretrained model to be modified. params (dict): Hyperparameter file for the modifications. Needs to include: - number_of_tunable_layers (how many layers should not be frozen. If number exceeds number of existing layers, all layers are tuned.) - fresh_dense_sizes (a list of fresh dense layers to be plugged in at the end). - num_tasks (number of classfication tasks being performed). Returns: MCAMultiTask: Modified model for finetune """ if not isinstance(model, MCAMultiTask): raise TypeError( f'Wrong model type, was {type(model)}, not MCAMultiTask.' ) # Freeze the correct layers and add new ones # Not strictly speaking all layers, but all param matrices, gradient-req or not. num_layers = len(['' for p in model.parameters()]) num_to_tune = params['number_of_tunable_layers'] if num_to_tune > num_layers: logger.warning( f'Model has {num_layers} tunable layers. Given # is larger: {num_to_tune}.' ) num_to_tune = num_layers fresh_sizes = params['fresh_dense_sizes'] logger.info( f'Model has {num_layers} layers. {num_to_tune} will be finetuned, ' f'{len(fresh_sizes)} fresh ones will be added (sizes: {fresh_sizes}).' ) # Count the ensemble layers (will be replaced anyways) num_ensemble_layers = len( list( filter(lambda tpl: 'ensemble' in tpl[0], model.named_parameters()) ) ) # Freeze the right layers for idx, (name, param) in enumerate(model.named_parameters()): if idx < num_layers - num_to_tune - num_ensemble_layers: param.requires_grad = False # Add more dense layers fresh_sizes.insert(0, model.hidden_sizes[-1]) model.dense_layers = nn.Sequential( model.dense_layers, nn.Sequential( OrderedDict( [ ( 'fresh_dense_{}'.format(ind), dense_layer( fresh_sizes[ind], fresh_sizes[ind + 1], act_fn=ACTIVATION_FN_FACTORY[ params.get('activation_fn', 'relu')], dropout=params.get('dropout', 0.5), batch_norm=params.get('batch_norm', True) ).to(get_device()) ) for ind in range(len(fresh_sizes) - 1) ] ) ) ) # Replace final layer model.num_tasks = params['num_tasks'] model.final_dense = EnsembleLayer( typ=params.get('ensemble', 'score'), input_size=fresh_sizes[-1], output_size=model.num_tasks, ensemble_size=params.get('ensemble_size', 5), fn=ACTIVATION_FN_FACTORY['sigmoid'] ) return model
def main(train_sensitivity_filepath, test_sensitivity_filepath, gep_filepath, smi_filepath, gene_filepath, smiles_language_filepath, model_path, params_filepath, training_name): logger = logging.getLogger(f'{training_name}') # Process parameter file: params = {} with open(params_filepath) as fp: params.update(json.load(fp)) # Create model directory and dump files model_dir = os.path.join(model_path, training_name) os.makedirs(os.path.join(model_dir, 'weights'), exist_ok=True) os.makedirs(os.path.join(model_dir, 'results'), exist_ok=True) with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp: json.dump(params, fp, indent=4) # Prepare the dataset logger.info("Start data preprocessing...") # Load SMILES language smiles_language = SMILESLanguage.load(smiles_language_filepath) # Load the gene list with open(gene_filepath, 'rb') as f: gene_list = pickle.load(f) # Assemble datasets train_dataset = DrugSensitivityDataset( drug_sensitivity_filepath=train_sensitivity_filepath, smi_filepath=smi_filepath, gene_expression_filepath=gep_filepath, smiles_language=smiles_language, gene_list=gene_list, drug_sensitivity_min_max=params.get('drug_sensitivity_min_max', True), augment=params.get('augment_smiles', True), add_start_and_stop=params.get('smiles_start_stop_token', True), padding_length=params.get('smiles_padding_length', None), device=torch.device(params.get('dataset_device', 'cpu')), backend='eager') train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=params['batch_size'], shuffle=True, drop_last=True, num_workers=params.get( 'num_workers', 0)) test_dataset = DrugSensitivityDataset( drug_sensitivity_filepath=test_sensitivity_filepath, smi_filepath=smi_filepath, gene_expression_filepath=gep_filepath, smiles_language=smiles_language, gene_list=gene_list, drug_sensitivity_min_max=params.get('drug_sensitivity_min_max', True), augment=params.get('augment_smiles', True), add_start_and_stop=params.get('smiles_start_stop_token', True), padding_length=params.get('smiles_padding_length', None), device=torch.device(params.get('dataset_device', 'cpu')), backend='eager') test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=params['batch_size'], shuffle=True, drop_last=True, num_workers=params.get( 'num_workers', 0)) params.update({ 'number_of_genes': len(gene_list), 'smiles_vocabulary_size': smiles_language.number_of_tokens }) # Set the tensorboard logger tb_logger = Logger(os.path.join(model_dir, 'tb_logs')) device = get_device() logger.info(f'Device for data loader is {train_dataset.device} and for ' f'model is {device}') save_top_model = os.path.join(model_dir, 'weights/{}_{}_{}.pt') model = MODEL_FACTORY[params.get('model_fn', 'mca')](params).to(device) # Define optimizer optimizer = (OPTIMIZER_FACTORY[params.get('optimizer', 'Adam')](model.parameters(), lr=params.get( 'lr', 0.01))) # Overwrite params.json file with updated parameters. with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp: json.dump(params, fp) # Start training logger.info('Training about to start...\n') t = time() min_loss, max_pearson = 100, 0 for epoch in range(params['epochs']): model.train() logger.info(params_filepath.split('/')[-1]) logger.info(f"== Epoch [{epoch}/{params['epochs']}] ==") train_loss = 0 for ind, (smiles, gep, y) in enumerate(train_loader): y_hat, pred_dict = model(torch.squeeze(smiles.to(device)), gep.to(device)) loss = model.loss(y_hat, y.to(device)) optimizer.zero_grad() loss.backward() # Apply gradient clipping # torch.nn.utils.clip_grad_norm_(model.parameters(),1e-6) optimizer.step() train_loss += loss.item() logger.info("\t **** TRAINING **** " f"Epoch [{epoch + 1}/{params['epochs']}], " f"loss: {train_loss / len(train_loader):.5f}. " f"This took {time() - t:.1f} secs.") t = time() # Measure validation performance model.eval() with torch.no_grad(): test_loss = 0 predictions = [] labels = [] for ind, (smiles, gep, y) in enumerate(test_loader): y_hat, pred_dict = model(torch.squeeze(smiles.to(device)), gep.to(device)) predictions.append(y_hat) labels.append(y) loss = model.loss(y_hat, y.to(device)) test_loss += loss.item() predictions = np.array( [p.cpu() for preds in predictions for p in preds]) labels = np.array([l.cpu() for label in labels for l in label]) test_pearson_a = pearsonr(predictions, labels) test_rmse_a = np.sqrt(np.mean((predictions - labels)**2)) np.save(os.path.join(model_dir, 'results', 'preds.npy'), np.hstack([predictions, labels])) test_loss_a = test_loss / len(test_loader) logger.info( f"\t **** TESTING **** Epoch [{epoch + 1}/{params['epochs']}], " f"loss: {test_loss_a:.5f}, " f"Pearson: {test_pearson_a:.3f}, " f"RMSE: {test_rmse_a:.3f}") # TensorBoard logging of scalars. info = { 'train_loss': train_loss / len(train_loader), 'test_loss': test_loss_a, } for tag, value in info.items(): tb_logger.scalar_summary(tag, value, epoch + 1) def save(path, metric, typ, val=None): model.save(path.format(typ, metric, params.get('model_fn', 'mca'))) if typ == 'best': logger.info(f'\t New best performance in "{metric}"' f' with value : {val:.7f} in epoch: {epoch}') if test_loss_a < min_loss: min_loss = test_loss_a min_loss_pearson = test_pearson_a save(save_top_model, 'mse', 'best', min_loss) ep_loss = epoch if test_pearson_a > max_pearson: max_pearson = test_pearson_a max_pearson_loss = test_loss_a save(save_top_model, 'pearson', 'best', max_pearson) ep_pearson = epoch if (epoch + 1) % params.get('save_model', 100) == 0: save(save_top_model, 'epoch', str(epoch)) logger.info('Overall best performances are: \n \t' f'Loss = {min_loss:.4f} in epoch {ep_loss} ' f'\t (Pearson was {min_loss_pearson:4f}) \n \t' f'Pearson = {max_pearson:.4f} in epoch {ep_pearson} ' f'\t (Loss was {max_pearson_loss:2f})') save(save_top_model, 'training', 'done') logger.info('Done with training, models saved, shutting down.')
def __init__(self, params, *args, **kwargs): """Constructor. Args: params (dict): A dictionary containing the parameter to built the dense Decoder. TODO params should become actual arguments (use **params). Items in params: filters (list[int], optional): Numbers of filters to learn per convolutional layer. kernel_sizes (list[list[int]], optional): Sizes of kernels per convolutional layer. Defaults to [ [3, params['smiles_embedding_size']], [5, params['smiles_embedding_size']], [11, params['smiles_embedding_size']] ] activation_fn (string, optional): Activation function used in all layers for specification in ACTIVATION_FN_FACTORY. Defaults to 'relu'. batch_norm (bool, optional): Whether batch normalization is applied. Defaults to False. dropout (float, optional): Dropout probability in all except parametric layer. Defaults to 0.0. *args, **kwargs: positional and keyword arguments are ignored. """ super(CNN, self).__init__(*args, **kwargs) # Model Parameter self.device = get_device() self.params = params self.loss_fn = LOSS_FN_FACTORY[params.get('loss_fn', 'cnn')] self.kernel_sizes = params.get('kernel_sizes', [[3, params['smiles_embedding_size']], [5, params['smiles_embedding_size']], [11, params['smiles_embedding_size']]]) self.num_filters = [1] + params.get('num_filters', [10, 20, 50]) if len(self.filters) != len(self.kernel_sizes): raise ValueError( 'Length of filter and kernel size lists do not match.') self.smiles_embedding = nn.Embedding( self.params['smiles_vocabulary_size'], self.params['smiles_embedding_size'], scale_grad_by_freq=params.get('embed_scale_grad', False)) self.dropout = params.get('dropout', 0.0) self.act_fn = ACTIVATION_FN_FACTORY[params.get('activation_fn', 'relu')] self.conv_layers = [ convolutional_layer(self.num_filters[layer], self.num_filters[layer + 1], self.kernel_sizes[layer]) for layer in range(len(self.channel_inputs) - 1) ] self.final_dense = nn.Linear( (self.num_filters[-1] * self.params['smiles_embedding_size']), self.num_tasks) self.final_act_fn = ACTIVATION_FN_FACTORY['sigmoid']
def main(train_scores_filepath, test_scores_filepath, smi_filepath, smiles_language_filepath, model_path, params_filepath, training_name, embedding_path=None): logging.basicConfig(level=logging.INFO, format='%(message)s') logger = logging.getLogger(f'{training_name}') logger.setLevel(logging.INFO) disable_rdkit_logging() # Process parameter file: params = {} with open(params_filepath) as fp: params.update(json.load(fp)) if embedding_path: params['embedding_path'] = embedding_path # Create model directory and dump files model_dir = os.path.join(model_path, training_name) os.makedirs(os.path.join(model_dir, 'weights'), exist_ok=True) os.makedirs(os.path.join(model_dir, 'results'), exist_ok=True) with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp: json.dump(params, fp, indent=4) logger.info('Start data preprocessing...') smiles_language = SMILESLanguage.load(smiles_language_filepath) # Prepare FP processing if params.get('model_fn', 'mca') == 'dense': morgan_transform = Compose([ SMILESToMorganFingerprints(radius=params.get('fp_radius', 2), bits=params.get('num_drug_features', 512), chirality=params.get( 'fp_chirality', True)), ToTensor(get_device()) ]) def smiles_tensor_batch_to_fp(smiles): """ To abuse SMILES dataset for FP usage""" out = torch.Tensor(smiles.shape[0], params.get('num_drug_features', 256)) for ind, tensor in enumerate(smiles): smiles = smiles_language.token_indexes_to_smiles( tensor.tolist()) out[ind, :] = torch.squeeze(morgan_transform(smiles)) return out # Assemble datasets smiles_dataset = SMILESDataset( smi_filepath, smiles_language=smiles_language, padding_length=params.get('smiles_padding_length', None), padding=params.get('padd_smiles', True), add_start_and_stop=params.get('add_start_stop_token', True), augment=params.get('augment_smiles', False), canonical=params.get('canonical', False), kekulize=params.get('kekulize', False), all_bonds_explicit=params.get('all_bonds_explicit', False), all_hs_explicit=params.get('all_hs_explicit', False), randomize=params.get('randomize', False), remove_bonddir=params.get('remove_bonddir', False), remove_chirality=params.get('remove_chirality', False), selfies=params.get('selfies', False), sanitize=params.get('sanitize', True)) # include arg label_columns if data file has any unwanted columns (such as index) to be ignored. train_dataset = AnnotatedDataset( annotations_filepath=train_scores_filepath, dataset=smiles_dataset, device=get_device()) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=params['batch_size'], shuffle=True, drop_last=True, num_workers=params.get( 'num_workers', 0)) if (params.get('uncertainty', True) and params.get('augment_test_data', False)): raise ValueError( 'Epistemic uncertainty evaluation not supported if augmentation ' 'is not enabled for test data.') # Generally, if sanitize is True molecules are de-kekulized. Augmentation # preserves the "kekulization", so if it is used, test data should be # sanitized or canonicalized. smiles_test_dataset = SMILESDataset( smi_filepath, smiles_language=smiles_language, padding_length=params.get('smiles_padding_length', None), padding=params.get('padd_smiles', True), add_start_and_stop=params.get('add_start_stop_token', True), augment=params.get('augment_test_smiles', False), canonical=params.get('test_canonical', False), kekulize=params.get('test_kekulize', False), all_bonds_explicit=params.get('test_all_bonds_explicit', False), all_hs_explicit=params.get('test_all_hs_explicit', False), randomize=False, remove_bonddir=params.get('test_remove_bonddir', False), remove_chirality=params.get('test_remove_chirality', False), selfies=params.get('selfies', False), sanitize=params.get('test_sanitize', False)) # Dump eventually modified SMILES Language smiles_language.save(os.path.join(model_dir, 'smiles_language.pkl')) logger.info(smiles_dataset._dataset.transform) logger.info(smiles_test_dataset._dataset.transform) # include arg label_columns if data file has any unwanted columns (such as index) to be ignored. test_dataset = AnnotatedDataset(annotations_filepath=test_scores_filepath, dataset=smiles_test_dataset, device=get_device()) test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=params['batch_size'], shuffle=False, drop_last=False, num_workers=params.get( 'num_workers', 0)) if params.get('confidence', False): smiles_conf_dataset = SMILESDataset( smi_filepath, smiles_language=smiles_language, padding_length=params.get('smiles_padding_length', None), padding=params.get('padd_smiles', True), add_start_and_stop=params.get('add_start_stop_token', True), augment=True, # Natively true for epistemic uncertainity estimate canonical=params.get('canonical', False), kekulize=params.get('kekulize', False), all_bonds_explicit=params.get('all_bonds_explicit', False), all_hs_explicit=params.get('all_hs_explicit', False), randomize=params.get('randomize', False), remove_bonddir=params.get('remove_bonddir', False), remove_chirality=params.get('remove_chirality', False), selfies=params.get('selfies', False)) conf_dataset = AnnotatedDataset( annotations_filepath=test_scores_filepath, dataset=smiles_conf_dataset, device=get_device()) conf_loader = torch.utils.data.DataLoader( dataset=conf_dataset, batch_size=params['batch_size'], shuffle=False, drop_last=False, num_workers=params.get('num_workers', 0)) if not params.get('embedding', 'learned') == 'pretrained': params.update( {'smiles_vocabulary_size': smiles_language.number_of_tokens}) device = get_device() logger.info(f'Device for data loader is {train_dataset.device} and for ' f'model is {device}') save_top_model = os.path.join(model_dir, 'weights/{}_{}_{}.pt') model = MODEL_FACTORY[params.get('model_fn', 'mca')](params).to(device) logger.info(model) logger.info(model.loss_fn.class_weights) logger.info('Parameters follow') for name, param in model.named_parameters(): logger.info((name, param.shape)) num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) params.update({'number_of_parameters': num_params}) logger.info(f'Number of parameters {num_params}') # Define optimizer optimizer = (OPTIMIZER_FACTORY[params.get('optimizer', 'adam')](model.parameters(), lr=params.get( 'lr', 0.00001))) # Overwrite params.json file with updated parameters. with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp: json.dump(params, fp) # Start training logger.info('Training about to start...\n') t = time() min_loss, max_roc_auc = 1000000, 0 max_precision_recall_score = 0 for epoch in range(params['epochs']): model.train() logger.info(params_filepath.split('/')[-1]) logger.info(f"== Epoch [{epoch}/{params['epochs']}] ==") train_loss = 0 for ind, (smiles, y) in enumerate(train_loader): smiles = torch.squeeze(smiles.to(device)) # Transform smiles to FP if needed if params.get('model_fn', 'mca') == 'dense': smiles = smiles_tensor_batch_to_fp(smiles).to(device) y_hat, pred_dict = model(smiles) loss = model.loss(y_hat, y.to(device)) optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.item() logger.info('\t **** TRAINING **** ' f"Epoch [{epoch + 1}/{params['epochs']}], " f'loss: {train_loss / len(train_loader):.5f}. ' f'This took {time() - t:.1f} secs.') t = time() # Measure validation performance model.eval() with torch.no_grad(): test_loss = 0 predictions = [] labels = [] for ind, (smiles, y) in enumerate(test_loader): smiles = torch.squeeze(smiles.to(device)) # Transform smiles to FP if needed if params.get('model_fn', 'mca') == 'dense': smiles = smiles_tensor_batch_to_fp(smiles).to(device) y_hat, pred_dict = model(smiles) predictions.append(y_hat) # Copy y tensor since loss function applies downstream # modification labels.append(y.clone()) loss = model.loss(y_hat, y.to(device)) test_loss += loss.item() predictions = torch.cat(predictions, dim=0).flatten().cpu().numpy() labels = torch.cat(labels, dim=0).flatten().cpu().numpy() # Remove NaNs from labels to compute scores predictions = predictions[~np.isnan(labels)] labels = labels[~np.isnan(labels)] test_loss_a = test_loss / len(test_loader) fpr, tpr, _ = roc_curve(labels, predictions) test_roc_auc_a = auc(fpr, tpr) # calculations for visualization plot precision, recall, _ = precision_recall_curve(labels, predictions) # score for precision vs accuracy test_precision_recall_score = average_precision_score( labels, predictions) logger.info( f"\t **** TEST **** Epoch [{epoch + 1}/{params['epochs']}], " f'loss: {test_loss_a:.5f}, , roc_auc: {test_roc_auc_a:.5f}, ' f'avg precision-recall score: {test_precision_recall_score:.5f}') info = { 'test_auc': test_roc_auc_a, 'train_loss': train_loss / len(train_loader), 'test_loss': test_loss_a, 'test_auc': test_roc_auc_a, 'best_test_auc': max_roc_auc, 'test_precision_recall_score': test_precision_recall_score, 'best_precision_recall_score': max_precision_recall_score, } def save(path, metric, typ, val=None): model.save(path.format(typ, metric, params.get('model_fn', 'mca'))) if typ == 'best': logger.info(f'\t New best performance in {metric}' f' with value : {val:.7f} in epoch: {epoch+1}') if test_roc_auc_a > max_roc_auc: max_roc_auc = test_roc_auc_a info.update({'best_test_auc': max_roc_auc}) save(save_top_model, 'ROC-AUC', 'best', max_roc_auc) np.save(os.path.join(model_dir, 'results', 'best_predictions.npy'), predictions) with open(os.path.join(model_dir, 'results', 'metrics.json'), 'w') as f: json.dump(info, f) if params.get('confidence', False): # Compute uncertainity estimates and save them epistemic_conf = monte_carlo_dropout(model, regime='loader', loader=conf_loader) aleatoric_conf = test_time_augmentation(model, regime='loader', loader=conf_loader) np.save( os.path.join(model_dir, 'results', 'epistemic_conf.npy'), epistemic_conf) np.save( os.path.join(model_dir, 'results', 'aleatoric_conf.npy'), aleatoric_conf) if test_precision_recall_score > max_precision_recall_score: max_precision_recall_score = test_precision_recall_score info.update( {'best_precision_recall_score': max_precision_recall_score}) save(save_top_model, 'precision-recall score', 'best', max_precision_recall_score) if test_loss_a < min_loss: min_loss = test_loss_a save(save_top_model, 'loss', 'best', min_loss) ep_loss = epoch if (epoch + 1) % params.get('save_model', 100) == 0: save(save_top_model, 'epoch', str(epoch)) logger.info('Overall best performances are: \n \t' f'Loss = {min_loss:.4f} in epoch {ep_loss} ') save(save_top_model, 'training', 'done') logger.info('Done with training, models saved, shutting down.')
def main(model_path, smi_filepath, labels_filepath, output_folder, smiles_language_filepath, model_id, confidence): logging.basicConfig(level=logging.INFO, format='%(message)s') logger = logging.getLogger('eval_toxicity') logger.setLevel(logging.INFO) disable_rdkit_logging() # Process parameter file: params = {} with open(os.path.join(model_path, 'model_params.json'), 'r') as fp: params.update(json.load(fp)) # Create model directory os.makedirs(output_folder, exist_ok=True) if model_id not in MODEL_FACTORY.keys(): raise KeyError( f'Model ID: Pass one of {MODEL_FACTORY.keys()}, not {model_id}') device = get_device() weights_path = os.path.join(model_path, 'weights', f'best_ROC-AUC_{model_id}.pt') # Restore model model = MODEL_FACTORY[model_id](params).to(device) if os.path.isfile(weights_path): try: model.load(weights_path, map_location=device) except Exception: logger.error(f'Error in model restoring from {weights_path}') else: logger.info(f'Did not find weights at {weights_path}, ' f'name weights: "best_ROC-AUC_{model_id}.pt".') model.eval() logger.info('Model restored. Model specs & parameters follow') for name, param in model.named_parameters(): logger.info((name, param.shape)) # Load language if smiles_language_filepath == '.': smiles_language_filepath = os.path.join(model_path, 'smiles_language.pkl') smiles_language = SMILESLanguage.load(smiles_language_filepath) # Assemble datasets smiles_dataset = SMILESDataset( smi_filepath, smiles_language=smiles_language, padding_length=params.get('smiles_padding_length', None), padding=params.get('padd_smiles', True), add_start_and_stop=params.get('add_start_stop_token', True), augment=params.get('augment_test_smiles', False), canonical=params.get('test_canonical', False), kekulize=params.get('test_kekulize', False), all_bonds_explicit=params.get('test_all_bonds_explicit', False), all_hs_explicit=params.get('test_all_hs_explicit', False), randomize=False, remove_bonddir=params.get('test_remove_bonddir', False), remove_chirality=params.get('test_remove_chirality', False), selfies=params.get('selfies', False), sanitize=params.get('test_sanitize', False), ) logger.info( f'SMILES Padding length is {smiles_dataset._dataset.padding_length}.' 'Consider setting manually if this looks wrong.') dataset = AnnotatedDataset(annotations_filepath=labels_filepath, dataset=smiles_dataset, device=device) loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=10, shuffle=False, drop_last=False, num_workers=0) if confidence: smiles_aleatoric_dataset = SMILESDataset( smi_filepath, smiles_language=smiles_language, padding_length=params.get('smiles_padding_length', None), padding=params.get('padd_smiles', True), add_start_and_stop=params.get('add_start_stop_token', True), augment=True, # Natively true for aleatoric uncertainity estimate canonical=params.get('test_canonical', False), kekulize=params.get('test_kekulize', False), all_bonds_explicit=params.get('test_all_bonds_explicit', False), all_hs_explicit=params.get('test_all_hs_explicit', False), remove_bonddir=params.get('test_remove_bonddir', False), remove_chirality=params.get('test_remove_chirality', False), selfies=params.get('selfies', False), sanitize=params.get('test_sanitize', False), ) ale_dataset = AnnotatedDataset(annotations_filepath=labels_filepath, dataset=smiles_aleatoric_dataset, device=device) ale_loader = torch.utils.data.DataLoader(dataset=ale_dataset, batch_size=10, shuffle=False, drop_last=False, num_workers=0) smiles_epi_dataset = SMILESDataset( smi_filepath, smiles_language=smiles_language, padding_length=params.get('smiles_padding_length', None), padding=params.get('padd_smiles', True), add_start_and_stop=params.get('add_start_stop_token', True), augment=False, # Natively false for epistemic uncertainity estimate canonical=params.get('test_canonical', False), kekulize=params.get('test_kekulize', False), all_bonds_explicit=params.get('test_all_bonds_explicit', False), all_hs_explicit=params.get('test_all_hs_explicit', False), remove_bonddir=params.get('test_remove_bonddir', False), remove_chirality=params.get('test_remove_chirality', False), selfies=params.get('selfies', False), sanitize=params.get('test_sanitize', False), ) epi_dataset = AnnotatedDataset(annotations_filepath=labels_filepath, dataset=smiles_epi_dataset, device=device) epi_loader = torch.utils.data.DataLoader(dataset=epi_dataset, batch_size=10, shuffle=False, drop_last=False, num_workers=0) logger.info(f'Device for data loader is {dataset.device} and for ' f'model is {device}') # Start evaluation logger.info('Evaluation about to start...\n') preds, labels, attention_scores, smiles = [], [], [], [] for idx, (smiles_batch, labels_batch) in enumerate(loader): pred, pred_dict = model(smiles_batch.to(device)) preds.extend(pred.detach().squeeze().tolist()) attention_scores.extend( torch.stack(pred_dict['smiles_attention'], dim=1).detach()) smiles.extend([ smiles_language.token_indexes_to_smiles(s.tolist()) for s in smiles_batch ]) labels.extend(labels_batch.squeeze().tolist()) # Scores are now 3D: num_samples x num_att_layers x padding_length attention = torch.stack(attention_scores, dim=0).numpy() if confidence: # Compute uncertainity estimates and save them epistemic_conf, epistemic_pred = monte_carlo_dropout(model, regime='loader', loader=epi_loader) aleatoric_conf, aleatoric_pred = test_time_augmentation( model, regime='loader', loader=ale_loader) epi_conf_df = pd.DataFrame( data=epistemic_conf.numpy(), columns=[ f'epistemic_conf_{i}' for i in range(epistemic_conf.shape[1]) ], ) ale_conf_df = pd.DataFrame( data=aleatoric_conf.numpy(), columns=[ f'aleatoric_conf_{i}' for i in range(aleatoric_conf.shape[1]) ], ) epi_pred_df = pd.DataFrame( data=epistemic_pred.numpy(), columns=[ f'epistemic_pred_{i}' for i in range(epistemic_pred.shape[1]) ], ) ale_pred_df = pd.DataFrame( data=aleatoric_pred.numpy(), columns=[ f'aleatoric_pred_{i}' for i in range(aleatoric_pred.shape[1]) ], ) logger.info(f'Shape of attention scores {attention.shape}.') np.save(os.path.join(output_folder, 'attention_raw.npy'), attention) attention_avg = np.mean(attention, axis=1) att_df = pd.DataFrame( data=attention_avg, columns=[f'att_idx_{i}' for i in range(attention_avg.shape[1])], ) pred_df = pd.DataFrame( data=preds, columns=[f'pred_{i}' for i in range(len(preds[0]))], ) lab_df = pd.DataFrame( data=labels, columns=[f'label_{i}' for i in range(len(labels[0]))], ) df = pd.concat([pred_df, lab_df], axis=1) if confidence: df = pd.concat( [df, epi_conf_df, ale_conf_df, epi_pred_df, ale_pred_df], axis=1) df = pd.concat([df, att_df], axis=1) df.insert(0, 'SMILES', smiles) df.to_csv(os.path.join(output_folder, 'results.csv'), index=False) logger.info('Done, shutting down.')
import logging from typing import Iterable import torch import torch.nn as nn from paccmann_predictor.utils.utils import get_device logger = logging.getLogger(__name__) DEVICE = get_device() class BCEIgnoreNaN(nn.Module): """Wrapper for BCE function that ignores NaNs""" def __init__(self, reduction: str, class_weights: tuple = (1, 1)) -> None: """ Args: reduction (str): Reduction applied in loss function. Either sum or mean. class_weights (tuple, optional): Class weights for loss function. Defaults to (1, 1), i.e. equal class weighhts. """ super(BCEIgnoreNaN, self).__init__() self.loss = nn.BCELoss(reduction='none') if reduction != 'sum' and reduction != 'mean': raise ValueError( f'Chose reduction type as mean or sum, not {reduction}' ) self.reduction = reduction
def __init__(self, params: dict, *args, **kwargs): """Constructor. Args: params (dict): A dictionary containing the parameter to built the dense encoder. Items in params: smiles_embedding_size (int): dimension of tokens' embedding. smiles_vocabulary_size (int): size of the tokens vocabulary. activation_fn (string, optional): Activation function used in all layers for specification in ACTIVATION_FN_FACTORY. Defaults to 'relu'. batch_norm (bool, optional): Whether batch normalization is applied. Defaults to True. dropout (float, optional): Dropout probability in all except parametric layer. Defaults to 0.5. filters (list[int], optional): Numbers of filters to learn per convolutional layer. Defaults to [64, 64, 64]. kernel_sizes (list[list[int]], optional): Sizes of kernels per convolutional layer. Defaults to [ [3, params['smiles_embedding_size']], [5, params['smiles_embedding_size']], [11, params['smiles_embedding_size']] ] NOTE: The kernel sizes should match the dimensionality of the smiles_embedding_size, so if the latter is 8, the images are t x 8, then treat the 8 embedding dimensions like channels in an RGB image. multiheads (list[int], optional): Amount of attentive multiheads per SMILES embedding. Should have len(filters)+1. Defaults to [4, 4, 4, 4]. stacked_dense_hidden_sizes (list[int], optional): Sizes of the hidden dense layers. Defaults to [1024, 512]. smiles_attention_size (int, optional): size of the attentive layer for the smiles sequence. Defaults to 64. Example params: ``` { "smiles_attention_size": 8, "smiles_vocabulary_size": 28, "smiles_embedding_size": 8, "filters": [128, 128], "kernel_sizes": [[3, 8], [5, 8]], "multiheads":[4, 4, 4] "stacked_dense_hidden_sizes": [1024, 512] } ``` """ super(MCAMultiTask, self).__init__(*args, **kwargs) # Model Parameter self.device = get_device() self.params = params self.num_tasks = params.get('num_tasks', 12) self.smiles_attention_size = params.get('smiles_attention_size', 64) # Model architecture (hyperparameter) self.multiheads = params.get('multiheads', [4, 4, 4, 4]) self.filters = params.get('filters', [64, 64, 64]) self.dropout = params.get('dropout', 0.5) self.use_batch_norm = self.params.get('batch_norm', True) self.act_fn = ACTIVATION_FN_FACTORY[params.get('activation_fn', 'relu')] # Build the model. First the embeddings if params.get('embedding', 'learned') == 'learned': self.smiles_embedding = nn.Embedding( self.params['smiles_vocabulary_size'], self.params['smiles_embedding_size'], scale_grad_by_freq=params.get('embed_scale_grad', False)) elif params.get('embedding', 'learned') == 'one_hot': self.smiles_embedding = nn.Embedding( self.params['smiles_vocabulary_size'], self.params['smiles_vocabulary_size']) # Plug in one hot-vectors and freeze weights self.smiles_embedding.load_state_dict({ 'weight': torch.nn.functional.one_hot( torch.arange(self.params['smiles_vocabulary_size'])) }) self.smiles_embedding.weight.requires_grad = False elif params.get('embedding', 'learned') == 'pretrained': # Load the pretrained embeddings try: with open(params['embedding_path'], 'rb') as f: embeddings = pickle.load(f) except KeyError: raise KeyError('Path for embeddings is missing in params.') # Plug into layer self.smiles_embedding = nn.Embedding(embeddings.shape[0], embeddings.shape[1]) self.smiles_embedding.load_state_dict( {'weight': torch.Tensor(embeddings)}) if params.get('fix_embeddings', True): self.smiles_embedding.weight.requires_grad = False else: raise ValueError(f"Unknown embedding type: {params['embedding']}") self.kernel_sizes = params.get( 'kernel_sizes', [[3, self.smiles_embedding.weight.shape[1]], [5, self.smiles_embedding.weight.shape[1]], [11, self.smiles_embedding.weight.shape[1]]]) self.hidden_sizes = ([ self.multiheads[0] * self.smiles_embedding.weight.shape[1] + sum([h * f for h, f in zip(self.multiheads[1:], self.filters)]) ] + params.get('stacked_hidden_sizes', [1024, 512])) if len(self.filters) != len(self.kernel_sizes): raise ValueError( 'Length of filter and kernel size lists do not match.') if len(self.filters) + 1 != len(self.multiheads): raise ValueError( 'Length of filter and multihead lists do not match') self.convolutional_layers = nn.Sequential( OrderedDict([(f'convolutional_{index}', convolutional_layer( num_kernel, kernel_size, act_fn=self.act_fn, dropout=self.dropout, batch_norm=self.use_batch_norm).to(self.device)) for index, (num_kernel, kernel_size) in enumerate( zip(self.filters, self.kernel_sizes))])) smiles_hidden_sizes = [self.smiles_embedding.weight.shape[1] ] + self.filters self.smiles_projections = nn.Sequential( OrderedDict([ (f'smiles_projection_{self.multiheads[0]*layer+index}', smiles_projection(smiles_hidden_sizes[layer], self.smiles_attention_size)) for layer in range(len(self.multiheads)) for index in range(self.multiheads[layer]) ])) self.alpha_projections = nn.Sequential( OrderedDict([(f'alpha_projection_{self.multiheads[0]*layer+index}', alpha_projection(self.smiles_attention_size)) for layer in range(len(self.multiheads)) for index in range(self.multiheads[layer])])) if self.use_batch_norm: self.batch_norm = nn.BatchNorm1d(self.hidden_sizes[0]) self.dense_layers = nn.Sequential( OrderedDict([ ('dense_{}'.format(ind), dense_layer(self.hidden_sizes[ind], self.hidden_sizes[ind + 1], act_fn=self.act_fn, dropout=self.dropout, batch_norm=self.use_batch_norm).to(self.device)) for ind in range(len(self.hidden_sizes) - 1) ])) if params.get('ensemble', 'None') not in ['score', 'prob', 'None']: raise NotImplementedError( "Choose ensemble type from ['score', 'prob', 'None']") if params.get('ensemble', 'None') == 'None': params['ensemble_size'] = 1 self.final_dense = EnsembleLayer(typ=params.get('ensemble', 'score'), input_size=self.hidden_sizes[-1], output_size=self.num_tasks, ensemble_size=params.get( 'ensemble_size', 5), fn=ACTIVATION_FN_FACTORY['sigmoid']) self.loss_fn = LOSS_FN_FACTORY[ params.get('loss_fn', 'binary_cross_entropy_ignore_nan_and_sum') ] # yapf: disable # Set class weights manually if 'binary_cross_entropy_ignore_nan' in params.get( 'loss_fn', 'binary_cross_entropy_ignore_nan_and_sum'): self.loss_fn.class_weights = params.get('class_weights', [1, 1])