示例#1
0
    def test_all_base_for_indexed_methods(self):

        with TestFileContent(self.smiles_content) as smiles_file:
            with TestFileContent(self.annotated_content) as annotation_file:
                smiles_dataset = SMILESTokenizerDataset(
                    smiles_file.filename,
                    add_start_and_stop=True,
                    backend='eager')
                annotated_dataset = AnnotatedDataset(
                    annotation_file.filename,
                    dataset=smiles_dataset,
                    index_col=0,
                    label_columns=['label_1'],
                )
                duplicate_ds = AnnotatedDataset(
                    annotation_file.filename,
                    dataset=smiles_dataset + smiles_dataset,
                )

        all_keys = [
            row.split(',')[-1]
            for row in self.annotated_content.split(os.linesep)[1:]
        ]

        for ds, keys in [
            (annotated_dataset, all_keys),
        ]:
            index = -1
            self._test_indexed(ds, keys, index)

        # duplicates in datasource can be checked directly
        self.assertTrue(duplicate_ds.datasource.has_duplicate_keys)
        # DataFrame is the dataset
        self.assertFalse(duplicate_ds.has_duplicate_keys)
示例#2
0
    def test___getitem___from_indexed_annotation(self) -> None:
        """Test __getitem__ with index in the annotation file."""
        with TestFileContent(self.smiles_content) as smiles_file:
            with TestFileContent(self.annotated_content) as annotation_file:
                smiles_dataset = SMILESTokenizerDataset(
                    smiles_file.filename,
                    add_start_and_stop=True,
                    backend='eager')
                annotated_dataset = AnnotatedDataset(annotation_file.filename,
                                                     dataset=smiles_dataset)

        pad_index = smiles_dataset.smiles_language.padding_index
        start_index = smiles_dataset.smiles_language.start_index
        stop_index = smiles_dataset.smiles_language.stop_index
        c_index = smiles_dataset.smiles_language.token_to_index['C']
        o_index = smiles_dataset.smiles_language.token_to_index['O']
        n_index = smiles_dataset.smiles_language.token_to_index['N']
        s_index = smiles_dataset.smiles_language.token_to_index['S']
        # test first sample
        smiles_tokens, labels = annotated_dataset[0]
        self.assertEqual(
            smiles_tokens.numpy().flatten().tolist(),
            [pad_index, start_index, c_index, c_index, o_index, stop_index],
        )
        self.assertTrue(
            np.allclose(labels.numpy().flatten().tolist(), [2.3, 3.4]))
        # test last sample
        smiles_tokens, labels = annotated_dataset[2]
        self.assertEqual(
            smiles_tokens.numpy().flatten().tolist(),
            [start_index, n_index, c_index, c_index, s_index, stop_index],
        )
        self.assertTrue(
            np.allclose(labels.numpy().flatten().tolist(), [6.7, 7.8]))
示例#3
0
    def test_return_integer_index(self) -> None:
        """Test __getitem__ with index in dataset."""
        with TestFileContent(self.smiles_content) as smiles_file:
            with TestFileContent(self.annotated_content) as annotation_file:
                smiles_dataset = SMILESTokenizerDataset(smiles_file.filename)
                # default
                default_annotated_dataset = AnnotatedDataset(
                    annotation_file.filename, dataset=smiles_dataset)
                # outer
                indexed_annotated_dataset = indexed(default_annotated_dataset)
                # inner
                annotated_dataset = AnnotatedDataset(
                    annotation_file.filename, dataset=indexed(smiles_dataset))

        # default
        # relevant to check that `smiles_dataset` was not mutated from
        # the `indexed(smiles_dataset)` call
        smiles_tokens, labels = default_annotated_dataset[2]
        self.check_CHEMBL602(smiles_tokens, labels)

        # outer __getitem__
        (smiles_tokens, labels), sample_index = indexed_annotated_dataset[2]
        self.assertEqual(sample_index, 2)
        self.check_CHEMBL602(smiles_tokens, labels)
        # outer get_item_from_key
        (
            (smiles_tokens, labels),
            sample_index,
        ) = indexed_annotated_dataset.get_item_from_key('CHEMBL602')
        self.assertEqual(sample_index, 2)
        self.check_CHEMBL602(smiles_tokens, labels)

        # inner __getitem__
        (smiles_tokens, sample_index), labels = annotated_dataset[0]
        self.assertEqual(sample_index, 0)
        # inner __getitem__ with different index in smiles_dataset
        (smiles_tokens, sample_index), labels = annotated_dataset[2]
        self.assertEqual(sample_index, 3)
        self.check_CHEMBL602(smiles_tokens, labels)
        # inner get_item_from_key
        (smiles_tokens, sample_index
         ), labels = annotated_dataset.get_item_from_key('CHEMBL602')
        self.assertEqual(sample_index, 3)
        self.check_CHEMBL602(smiles_tokens, labels)
示例#4
0
    def test_return_key_index_stacked(self) -> None:
        """Test __getitem__ with key in dataset."""
        with TestFileContent(self.smiles_content) as smiles_file:
            with TestFileContent(self.annotated_content) as annotation_file:
                smiles_dataset = keyed(
                    indexed(SMILESTokenizerDataset(smiles_file.filename, )))
                annotated_dataset = indexed(
                    keyed(
                        AnnotatedDataset(
                            annotation_file.filename,
                            dataset=smiles_dataset,
                        )))

        (smiles_tokens, smiles_index), smiles_key = smiles_dataset[3]
        self.assertEqual(smiles_key, 'CHEMBL602')
        self.assertEqual(smiles_index, 3)
        self.check_CHEMBL602(smiles_tokens)
        (smiles_tokens, smiles_index
         ), smiles_key = smiles_dataset.get_item_from_key('CHEMBL602')
        self.assertEqual(smiles_key, 'CHEMBL602')
        self.assertEqual(smiles_index, 3)
        self.check_CHEMBL602(smiles_tokens)

        (
            (
                (((smiles_tokens, smiles_index), smiles_key), labels),  # inner
                annotation_key,
            ),
            annotation_index,
        ) = annotated_dataset[2]
        self.assertEqual(smiles_key, 'CHEMBL602')
        self.assertEqual(smiles_index, 3)
        self.assertEqual(annotation_index, 2)
        self.assertEqual(annotation_key, 'CHEMBL602')
        (
            (
                (((smiles_tokens, smiles_index), smiles_key), labels),  # inner
                annotation_key,
            ),
            annotation_index,
        ) = annotated_dataset.get_item_from_key('CHEMBL602')
        self.assertEqual(smiles_key, 'CHEMBL602')
        self.assertEqual(smiles_index, 3)
        self.assertEqual(annotation_index, 2)
        self.assertEqual(annotation_key, 'CHEMBL602')
示例#5
0
def main(
    train_scores_filepath, test_scores_filepath, smi_filepath, model_path,
    params_filepath, training_name
):

    logging.basicConfig(level=logging.INFO, format='%(message)s')
    logger = logging.getLogger(f'{training_name}')
    logger.setLevel(logging.INFO)
    disable_rdkit_logging()
    device = get_device()

    # Restore pretrained model
    logger.info('Start model restoring.')
    try:
        with open(os.path.join(model_path, 'model_params.json'), 'r') as fp:
            params = json.load(fp)
        smiles_language = SMILESLanguage.load(
            os.path.join(model_path, 'smiles_language.pkl')
        )
        model = MODEL_FACTORY[params.get('model_fn', 'mca')](params).to(device)
        # Try weight restoring
        try:
            weight_path = os.path.join(
                model_path, 'weights',
                params.get('weights_name', 'best_ROC-AUC_mca.pt')
            )
            model.load(weight_path)
        except Exception:
            try:
                wp = os.listdir(os.path.join(model_path, 'weights'))[0]
                logger.info(
                    f"Weights {weight_path} not found. Try restore {wp}"
                )
                model.load(os.path.join(model_path, 'weights', wp))
            except Exception:
                raise Exception('Error in weight loading.')

    except Exception:
        raise Exception(
            'Error in model restoring. model_path should point to the model root '
            'folder that contains a model_params.json, a smiles_language.pkl and a '
            'weights folder.'
        )
    logger.info('Model restored. Now starting to craft it for the task')

    # Process parameter file:
    model_params = {}
    with open(params_filepath) as fp:
        model_params.update(json.load(fp))

    model = update_mca_model(model, model_params)

    logger.info('Model set up.')
    for idx, (name, param) in enumerate(model.named_parameters()):
        logger.info(
            (idx, name, param.shape, f'Gradients: {param.requires_grad}')
        )

    ft_model_path = os.path.join(model_path, 'finetuned')
    os.makedirs(os.path.join(ft_model_path, 'weights'), exist_ok=True)
    os.makedirs(os.path.join(ft_model_path, 'results'), exist_ok=True)

    logger.info('Now start data preprocessing...')

    # Assemble datasets
    smiles_dataset = SMILESDataset(
        smi_filepath,
        smiles_language=smiles_language,
        padding_length=params.get('smiles_padding_length', None),
        padding=params.get('padd_smiles', True),
        add_start_and_stop=params.get('add_start_stop_token', True),
        augment=params.get('augment_smiles', False),
        canonical=params.get('canonical', False),
        kekulize=params.get('kekulize', False),
        all_bonds_explicit=params.get('all_bonds_explicit', False),
        all_hs_explicit=params.get('all_hs_explicit', False),
        randomize=params.get('randomize', False),
        remove_bonddir=params.get('remove_bonddir', False),
        remove_chirality=params.get('remove_chirality', False),
        selfies=params.get('selfies', False),
        sanitize=params.get('sanitize', True)
    )

    train_dataset = AnnotatedDataset(
        annotations_filepath=train_scores_filepath,
        dataset=smiles_dataset,
        device=get_device()
    )
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=params['batch_size'],
        shuffle=True,
        drop_last=True,
        num_workers=params.get('num_workers', 0)
    )

    # Generally, if sanitize is True molecules are de-kekulized. Augmentation
    # preserves the "kekulization", so if it is used, test data should be
    # sanitized or canonicalized.
    smiles_test_dataset = SMILESDataset(
        smi_filepath,
        smiles_language=smiles_language,
        padding_length=params.get('smiles_padding_length', None),
        padding=params.get('padd_smiles', True),
        add_start_and_stop=params.get('add_start_stop_token', True),
        augment=params.get('augment_test_smiles', False),
        canonical=params.get('test_canonical', False),
        kekulize=params.get('test_kekulize', False),
        all_bonds_explicit=params.get('test_all_bonds_explicit', False),
        all_hs_explicit=params.get('test_all_hs_explicit', False),
        randomize=False,
        remove_bonddir=params.get('test_remove_bonddir', False),
        remove_chirality=params.get('test_remove_chirality', False),
        selfies=params.get('selfies', False),
        sanitize=params.get('test_sanitize', False)
    )

    # Dump eventually modified SMILES Language
    smiles_language.save(os.path.join(ft_model_path, 'smiles_language.pkl'))

    test_dataset = AnnotatedDataset(
        annotations_filepath=test_scores_filepath,
        dataset=smiles_test_dataset,
        device=get_device()
    )

    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=params['batch_size'],
        shuffle=False,
        drop_last=False,
        num_workers=params.get('num_workers', 0)
    )

    save_top_model = os.path.join(ft_model_path, 'weights/{}_{}_{}.pt')

    num_t_params = sum(
        p.numel() for p in model.parameters() if p.requires_grad
    )
    num_params = sum(p.numel() for p in model.parameters())
    params.update({'params': num_params, 'trainable_params': num_t_params})
    logger.info(
        f'Number of parameters: {num_params} (trainable {num_t_params}).'
    )

    # Define optimizer only for those layers which require gradients
    optimizer = (
        OPTIMIZER_FACTORY[params.get('optimizer', 'adam')](
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=params.get('lr', 0.00001)
        )
    )
    # Dump params.json file with updated parameters.
    with open(os.path.join(ft_model_path, 'model_params.json'), 'w') as fp:
        json.dump(params, fp)

    # Start training
    logger.info('Training about to start...\n')
    t = time()
    min_loss, max_roc_auc = 1000000, 0
    max_precision_recall_score = 0

    for epoch in range(params['epochs']):

        model.train()
        logger.info(params_filepath.split('/')[-1])
        logger.info(f"== Epoch [{epoch}/{params['epochs']}] ==")
        train_loss = 0

        for ind, (smiles, y) in enumerate(train_loader):

            smiles = torch.squeeze(smiles.to(device))
            y_hat, pred_dict = model(smiles)

            loss = model.loss(y_hat, y.to(device))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        logger.info(
            '\t **** TRAINING ****   '
            f"Epoch [{epoch + 1}/{params['epochs']}], "
            f'loss: {train_loss / len(train_loader):.5f}. '
            f'This took {time() - t:.1f} secs.'
        )
        t = time()

        # Measure validation performance
        model.eval()
        with torch.no_grad():
            test_loss = 0
            predictions = []
            labels = []
            for ind, (smiles, y) in enumerate(test_loader):

                smiles = torch.squeeze(smiles.to(device))

                y_hat, pred_dict = model(smiles)
                predictions.append(y_hat)
                # Copy y tensor since loss function applies downstream modification
                labels.append(y.clone())
                loss = model.loss(y_hat, y.to(device))
                test_loss += loss.item()

        predictions = torch.cat(predictions, dim=0).flatten().cpu().numpy()
        labels = torch.cat(labels, dim=0).flatten().cpu().numpy()

        # Remove NaNs from labels to compute scores
        predictions = predictions[~np.isnan(labels)]
        labels = labels[~np.isnan(labels)]
        test_loss_a = test_loss / len(test_loader)
        fpr, tpr, _ = roc_curve(labels, predictions)
        test_roc_auc_a = auc(fpr, tpr)

        # calculations for visualization plot
        precision, recall, _ = precision_recall_curve(labels, predictions)
        # score for precision vs accuracy
        test_precision_recall_score = average_precision_score(
            labels, predictions
        )

        logger.info(
            f"\t **** TEST **** Epoch [{epoch + 1}/{params['epochs']}], "
            f'loss: {test_loss_a:.5f}, , roc_auc: {test_roc_auc_a:.5f}, '
            f'avg precision-recall score: {test_precision_recall_score:.5f}'
        )
        info = {
            'test_auc': test_roc_auc_a,
            'train_loss': train_loss / len(train_loader),
            'test_loss': test_loss_a,
            'test_auc': test_roc_auc_a,
            'best_test_auc': max_roc_auc,
            'test_precision_recall_score': test_precision_recall_score,
            'best_precision_recall_score': max_precision_recall_score,
        }

        def save(path, metric, typ, val=None):
            model.save(path.format(typ, metric, params.get('model_fn', 'mca')))
            if typ == 'best':
                logger.info(
                    f'\t New best performance in {metric}'
                    f' with value : {val:.7f} in epoch: {epoch+1}'
                )

        if test_roc_auc_a > max_roc_auc:
            max_roc_auc = test_roc_auc_a
            info.update({'best_test_auc': max_roc_auc})
            save(save_top_model, 'ROC-AUC', 'best', max_roc_auc)
            np.save(
                os.path.join(ft_model_path, 'results', 'best_predictions.npy'),
                predictions
            )
            with open(
                os.path.join(ft_model_path, 'results', 'metrics.json'), 'w'
            ) as f:
                json.dump(info, f)

        if test_precision_recall_score > max_precision_recall_score:
            max_precision_recall_score = test_precision_recall_score
            info.update(
                {'best_precision_recall_score': max_precision_recall_score}
            )
            save(
                save_top_model, 'precision-recall score', 'best',
                max_precision_recall_score
            )

        if test_loss_a < min_loss:
            min_loss = test_loss_a
            save(save_top_model, 'loss', 'best', min_loss)
            ep_loss = epoch

        if (epoch + 1) % params.get('save_model', 100) == 0:
            save(save_top_model, 'epoch', str(epoch))

    logger.info(
        'Overall best performances are: \n \t'
        f'Loss = {min_loss:.4f} in epoch {ep_loss} '
    )
    save(save_top_model, 'training', 'done')
    logger.info('Done with training, models saved, shutting down.')
示例#6
0
def main(train_scores_filepath,
         test_scores_filepath,
         smi_filepath,
         smiles_language_filepath,
         model_path,
         params_filepath,
         training_name,
         embedding_path=None):

    logging.basicConfig(level=logging.INFO, format='%(message)s')
    logger = logging.getLogger(f'{training_name}')
    logger.setLevel(logging.INFO)
    disable_rdkit_logging()

    # Process parameter file:
    params = {}
    with open(params_filepath) as fp:
        params.update(json.load(fp))

    if embedding_path:
        params['embedding_path'] = embedding_path

    # Create model directory and dump files
    model_dir = os.path.join(model_path, training_name)
    os.makedirs(os.path.join(model_dir, 'weights'), exist_ok=True)
    os.makedirs(os.path.join(model_dir, 'results'), exist_ok=True)
    with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp:
        json.dump(params, fp, indent=4)

    logger.info('Start data preprocessing...')
    smiles_language = SMILESLanguage.load(smiles_language_filepath)

    # Prepare FP processing
    if params.get('model_fn', 'mca') == 'dense':
        morgan_transform = Compose([
            SMILESToMorganFingerprints(radius=params.get('fp_radius', 2),
                                       bits=params.get('num_drug_features',
                                                       512),
                                       chirality=params.get(
                                           'fp_chirality', True)),
            ToTensor(get_device())
        ])

        def smiles_tensor_batch_to_fp(smiles):
            """ To abuse SMILES dataset for FP usage"""
            out = torch.Tensor(smiles.shape[0],
                               params.get('num_drug_features', 256))
            for ind, tensor in enumerate(smiles):
                smiles = smiles_language.token_indexes_to_smiles(
                    tensor.tolist())
                out[ind, :] = torch.squeeze(morgan_transform(smiles))
            return out

    # Assemble datasets
    smiles_dataset = SMILESDataset(
        smi_filepath,
        smiles_language=smiles_language,
        padding_length=params.get('smiles_padding_length', None),
        padding=params.get('padd_smiles', True),
        add_start_and_stop=params.get('add_start_stop_token', True),
        augment=params.get('augment_smiles', False),
        canonical=params.get('canonical', False),
        kekulize=params.get('kekulize', False),
        all_bonds_explicit=params.get('all_bonds_explicit', False),
        all_hs_explicit=params.get('all_hs_explicit', False),
        randomize=params.get('randomize', False),
        remove_bonddir=params.get('remove_bonddir', False),
        remove_chirality=params.get('remove_chirality', False),
        selfies=params.get('selfies', False),
        sanitize=params.get('sanitize', True))

    # include arg label_columns if data file has any unwanted columns (such as index) to be ignored.
    train_dataset = AnnotatedDataset(
        annotations_filepath=train_scores_filepath,
        dataset=smiles_dataset,
        device=get_device())
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=params['batch_size'],
                                               shuffle=True,
                                               drop_last=True,
                                               num_workers=params.get(
                                                   'num_workers', 0))

    if (params.get('uncertainty', True)
            and params.get('augment_test_data', False)):
        raise ValueError(
            'Epistemic uncertainty evaluation not supported if augmentation '
            'is not enabled for test data.')

    # Generally, if sanitize is True molecules are de-kekulized. Augmentation
    # preserves the "kekulization", so if it is used, test data should be
    # sanitized or canonicalized.
    smiles_test_dataset = SMILESDataset(
        smi_filepath,
        smiles_language=smiles_language,
        padding_length=params.get('smiles_padding_length', None),
        padding=params.get('padd_smiles', True),
        add_start_and_stop=params.get('add_start_stop_token', True),
        augment=params.get('augment_test_smiles', False),
        canonical=params.get('test_canonical', False),
        kekulize=params.get('test_kekulize', False),
        all_bonds_explicit=params.get('test_all_bonds_explicit', False),
        all_hs_explicit=params.get('test_all_hs_explicit', False),
        randomize=False,
        remove_bonddir=params.get('test_remove_bonddir', False),
        remove_chirality=params.get('test_remove_chirality', False),
        selfies=params.get('selfies', False),
        sanitize=params.get('test_sanitize', False))

    # Dump eventually modified SMILES Language
    smiles_language.save(os.path.join(model_dir, 'smiles_language.pkl'))

    logger.info(smiles_dataset._dataset.transform)
    logger.info(smiles_test_dataset._dataset.transform)

    # include arg label_columns if data file has any unwanted columns (such as index) to be ignored.
    test_dataset = AnnotatedDataset(annotations_filepath=test_scores_filepath,
                                    dataset=smiles_test_dataset,
                                    device=get_device())

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=params['batch_size'],
                                              shuffle=False,
                                              drop_last=False,
                                              num_workers=params.get(
                                                  'num_workers', 0))

    if params.get('confidence', False):
        smiles_conf_dataset = SMILESDataset(
            smi_filepath,
            smiles_language=smiles_language,
            padding_length=params.get('smiles_padding_length', None),
            padding=params.get('padd_smiles', True),
            add_start_and_stop=params.get('add_start_stop_token', True),
            augment=True,  # Natively true for epistemic uncertainity estimate
            canonical=params.get('canonical', False),
            kekulize=params.get('kekulize', False),
            all_bonds_explicit=params.get('all_bonds_explicit', False),
            all_hs_explicit=params.get('all_hs_explicit', False),
            randomize=params.get('randomize', False),
            remove_bonddir=params.get('remove_bonddir', False),
            remove_chirality=params.get('remove_chirality', False),
            selfies=params.get('selfies', False))
        conf_dataset = AnnotatedDataset(
            annotations_filepath=test_scores_filepath,
            dataset=smiles_conf_dataset,
            device=get_device())
        conf_loader = torch.utils.data.DataLoader(
            dataset=conf_dataset,
            batch_size=params['batch_size'],
            shuffle=False,
            drop_last=False,
            num_workers=params.get('num_workers', 0))

    if not params.get('embedding', 'learned') == 'pretrained':
        params.update(
            {'smiles_vocabulary_size': smiles_language.number_of_tokens})

    device = get_device()
    logger.info(f'Device for data loader is {train_dataset.device} and for '
                f'model is {device}')
    save_top_model = os.path.join(model_dir, 'weights/{}_{}_{}.pt')

    model = MODEL_FACTORY[params.get('model_fn', 'mca')](params).to(device)
    logger.info(model)
    logger.info(model.loss_fn.class_weights)

    logger.info('Parameters follow')
    for name, param in model.named_parameters():
        logger.info((name, param.shape))

    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    params.update({'number_of_parameters': num_params})
    logger.info(f'Number of parameters {num_params}')

    # Define optimizer
    optimizer = (OPTIMIZER_FACTORY[params.get('optimizer',
                                              'adam')](model.parameters(),
                                                       lr=params.get(
                                                           'lr', 0.00001)))
    # Overwrite params.json file with updated parameters.
    with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp:
        json.dump(params, fp)

    # Start training
    logger.info('Training about to start...\n')
    t = time()
    min_loss, max_roc_auc = 1000000, 0
    max_precision_recall_score = 0

    for epoch in range(params['epochs']):

        model.train()
        logger.info(params_filepath.split('/')[-1])
        logger.info(f"== Epoch [{epoch}/{params['epochs']}] ==")
        train_loss = 0

        for ind, (smiles, y) in enumerate(train_loader):

            smiles = torch.squeeze(smiles.to(device))
            # Transform smiles to FP if needed
            if params.get('model_fn', 'mca') == 'dense':
                smiles = smiles_tensor_batch_to_fp(smiles).to(device)

            y_hat, pred_dict = model(smiles)

            loss = model.loss(y_hat, y.to(device))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        logger.info('\t **** TRAINING ****   '
                    f"Epoch [{epoch + 1}/{params['epochs']}], "
                    f'loss: {train_loss / len(train_loader):.5f}. '
                    f'This took {time() - t:.1f} secs.')
        t = time()

        # Measure validation performance
        model.eval()
        with torch.no_grad():
            test_loss = 0
            predictions = []
            labels = []
            for ind, (smiles, y) in enumerate(test_loader):

                smiles = torch.squeeze(smiles.to(device))
                # Transform smiles to FP if needed
                if params.get('model_fn', 'mca') == 'dense':
                    smiles = smiles_tensor_batch_to_fp(smiles).to(device)

                y_hat, pred_dict = model(smiles)
                predictions.append(y_hat)
                # Copy y tensor since loss function applies downstream
                #   modification
                labels.append(y.clone())
                loss = model.loss(y_hat, y.to(device))
                test_loss += loss.item()

        predictions = torch.cat(predictions, dim=0).flatten().cpu().numpy()
        labels = torch.cat(labels, dim=0).flatten().cpu().numpy()

        # Remove NaNs from labels to compute scores
        predictions = predictions[~np.isnan(labels)]
        labels = labels[~np.isnan(labels)]
        test_loss_a = test_loss / len(test_loader)
        fpr, tpr, _ = roc_curve(labels, predictions)
        test_roc_auc_a = auc(fpr, tpr)

        # calculations for visualization plot
        precision, recall, _ = precision_recall_curve(labels, predictions)
        # score for precision vs accuracy
        test_precision_recall_score = average_precision_score(
            labels, predictions)

        logger.info(
            f"\t **** TEST **** Epoch [{epoch + 1}/{params['epochs']}], "
            f'loss: {test_loss_a:.5f}, , roc_auc: {test_roc_auc_a:.5f}, '
            f'avg precision-recall score: {test_precision_recall_score:.5f}')
        info = {
            'test_auc': test_roc_auc_a,
            'train_loss': train_loss / len(train_loader),
            'test_loss': test_loss_a,
            'test_auc': test_roc_auc_a,
            'best_test_auc': max_roc_auc,
            'test_precision_recall_score': test_precision_recall_score,
            'best_precision_recall_score': max_precision_recall_score,
        }

        def save(path, metric, typ, val=None):
            model.save(path.format(typ, metric, params.get('model_fn', 'mca')))
            if typ == 'best':
                logger.info(f'\t New best performance in {metric}'
                            f' with value : {val:.7f} in epoch: {epoch+1}')

        if test_roc_auc_a > max_roc_auc:
            max_roc_auc = test_roc_auc_a
            info.update({'best_test_auc': max_roc_auc})
            save(save_top_model, 'ROC-AUC', 'best', max_roc_auc)
            np.save(os.path.join(model_dir, 'results', 'best_predictions.npy'),
                    predictions)
            with open(os.path.join(model_dir, 'results', 'metrics.json'),
                      'w') as f:
                json.dump(info, f)
            if params.get('confidence', False):
                # Compute uncertainity estimates and save them
                epistemic_conf = monte_carlo_dropout(model,
                                                     regime='loader',
                                                     loader=conf_loader)
                aleatoric_conf = test_time_augmentation(model,
                                                        regime='loader',
                                                        loader=conf_loader)

                np.save(
                    os.path.join(model_dir, 'results', 'epistemic_conf.npy'),
                    epistemic_conf)
                np.save(
                    os.path.join(model_dir, 'results', 'aleatoric_conf.npy'),
                    aleatoric_conf)

        if test_precision_recall_score > max_precision_recall_score:
            max_precision_recall_score = test_precision_recall_score
            info.update(
                {'best_precision_recall_score': max_precision_recall_score})
            save(save_top_model, 'precision-recall score', 'best',
                 max_precision_recall_score)

        if test_loss_a < min_loss:
            min_loss = test_loss_a
            save(save_top_model, 'loss', 'best', min_loss)
            ep_loss = epoch

        if (epoch + 1) % params.get('save_model', 100) == 0:
            save(save_top_model, 'epoch', str(epoch))

    logger.info('Overall best performances are: \n \t'
                f'Loss = {min_loss:.4f} in epoch {ep_loss} ')
    save(save_top_model, 'training', 'done')
    logger.info('Done with training, models saved, shutting down.')
示例#7
0
def main(model_path, smi_filepath, labels_filepath, output_folder,
         smiles_language_filepath, model_id, confidence):

    logging.basicConfig(level=logging.INFO, format='%(message)s')
    logger = logging.getLogger('eval_toxicity')
    logger.setLevel(logging.INFO)
    disable_rdkit_logging()

    # Process parameter file:
    params = {}
    with open(os.path.join(model_path, 'model_params.json'), 'r') as fp:
        params.update(json.load(fp))

    # Create model directory
    os.makedirs(output_folder, exist_ok=True)

    if model_id not in MODEL_FACTORY.keys():
        raise KeyError(
            f'Model ID: Pass one of {MODEL_FACTORY.keys()}, not {model_id}')

    device = get_device()
    weights_path = os.path.join(model_path, 'weights',
                                f'best_ROC-AUC_{model_id}.pt')

    # Restore model
    model = MODEL_FACTORY[model_id](params).to(device)
    if os.path.isfile(weights_path):
        try:
            model.load(weights_path, map_location=device)
        except Exception:
            logger.error(f'Error in model restoring from {weights_path}')
    else:
        logger.info(f'Did not find weights at {weights_path}, '
                    f'name weights: "best_ROC-AUC_{model_id}.pt".')
    model.eval()

    logger.info('Model restored. Model specs & parameters follow')
    for name, param in model.named_parameters():
        logger.info((name, param.shape))

    # Load language
    if smiles_language_filepath == '.':
        smiles_language_filepath = os.path.join(model_path,
                                                'smiles_language.pkl')
    smiles_language = SMILESLanguage.load(smiles_language_filepath)

    # Assemble datasets
    smiles_dataset = SMILESDataset(
        smi_filepath,
        smiles_language=smiles_language,
        padding_length=params.get('smiles_padding_length', None),
        padding=params.get('padd_smiles', True),
        add_start_and_stop=params.get('add_start_stop_token', True),
        augment=params.get('augment_test_smiles', False),
        canonical=params.get('test_canonical', False),
        kekulize=params.get('test_kekulize', False),
        all_bonds_explicit=params.get('test_all_bonds_explicit', False),
        all_hs_explicit=params.get('test_all_hs_explicit', False),
        randomize=False,
        remove_bonddir=params.get('test_remove_bonddir', False),
        remove_chirality=params.get('test_remove_chirality', False),
        selfies=params.get('selfies', False),
        sanitize=params.get('test_sanitize', False),
    )
    logger.info(
        f'SMILES Padding length is {smiles_dataset._dataset.padding_length}.'
        'Consider setting manually if this looks wrong.')
    dataset = AnnotatedDataset(annotations_filepath=labels_filepath,
                               dataset=smiles_dataset,
                               device=device)
    loader = torch.utils.data.DataLoader(dataset=dataset,
                                         batch_size=10,
                                         shuffle=False,
                                         drop_last=False,
                                         num_workers=0)
    if confidence:
        smiles_aleatoric_dataset = SMILESDataset(
            smi_filepath,
            smiles_language=smiles_language,
            padding_length=params.get('smiles_padding_length', None),
            padding=params.get('padd_smiles', True),
            add_start_and_stop=params.get('add_start_stop_token', True),
            augment=True,  # Natively true for aleatoric uncertainity estimate
            canonical=params.get('test_canonical', False),
            kekulize=params.get('test_kekulize', False),
            all_bonds_explicit=params.get('test_all_bonds_explicit', False),
            all_hs_explicit=params.get('test_all_hs_explicit', False),
            remove_bonddir=params.get('test_remove_bonddir', False),
            remove_chirality=params.get('test_remove_chirality', False),
            selfies=params.get('selfies', False),
            sanitize=params.get('test_sanitize', False),
        )
        ale_dataset = AnnotatedDataset(annotations_filepath=labels_filepath,
                                       dataset=smiles_aleatoric_dataset,
                                       device=device)
        ale_loader = torch.utils.data.DataLoader(dataset=ale_dataset,
                                                 batch_size=10,
                                                 shuffle=False,
                                                 drop_last=False,
                                                 num_workers=0)
        smiles_epi_dataset = SMILESDataset(
            smi_filepath,
            smiles_language=smiles_language,
            padding_length=params.get('smiles_padding_length', None),
            padding=params.get('padd_smiles', True),
            add_start_and_stop=params.get('add_start_stop_token', True),
            augment=False,  # Natively false for epistemic uncertainity estimate
            canonical=params.get('test_canonical', False),
            kekulize=params.get('test_kekulize', False),
            all_bonds_explicit=params.get('test_all_bonds_explicit', False),
            all_hs_explicit=params.get('test_all_hs_explicit', False),
            remove_bonddir=params.get('test_remove_bonddir', False),
            remove_chirality=params.get('test_remove_chirality', False),
            selfies=params.get('selfies', False),
            sanitize=params.get('test_sanitize', False),
        )
        epi_dataset = AnnotatedDataset(annotations_filepath=labels_filepath,
                                       dataset=smiles_epi_dataset,
                                       device=device)
        epi_loader = torch.utils.data.DataLoader(dataset=epi_dataset,
                                                 batch_size=10,
                                                 shuffle=False,
                                                 drop_last=False,
                                                 num_workers=0)

    logger.info(f'Device for data loader is {dataset.device} and for '
                f'model is {device}')
    # Start evaluation
    logger.info('Evaluation about to start...\n')

    preds, labels, attention_scores, smiles = [], [], [], []
    for idx, (smiles_batch, labels_batch) in enumerate(loader):
        pred, pred_dict = model(smiles_batch.to(device))
        preds.extend(pred.detach().squeeze().tolist())
        attention_scores.extend(
            torch.stack(pred_dict['smiles_attention'], dim=1).detach())
        smiles.extend([
            smiles_language.token_indexes_to_smiles(s.tolist())
            for s in smiles_batch
        ])
        labels.extend(labels_batch.squeeze().tolist())
    # Scores are now 3D: num_samples x num_att_layers x padding_length
    attention = torch.stack(attention_scores, dim=0).numpy()

    if confidence:
        # Compute uncertainity estimates and save them
        epistemic_conf, epistemic_pred = monte_carlo_dropout(model,
                                                             regime='loader',
                                                             loader=epi_loader)
        aleatoric_conf, aleatoric_pred = test_time_augmentation(
            model, regime='loader', loader=ale_loader)
        epi_conf_df = pd.DataFrame(
            data=epistemic_conf.numpy(),
            columns=[
                f'epistemic_conf_{i}' for i in range(epistemic_conf.shape[1])
            ],
        )
        ale_conf_df = pd.DataFrame(
            data=aleatoric_conf.numpy(),
            columns=[
                f'aleatoric_conf_{i}' for i in range(aleatoric_conf.shape[1])
            ],
        )
        epi_pred_df = pd.DataFrame(
            data=epistemic_pred.numpy(),
            columns=[
                f'epistemic_pred_{i}' for i in range(epistemic_pred.shape[1])
            ],
        )
        ale_pred_df = pd.DataFrame(
            data=aleatoric_pred.numpy(),
            columns=[
                f'aleatoric_pred_{i}' for i in range(aleatoric_pred.shape[1])
            ],
        )

    logger.info(f'Shape of attention scores {attention.shape}.')
    np.save(os.path.join(output_folder, 'attention_raw.npy'), attention)
    attention_avg = np.mean(attention, axis=1)
    att_df = pd.DataFrame(
        data=attention_avg,
        columns=[f'att_idx_{i}' for i in range(attention_avg.shape[1])],
    )
    pred_df = pd.DataFrame(
        data=preds,
        columns=[f'pred_{i}' for i in range(len(preds[0]))],
    )
    lab_df = pd.DataFrame(
        data=labels,
        columns=[f'label_{i}' for i in range(len(labels[0]))],
    )
    df = pd.concat([pred_df, lab_df], axis=1)
    if confidence:
        df = pd.concat(
            [df, epi_conf_df, ale_conf_df, epi_pred_df, ale_pred_df], axis=1)
    df = pd.concat([df, att_df], axis=1)
    df.insert(0, 'SMILES', smiles)
    df.to_csv(os.path.join(output_folder, 'results.csv'), index=False)

    logger.info('Done, shutting down.')