def test_all_base_for_indexed_methods(self): with TestFileContent(self.smiles_content) as smiles_file: with TestFileContent(self.annotated_content) as annotation_file: smiles_dataset = SMILESTokenizerDataset( smiles_file.filename, add_start_and_stop=True, backend='eager') annotated_dataset = AnnotatedDataset( annotation_file.filename, dataset=smiles_dataset, index_col=0, label_columns=['label_1'], ) duplicate_ds = AnnotatedDataset( annotation_file.filename, dataset=smiles_dataset + smiles_dataset, ) all_keys = [ row.split(',')[-1] for row in self.annotated_content.split(os.linesep)[1:] ] for ds, keys in [ (annotated_dataset, all_keys), ]: index = -1 self._test_indexed(ds, keys, index) # duplicates in datasource can be checked directly self.assertTrue(duplicate_ds.datasource.has_duplicate_keys) # DataFrame is the dataset self.assertFalse(duplicate_ds.has_duplicate_keys)
def test___getitem___from_indexed_annotation(self) -> None: """Test __getitem__ with index in the annotation file.""" with TestFileContent(self.smiles_content) as smiles_file: with TestFileContent(self.annotated_content) as annotation_file: smiles_dataset = SMILESTokenizerDataset( smiles_file.filename, add_start_and_stop=True, backend='eager') annotated_dataset = AnnotatedDataset(annotation_file.filename, dataset=smiles_dataset) pad_index = smiles_dataset.smiles_language.padding_index start_index = smiles_dataset.smiles_language.start_index stop_index = smiles_dataset.smiles_language.stop_index c_index = smiles_dataset.smiles_language.token_to_index['C'] o_index = smiles_dataset.smiles_language.token_to_index['O'] n_index = smiles_dataset.smiles_language.token_to_index['N'] s_index = smiles_dataset.smiles_language.token_to_index['S'] # test first sample smiles_tokens, labels = annotated_dataset[0] self.assertEqual( smiles_tokens.numpy().flatten().tolist(), [pad_index, start_index, c_index, c_index, o_index, stop_index], ) self.assertTrue( np.allclose(labels.numpy().flatten().tolist(), [2.3, 3.4])) # test last sample smiles_tokens, labels = annotated_dataset[2] self.assertEqual( smiles_tokens.numpy().flatten().tolist(), [start_index, n_index, c_index, c_index, s_index, stop_index], ) self.assertTrue( np.allclose(labels.numpy().flatten().tolist(), [6.7, 7.8]))
def test_return_integer_index(self) -> None: """Test __getitem__ with index in dataset.""" with TestFileContent(self.smiles_content) as smiles_file: with TestFileContent(self.annotated_content) as annotation_file: smiles_dataset = SMILESTokenizerDataset(smiles_file.filename) # default default_annotated_dataset = AnnotatedDataset( annotation_file.filename, dataset=smiles_dataset) # outer indexed_annotated_dataset = indexed(default_annotated_dataset) # inner annotated_dataset = AnnotatedDataset( annotation_file.filename, dataset=indexed(smiles_dataset)) # default # relevant to check that `smiles_dataset` was not mutated from # the `indexed(smiles_dataset)` call smiles_tokens, labels = default_annotated_dataset[2] self.check_CHEMBL602(smiles_tokens, labels) # outer __getitem__ (smiles_tokens, labels), sample_index = indexed_annotated_dataset[2] self.assertEqual(sample_index, 2) self.check_CHEMBL602(smiles_tokens, labels) # outer get_item_from_key ( (smiles_tokens, labels), sample_index, ) = indexed_annotated_dataset.get_item_from_key('CHEMBL602') self.assertEqual(sample_index, 2) self.check_CHEMBL602(smiles_tokens, labels) # inner __getitem__ (smiles_tokens, sample_index), labels = annotated_dataset[0] self.assertEqual(sample_index, 0) # inner __getitem__ with different index in smiles_dataset (smiles_tokens, sample_index), labels = annotated_dataset[2] self.assertEqual(sample_index, 3) self.check_CHEMBL602(smiles_tokens, labels) # inner get_item_from_key (smiles_tokens, sample_index ), labels = annotated_dataset.get_item_from_key('CHEMBL602') self.assertEqual(sample_index, 3) self.check_CHEMBL602(smiles_tokens, labels)
def test_return_key_index_stacked(self) -> None: """Test __getitem__ with key in dataset.""" with TestFileContent(self.smiles_content) as smiles_file: with TestFileContent(self.annotated_content) as annotation_file: smiles_dataset = keyed( indexed(SMILESTokenizerDataset(smiles_file.filename, ))) annotated_dataset = indexed( keyed( AnnotatedDataset( annotation_file.filename, dataset=smiles_dataset, ))) (smiles_tokens, smiles_index), smiles_key = smiles_dataset[3] self.assertEqual(smiles_key, 'CHEMBL602') self.assertEqual(smiles_index, 3) self.check_CHEMBL602(smiles_tokens) (smiles_tokens, smiles_index ), smiles_key = smiles_dataset.get_item_from_key('CHEMBL602') self.assertEqual(smiles_key, 'CHEMBL602') self.assertEqual(smiles_index, 3) self.check_CHEMBL602(smiles_tokens) ( ( (((smiles_tokens, smiles_index), smiles_key), labels), # inner annotation_key, ), annotation_index, ) = annotated_dataset[2] self.assertEqual(smiles_key, 'CHEMBL602') self.assertEqual(smiles_index, 3) self.assertEqual(annotation_index, 2) self.assertEqual(annotation_key, 'CHEMBL602') ( ( (((smiles_tokens, smiles_index), smiles_key), labels), # inner annotation_key, ), annotation_index, ) = annotated_dataset.get_item_from_key('CHEMBL602') self.assertEqual(smiles_key, 'CHEMBL602') self.assertEqual(smiles_index, 3) self.assertEqual(annotation_index, 2) self.assertEqual(annotation_key, 'CHEMBL602')
def main( train_scores_filepath, test_scores_filepath, smi_filepath, model_path, params_filepath, training_name ): logging.basicConfig(level=logging.INFO, format='%(message)s') logger = logging.getLogger(f'{training_name}') logger.setLevel(logging.INFO) disable_rdkit_logging() device = get_device() # Restore pretrained model logger.info('Start model restoring.') try: with open(os.path.join(model_path, 'model_params.json'), 'r') as fp: params = json.load(fp) smiles_language = SMILESLanguage.load( os.path.join(model_path, 'smiles_language.pkl') ) model = MODEL_FACTORY[params.get('model_fn', 'mca')](params).to(device) # Try weight restoring try: weight_path = os.path.join( model_path, 'weights', params.get('weights_name', 'best_ROC-AUC_mca.pt') ) model.load(weight_path) except Exception: try: wp = os.listdir(os.path.join(model_path, 'weights'))[0] logger.info( f"Weights {weight_path} not found. Try restore {wp}" ) model.load(os.path.join(model_path, 'weights', wp)) except Exception: raise Exception('Error in weight loading.') except Exception: raise Exception( 'Error in model restoring. model_path should point to the model root ' 'folder that contains a model_params.json, a smiles_language.pkl and a ' 'weights folder.' ) logger.info('Model restored. Now starting to craft it for the task') # Process parameter file: model_params = {} with open(params_filepath) as fp: model_params.update(json.load(fp)) model = update_mca_model(model, model_params) logger.info('Model set up.') for idx, (name, param) in enumerate(model.named_parameters()): logger.info( (idx, name, param.shape, f'Gradients: {param.requires_grad}') ) ft_model_path = os.path.join(model_path, 'finetuned') os.makedirs(os.path.join(ft_model_path, 'weights'), exist_ok=True) os.makedirs(os.path.join(ft_model_path, 'results'), exist_ok=True) logger.info('Now start data preprocessing...') # Assemble datasets smiles_dataset = SMILESDataset( smi_filepath, smiles_language=smiles_language, padding_length=params.get('smiles_padding_length', None), padding=params.get('padd_smiles', True), add_start_and_stop=params.get('add_start_stop_token', True), augment=params.get('augment_smiles', False), canonical=params.get('canonical', False), kekulize=params.get('kekulize', False), all_bonds_explicit=params.get('all_bonds_explicit', False), all_hs_explicit=params.get('all_hs_explicit', False), randomize=params.get('randomize', False), remove_bonddir=params.get('remove_bonddir', False), remove_chirality=params.get('remove_chirality', False), selfies=params.get('selfies', False), sanitize=params.get('sanitize', True) ) train_dataset = AnnotatedDataset( annotations_filepath=train_scores_filepath, dataset=smiles_dataset, device=get_device() ) train_loader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=params['batch_size'], shuffle=True, drop_last=True, num_workers=params.get('num_workers', 0) ) # Generally, if sanitize is True molecules are de-kekulized. Augmentation # preserves the "kekulization", so if it is used, test data should be # sanitized or canonicalized. smiles_test_dataset = SMILESDataset( smi_filepath, smiles_language=smiles_language, padding_length=params.get('smiles_padding_length', None), padding=params.get('padd_smiles', True), add_start_and_stop=params.get('add_start_stop_token', True), augment=params.get('augment_test_smiles', False), canonical=params.get('test_canonical', False), kekulize=params.get('test_kekulize', False), all_bonds_explicit=params.get('test_all_bonds_explicit', False), all_hs_explicit=params.get('test_all_hs_explicit', False), randomize=False, remove_bonddir=params.get('test_remove_bonddir', False), remove_chirality=params.get('test_remove_chirality', False), selfies=params.get('selfies', False), sanitize=params.get('test_sanitize', False) ) # Dump eventually modified SMILES Language smiles_language.save(os.path.join(ft_model_path, 'smiles_language.pkl')) test_dataset = AnnotatedDataset( annotations_filepath=test_scores_filepath, dataset=smiles_test_dataset, device=get_device() ) test_loader = torch.utils.data.DataLoader( dataset=test_dataset, batch_size=params['batch_size'], shuffle=False, drop_last=False, num_workers=params.get('num_workers', 0) ) save_top_model = os.path.join(ft_model_path, 'weights/{}_{}_{}.pt') num_t_params = sum( p.numel() for p in model.parameters() if p.requires_grad ) num_params = sum(p.numel() for p in model.parameters()) params.update({'params': num_params, 'trainable_params': num_t_params}) logger.info( f'Number of parameters: {num_params} (trainable {num_t_params}).' ) # Define optimizer only for those layers which require gradients optimizer = ( OPTIMIZER_FACTORY[params.get('optimizer', 'adam')]( filter(lambda p: p.requires_grad, model.parameters()), lr=params.get('lr', 0.00001) ) ) # Dump params.json file with updated parameters. with open(os.path.join(ft_model_path, 'model_params.json'), 'w') as fp: json.dump(params, fp) # Start training logger.info('Training about to start...\n') t = time() min_loss, max_roc_auc = 1000000, 0 max_precision_recall_score = 0 for epoch in range(params['epochs']): model.train() logger.info(params_filepath.split('/')[-1]) logger.info(f"== Epoch [{epoch}/{params['epochs']}] ==") train_loss = 0 for ind, (smiles, y) in enumerate(train_loader): smiles = torch.squeeze(smiles.to(device)) y_hat, pred_dict = model(smiles) loss = model.loss(y_hat, y.to(device)) optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.item() logger.info( '\t **** TRAINING **** ' f"Epoch [{epoch + 1}/{params['epochs']}], " f'loss: {train_loss / len(train_loader):.5f}. ' f'This took {time() - t:.1f} secs.' ) t = time() # Measure validation performance model.eval() with torch.no_grad(): test_loss = 0 predictions = [] labels = [] for ind, (smiles, y) in enumerate(test_loader): smiles = torch.squeeze(smiles.to(device)) y_hat, pred_dict = model(smiles) predictions.append(y_hat) # Copy y tensor since loss function applies downstream modification labels.append(y.clone()) loss = model.loss(y_hat, y.to(device)) test_loss += loss.item() predictions = torch.cat(predictions, dim=0).flatten().cpu().numpy() labels = torch.cat(labels, dim=0).flatten().cpu().numpy() # Remove NaNs from labels to compute scores predictions = predictions[~np.isnan(labels)] labels = labels[~np.isnan(labels)] test_loss_a = test_loss / len(test_loader) fpr, tpr, _ = roc_curve(labels, predictions) test_roc_auc_a = auc(fpr, tpr) # calculations for visualization plot precision, recall, _ = precision_recall_curve(labels, predictions) # score for precision vs accuracy test_precision_recall_score = average_precision_score( labels, predictions ) logger.info( f"\t **** TEST **** Epoch [{epoch + 1}/{params['epochs']}], " f'loss: {test_loss_a:.5f}, , roc_auc: {test_roc_auc_a:.5f}, ' f'avg precision-recall score: {test_precision_recall_score:.5f}' ) info = { 'test_auc': test_roc_auc_a, 'train_loss': train_loss / len(train_loader), 'test_loss': test_loss_a, 'test_auc': test_roc_auc_a, 'best_test_auc': max_roc_auc, 'test_precision_recall_score': test_precision_recall_score, 'best_precision_recall_score': max_precision_recall_score, } def save(path, metric, typ, val=None): model.save(path.format(typ, metric, params.get('model_fn', 'mca'))) if typ == 'best': logger.info( f'\t New best performance in {metric}' f' with value : {val:.7f} in epoch: {epoch+1}' ) if test_roc_auc_a > max_roc_auc: max_roc_auc = test_roc_auc_a info.update({'best_test_auc': max_roc_auc}) save(save_top_model, 'ROC-AUC', 'best', max_roc_auc) np.save( os.path.join(ft_model_path, 'results', 'best_predictions.npy'), predictions ) with open( os.path.join(ft_model_path, 'results', 'metrics.json'), 'w' ) as f: json.dump(info, f) if test_precision_recall_score > max_precision_recall_score: max_precision_recall_score = test_precision_recall_score info.update( {'best_precision_recall_score': max_precision_recall_score} ) save( save_top_model, 'precision-recall score', 'best', max_precision_recall_score ) if test_loss_a < min_loss: min_loss = test_loss_a save(save_top_model, 'loss', 'best', min_loss) ep_loss = epoch if (epoch + 1) % params.get('save_model', 100) == 0: save(save_top_model, 'epoch', str(epoch)) logger.info( 'Overall best performances are: \n \t' f'Loss = {min_loss:.4f} in epoch {ep_loss} ' ) save(save_top_model, 'training', 'done') logger.info('Done with training, models saved, shutting down.')
def main(train_scores_filepath, test_scores_filepath, smi_filepath, smiles_language_filepath, model_path, params_filepath, training_name, embedding_path=None): logging.basicConfig(level=logging.INFO, format='%(message)s') logger = logging.getLogger(f'{training_name}') logger.setLevel(logging.INFO) disable_rdkit_logging() # Process parameter file: params = {} with open(params_filepath) as fp: params.update(json.load(fp)) if embedding_path: params['embedding_path'] = embedding_path # Create model directory and dump files model_dir = os.path.join(model_path, training_name) os.makedirs(os.path.join(model_dir, 'weights'), exist_ok=True) os.makedirs(os.path.join(model_dir, 'results'), exist_ok=True) with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp: json.dump(params, fp, indent=4) logger.info('Start data preprocessing...') smiles_language = SMILESLanguage.load(smiles_language_filepath) # Prepare FP processing if params.get('model_fn', 'mca') == 'dense': morgan_transform = Compose([ SMILESToMorganFingerprints(radius=params.get('fp_radius', 2), bits=params.get('num_drug_features', 512), chirality=params.get( 'fp_chirality', True)), ToTensor(get_device()) ]) def smiles_tensor_batch_to_fp(smiles): """ To abuse SMILES dataset for FP usage""" out = torch.Tensor(smiles.shape[0], params.get('num_drug_features', 256)) for ind, tensor in enumerate(smiles): smiles = smiles_language.token_indexes_to_smiles( tensor.tolist()) out[ind, :] = torch.squeeze(morgan_transform(smiles)) return out # Assemble datasets smiles_dataset = SMILESDataset( smi_filepath, smiles_language=smiles_language, padding_length=params.get('smiles_padding_length', None), padding=params.get('padd_smiles', True), add_start_and_stop=params.get('add_start_stop_token', True), augment=params.get('augment_smiles', False), canonical=params.get('canonical', False), kekulize=params.get('kekulize', False), all_bonds_explicit=params.get('all_bonds_explicit', False), all_hs_explicit=params.get('all_hs_explicit', False), randomize=params.get('randomize', False), remove_bonddir=params.get('remove_bonddir', False), remove_chirality=params.get('remove_chirality', False), selfies=params.get('selfies', False), sanitize=params.get('sanitize', True)) # include arg label_columns if data file has any unwanted columns (such as index) to be ignored. train_dataset = AnnotatedDataset( annotations_filepath=train_scores_filepath, dataset=smiles_dataset, device=get_device()) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=params['batch_size'], shuffle=True, drop_last=True, num_workers=params.get( 'num_workers', 0)) if (params.get('uncertainty', True) and params.get('augment_test_data', False)): raise ValueError( 'Epistemic uncertainty evaluation not supported if augmentation ' 'is not enabled for test data.') # Generally, if sanitize is True molecules are de-kekulized. Augmentation # preserves the "kekulization", so if it is used, test data should be # sanitized or canonicalized. smiles_test_dataset = SMILESDataset( smi_filepath, smiles_language=smiles_language, padding_length=params.get('smiles_padding_length', None), padding=params.get('padd_smiles', True), add_start_and_stop=params.get('add_start_stop_token', True), augment=params.get('augment_test_smiles', False), canonical=params.get('test_canonical', False), kekulize=params.get('test_kekulize', False), all_bonds_explicit=params.get('test_all_bonds_explicit', False), all_hs_explicit=params.get('test_all_hs_explicit', False), randomize=False, remove_bonddir=params.get('test_remove_bonddir', False), remove_chirality=params.get('test_remove_chirality', False), selfies=params.get('selfies', False), sanitize=params.get('test_sanitize', False)) # Dump eventually modified SMILES Language smiles_language.save(os.path.join(model_dir, 'smiles_language.pkl')) logger.info(smiles_dataset._dataset.transform) logger.info(smiles_test_dataset._dataset.transform) # include arg label_columns if data file has any unwanted columns (such as index) to be ignored. test_dataset = AnnotatedDataset(annotations_filepath=test_scores_filepath, dataset=smiles_test_dataset, device=get_device()) test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=params['batch_size'], shuffle=False, drop_last=False, num_workers=params.get( 'num_workers', 0)) if params.get('confidence', False): smiles_conf_dataset = SMILESDataset( smi_filepath, smiles_language=smiles_language, padding_length=params.get('smiles_padding_length', None), padding=params.get('padd_smiles', True), add_start_and_stop=params.get('add_start_stop_token', True), augment=True, # Natively true for epistemic uncertainity estimate canonical=params.get('canonical', False), kekulize=params.get('kekulize', False), all_bonds_explicit=params.get('all_bonds_explicit', False), all_hs_explicit=params.get('all_hs_explicit', False), randomize=params.get('randomize', False), remove_bonddir=params.get('remove_bonddir', False), remove_chirality=params.get('remove_chirality', False), selfies=params.get('selfies', False)) conf_dataset = AnnotatedDataset( annotations_filepath=test_scores_filepath, dataset=smiles_conf_dataset, device=get_device()) conf_loader = torch.utils.data.DataLoader( dataset=conf_dataset, batch_size=params['batch_size'], shuffle=False, drop_last=False, num_workers=params.get('num_workers', 0)) if not params.get('embedding', 'learned') == 'pretrained': params.update( {'smiles_vocabulary_size': smiles_language.number_of_tokens}) device = get_device() logger.info(f'Device for data loader is {train_dataset.device} and for ' f'model is {device}') save_top_model = os.path.join(model_dir, 'weights/{}_{}_{}.pt') model = MODEL_FACTORY[params.get('model_fn', 'mca')](params).to(device) logger.info(model) logger.info(model.loss_fn.class_weights) logger.info('Parameters follow') for name, param in model.named_parameters(): logger.info((name, param.shape)) num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) params.update({'number_of_parameters': num_params}) logger.info(f'Number of parameters {num_params}') # Define optimizer optimizer = (OPTIMIZER_FACTORY[params.get('optimizer', 'adam')](model.parameters(), lr=params.get( 'lr', 0.00001))) # Overwrite params.json file with updated parameters. with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp: json.dump(params, fp) # Start training logger.info('Training about to start...\n') t = time() min_loss, max_roc_auc = 1000000, 0 max_precision_recall_score = 0 for epoch in range(params['epochs']): model.train() logger.info(params_filepath.split('/')[-1]) logger.info(f"== Epoch [{epoch}/{params['epochs']}] ==") train_loss = 0 for ind, (smiles, y) in enumerate(train_loader): smiles = torch.squeeze(smiles.to(device)) # Transform smiles to FP if needed if params.get('model_fn', 'mca') == 'dense': smiles = smiles_tensor_batch_to_fp(smiles).to(device) y_hat, pred_dict = model(smiles) loss = model.loss(y_hat, y.to(device)) optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.item() logger.info('\t **** TRAINING **** ' f"Epoch [{epoch + 1}/{params['epochs']}], " f'loss: {train_loss / len(train_loader):.5f}. ' f'This took {time() - t:.1f} secs.') t = time() # Measure validation performance model.eval() with torch.no_grad(): test_loss = 0 predictions = [] labels = [] for ind, (smiles, y) in enumerate(test_loader): smiles = torch.squeeze(smiles.to(device)) # Transform smiles to FP if needed if params.get('model_fn', 'mca') == 'dense': smiles = smiles_tensor_batch_to_fp(smiles).to(device) y_hat, pred_dict = model(smiles) predictions.append(y_hat) # Copy y tensor since loss function applies downstream # modification labels.append(y.clone()) loss = model.loss(y_hat, y.to(device)) test_loss += loss.item() predictions = torch.cat(predictions, dim=0).flatten().cpu().numpy() labels = torch.cat(labels, dim=0).flatten().cpu().numpy() # Remove NaNs from labels to compute scores predictions = predictions[~np.isnan(labels)] labels = labels[~np.isnan(labels)] test_loss_a = test_loss / len(test_loader) fpr, tpr, _ = roc_curve(labels, predictions) test_roc_auc_a = auc(fpr, tpr) # calculations for visualization plot precision, recall, _ = precision_recall_curve(labels, predictions) # score for precision vs accuracy test_precision_recall_score = average_precision_score( labels, predictions) logger.info( f"\t **** TEST **** Epoch [{epoch + 1}/{params['epochs']}], " f'loss: {test_loss_a:.5f}, , roc_auc: {test_roc_auc_a:.5f}, ' f'avg precision-recall score: {test_precision_recall_score:.5f}') info = { 'test_auc': test_roc_auc_a, 'train_loss': train_loss / len(train_loader), 'test_loss': test_loss_a, 'test_auc': test_roc_auc_a, 'best_test_auc': max_roc_auc, 'test_precision_recall_score': test_precision_recall_score, 'best_precision_recall_score': max_precision_recall_score, } def save(path, metric, typ, val=None): model.save(path.format(typ, metric, params.get('model_fn', 'mca'))) if typ == 'best': logger.info(f'\t New best performance in {metric}' f' with value : {val:.7f} in epoch: {epoch+1}') if test_roc_auc_a > max_roc_auc: max_roc_auc = test_roc_auc_a info.update({'best_test_auc': max_roc_auc}) save(save_top_model, 'ROC-AUC', 'best', max_roc_auc) np.save(os.path.join(model_dir, 'results', 'best_predictions.npy'), predictions) with open(os.path.join(model_dir, 'results', 'metrics.json'), 'w') as f: json.dump(info, f) if params.get('confidence', False): # Compute uncertainity estimates and save them epistemic_conf = monte_carlo_dropout(model, regime='loader', loader=conf_loader) aleatoric_conf = test_time_augmentation(model, regime='loader', loader=conf_loader) np.save( os.path.join(model_dir, 'results', 'epistemic_conf.npy'), epistemic_conf) np.save( os.path.join(model_dir, 'results', 'aleatoric_conf.npy'), aleatoric_conf) if test_precision_recall_score > max_precision_recall_score: max_precision_recall_score = test_precision_recall_score info.update( {'best_precision_recall_score': max_precision_recall_score}) save(save_top_model, 'precision-recall score', 'best', max_precision_recall_score) if test_loss_a < min_loss: min_loss = test_loss_a save(save_top_model, 'loss', 'best', min_loss) ep_loss = epoch if (epoch + 1) % params.get('save_model', 100) == 0: save(save_top_model, 'epoch', str(epoch)) logger.info('Overall best performances are: \n \t' f'Loss = {min_loss:.4f} in epoch {ep_loss} ') save(save_top_model, 'training', 'done') logger.info('Done with training, models saved, shutting down.')
def main(model_path, smi_filepath, labels_filepath, output_folder, smiles_language_filepath, model_id, confidence): logging.basicConfig(level=logging.INFO, format='%(message)s') logger = logging.getLogger('eval_toxicity') logger.setLevel(logging.INFO) disable_rdkit_logging() # Process parameter file: params = {} with open(os.path.join(model_path, 'model_params.json'), 'r') as fp: params.update(json.load(fp)) # Create model directory os.makedirs(output_folder, exist_ok=True) if model_id not in MODEL_FACTORY.keys(): raise KeyError( f'Model ID: Pass one of {MODEL_FACTORY.keys()}, not {model_id}') device = get_device() weights_path = os.path.join(model_path, 'weights', f'best_ROC-AUC_{model_id}.pt') # Restore model model = MODEL_FACTORY[model_id](params).to(device) if os.path.isfile(weights_path): try: model.load(weights_path, map_location=device) except Exception: logger.error(f'Error in model restoring from {weights_path}') else: logger.info(f'Did not find weights at {weights_path}, ' f'name weights: "best_ROC-AUC_{model_id}.pt".') model.eval() logger.info('Model restored. Model specs & parameters follow') for name, param in model.named_parameters(): logger.info((name, param.shape)) # Load language if smiles_language_filepath == '.': smiles_language_filepath = os.path.join(model_path, 'smiles_language.pkl') smiles_language = SMILESLanguage.load(smiles_language_filepath) # Assemble datasets smiles_dataset = SMILESDataset( smi_filepath, smiles_language=smiles_language, padding_length=params.get('smiles_padding_length', None), padding=params.get('padd_smiles', True), add_start_and_stop=params.get('add_start_stop_token', True), augment=params.get('augment_test_smiles', False), canonical=params.get('test_canonical', False), kekulize=params.get('test_kekulize', False), all_bonds_explicit=params.get('test_all_bonds_explicit', False), all_hs_explicit=params.get('test_all_hs_explicit', False), randomize=False, remove_bonddir=params.get('test_remove_bonddir', False), remove_chirality=params.get('test_remove_chirality', False), selfies=params.get('selfies', False), sanitize=params.get('test_sanitize', False), ) logger.info( f'SMILES Padding length is {smiles_dataset._dataset.padding_length}.' 'Consider setting manually if this looks wrong.') dataset = AnnotatedDataset(annotations_filepath=labels_filepath, dataset=smiles_dataset, device=device) loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=10, shuffle=False, drop_last=False, num_workers=0) if confidence: smiles_aleatoric_dataset = SMILESDataset( smi_filepath, smiles_language=smiles_language, padding_length=params.get('smiles_padding_length', None), padding=params.get('padd_smiles', True), add_start_and_stop=params.get('add_start_stop_token', True), augment=True, # Natively true for aleatoric uncertainity estimate canonical=params.get('test_canonical', False), kekulize=params.get('test_kekulize', False), all_bonds_explicit=params.get('test_all_bonds_explicit', False), all_hs_explicit=params.get('test_all_hs_explicit', False), remove_bonddir=params.get('test_remove_bonddir', False), remove_chirality=params.get('test_remove_chirality', False), selfies=params.get('selfies', False), sanitize=params.get('test_sanitize', False), ) ale_dataset = AnnotatedDataset(annotations_filepath=labels_filepath, dataset=smiles_aleatoric_dataset, device=device) ale_loader = torch.utils.data.DataLoader(dataset=ale_dataset, batch_size=10, shuffle=False, drop_last=False, num_workers=0) smiles_epi_dataset = SMILESDataset( smi_filepath, smiles_language=smiles_language, padding_length=params.get('smiles_padding_length', None), padding=params.get('padd_smiles', True), add_start_and_stop=params.get('add_start_stop_token', True), augment=False, # Natively false for epistemic uncertainity estimate canonical=params.get('test_canonical', False), kekulize=params.get('test_kekulize', False), all_bonds_explicit=params.get('test_all_bonds_explicit', False), all_hs_explicit=params.get('test_all_hs_explicit', False), remove_bonddir=params.get('test_remove_bonddir', False), remove_chirality=params.get('test_remove_chirality', False), selfies=params.get('selfies', False), sanitize=params.get('test_sanitize', False), ) epi_dataset = AnnotatedDataset(annotations_filepath=labels_filepath, dataset=smiles_epi_dataset, device=device) epi_loader = torch.utils.data.DataLoader(dataset=epi_dataset, batch_size=10, shuffle=False, drop_last=False, num_workers=0) logger.info(f'Device for data loader is {dataset.device} and for ' f'model is {device}') # Start evaluation logger.info('Evaluation about to start...\n') preds, labels, attention_scores, smiles = [], [], [], [] for idx, (smiles_batch, labels_batch) in enumerate(loader): pred, pred_dict = model(smiles_batch.to(device)) preds.extend(pred.detach().squeeze().tolist()) attention_scores.extend( torch.stack(pred_dict['smiles_attention'], dim=1).detach()) smiles.extend([ smiles_language.token_indexes_to_smiles(s.tolist()) for s in smiles_batch ]) labels.extend(labels_batch.squeeze().tolist()) # Scores are now 3D: num_samples x num_att_layers x padding_length attention = torch.stack(attention_scores, dim=0).numpy() if confidence: # Compute uncertainity estimates and save them epistemic_conf, epistemic_pred = monte_carlo_dropout(model, regime='loader', loader=epi_loader) aleatoric_conf, aleatoric_pred = test_time_augmentation( model, regime='loader', loader=ale_loader) epi_conf_df = pd.DataFrame( data=epistemic_conf.numpy(), columns=[ f'epistemic_conf_{i}' for i in range(epistemic_conf.shape[1]) ], ) ale_conf_df = pd.DataFrame( data=aleatoric_conf.numpy(), columns=[ f'aleatoric_conf_{i}' for i in range(aleatoric_conf.shape[1]) ], ) epi_pred_df = pd.DataFrame( data=epistemic_pred.numpy(), columns=[ f'epistemic_pred_{i}' for i in range(epistemic_pred.shape[1]) ], ) ale_pred_df = pd.DataFrame( data=aleatoric_pred.numpy(), columns=[ f'aleatoric_pred_{i}' for i in range(aleatoric_pred.shape[1]) ], ) logger.info(f'Shape of attention scores {attention.shape}.') np.save(os.path.join(output_folder, 'attention_raw.npy'), attention) attention_avg = np.mean(attention, axis=1) att_df = pd.DataFrame( data=attention_avg, columns=[f'att_idx_{i}' for i in range(attention_avg.shape[1])], ) pred_df = pd.DataFrame( data=preds, columns=[f'pred_{i}' for i in range(len(preds[0]))], ) lab_df = pd.DataFrame( data=labels, columns=[f'label_{i}' for i in range(len(labels[0]))], ) df = pd.concat([pred_df, lab_df], axis=1) if confidence: df = pd.concat( [df, epi_conf_df, ale_conf_df, epi_pred_df, ale_pred_df], axis=1) df = pd.concat([df, att_df], axis=1) df.insert(0, 'SMILES', smiles) df.to_csv(os.path.join(output_folder, 'results.csv'), index=False) logger.info('Done, shutting down.')