Beispiel #1
0
    def load_pretrained_paccmann(self, params_file: str, lang_file: str,
                                 weights_file: str, batch_size: int,
                                 batch_mode: str):
        params = dict()
        with open(params_file, 'r') as f:
            params.update(json.load(f))
        params['batch_mode'] = batch_mode
        params['batch_size'] = batch_size

        self.selfies = params.get('selfies', False)

        self.device = get_device()
        self.smiles_language = SMILESLanguage.load(lang_file)

        self.gru_encoder = StackGRUEncoder(params).to(self.device)
        self.gru_decoder = StackGRUDecoder(params).to(self.device)
        self.gru_vae = TeacherVAE(self.gru_encoder,
                                  self.gru_decoder).to(self.device)
        self.gru_vae.load_state_dict(
            torch.load(weights_file, map_location=self.device))
        self.gru_vae.eval()

        transforms = []
        if self.selfies:
            transforms += [Selfies()]
        transforms += [
            SMILESToTokenIndexes(smiles_language=self.smiles_language)
        ]
        transforms += [ToTensor(device=self.device)]
        self.transform = Compose(transforms)
Beispiel #2
0
    def load_mca(self, model_path: str):
        """
        Restores pretrained MCA

        Arguments:
            model_path {String} -- Path to the model
        """

        # Restore model
        self.model_path = model_path
        with open(os.path.join(model_path, 'model_params.json')) as f:
            params = json.load(f)

        # Set up language and transforms
        self.smiles_language = SMILESLanguage.load(
            os.path.join(model_path, 'smiles_language.pkl')
        )
        self.transforms = self.compose_smiles_transforms(params)

        # Initialize and restore model weights
        self.model = MODEL_FACTORY['mca'](params)
        self.model.load(
            os.path.join(model_path, 'weights', 'best_ROC-AUC_mca.pt'),
            map_location=get_device()
        )
        self.model.eval()
Beispiel #3
0
 def __init__(self, filename, smiles_language):
     super().__init__()
     self.filename = filename
     with open(self.filename, 'r') as f:
         self.data = [line.strip().split('\t') for line in f.readlines()]
     self.data = [[(x1, x2), int(y)] for x1, x2, y in self.data]
     if isinstance(smiles_language, str):
         smiles_language = SMILESLanguage.load(smiles_language)
     self.smiles_language = smiles_language
def create_smiles_language(smi_path: str, pretrained_path: str) -> None:
    """
    Create a SMILESLanguage object and save it to disk.

    Args:
        smi_path (str): path to a folder containing .smi files.
        pretrained_path (str): directory to store the language as text files.
    """
    os.makedirs(pretrained_path, exist_ok=True)
    smiles_language = SMILESLanguage()
    smiles_language.add_smis([
        os.path.join(smi_path, smi_filename)
        for smi_filename in os.listdir(smi_path)
        if smi_filename.endswith('.smi')
    ])
    smiles_language.save_pretrained(pretrained_path)
Beispiel #5
0
def main(train_affinity_filepath, test_affinity_filepath, protein_filepath,
         smi_filepath, smiles_language_filepath, protein_language_filepath,
         model_path, params_filepath, training_name):

    logger = logging.getLogger(f'{training_name}')
    # Process parameter file:
    params = {}
    with open(params_filepath) as fp:
        params.update(json.load(fp))

    # Create model directory and dump files
    model_dir = os.path.join(model_path, training_name)
    os.makedirs(os.path.join(model_dir, 'weights'), exist_ok=True)
    os.makedirs(os.path.join(model_dir, 'results'), exist_ok=True)
    with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp:
        json.dump(params, fp, indent=4)

    # Prepare the dataset
    logger.info("Start data preprocessing...")
    device = get_device()

    # Load languages
    smiles_language = SMILESLanguage.load(smiles_language_filepath)
    protein_language = ProteinLanguage.load(protein_language_filepath)

    # Assemble datasets
    train_dataset = DrugAffinityDataset(
        drug_affinity_filepath=train_affinity_filepath,
        smi_filepath=smi_filepath,
        protein_filepath=protein_filepath,
        smiles_language=smiles_language,
        protein_language=protein_language,
        smiles_padding=params.get('smiles_padding', True),
        smiles_padding_length=params.get('smiles_padding_length', None),
        smiles_add_start_and_stop=params.get('smiles_add_start_stop', True),
        smiles_augment=params.get('augment_smiles', False),
        smiles_canonical=params.get('smiles_canonical', False),
        smiles_kekulize=params.get('smiles_kekulize', False),
        smiles_all_bonds_explicit=params.get('smiles_bonds_explicit', False),
        smiles_all_hs_explicit=params.get('smiles_all_hs_explicit', False),
        smiles_remove_bonddir=params.get('smiles_remove_bonddir', False),
        smiles_remove_chirality=params.get('smiles_remove_chirality', False),
        smiles_selfies=params.get('selfies', False),
        protein_amino_acid_dict=params.get('protein_amino_acid_dict', 'iupac'),
        protein_padding=params.get('protein_padding', True),
        protein_padding_length=params.get('protein_padding_length', None),
        protein_add_start_and_stop=params.get('protein_add_start_stop', True),
        protein_augment_by_revert=params.get('protein_augment', False),
        device=device,
        drug_affinity_dtype=torch.float,
        backend='eager')
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=params['batch_size'],
                                               shuffle=True,
                                               drop_last=True,
                                               num_workers=params.get(
                                                   'num_workers', 0))

    test_dataset = DrugAffinityDataset(
        drug_affinity_filepath=test_affinity_filepath,
        smi_filepath=smi_filepath,
        protein_filepath=protein_filepath,
        smiles_language=smiles_language,
        protein_language=protein_language,
        smiles_padding=params.get('smiles_padding', True),
        smiles_padding_length=params.get('smiles_padding_length', None),
        smiles_add_start_and_stop=params.get('smiles_add_start_stop', True),
        smiles_augment=False,
        smiles_canonical=params.get('smiles_canonical', False),
        smiles_kekulize=params.get('smiles_kekulize', False),
        smiles_all_bonds_explicit=params.get('smiles_bonds_explicit', False),
        smiles_all_hs_explicit=params.get('smiles_all_hs_explicit', False),
        smiles_remove_bonddir=params.get('smiles_remove_bonddir', False),
        smiles_remove_chirality=params.get('smiles_remove_chirality', False),
        smiles_selfies=params.get('selfies', False),
        protein_amino_acid_dict=params.get('protein_amino_acid_dict', 'iupac'),
        protein_padding=params.get('protein_padding', True),
        protein_padding_length=params.get('protein_padding_length', None),
        protein_add_start_and_stop=params.get('protein_add_start_stop', True),
        protein_augment_by_revert=False,
        device=device,
        drug_affinity_dtype=torch.float,
        backend='eager')
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=params['batch_size'],
                                              shuffle=True,
                                              drop_last=True,
                                              num_workers=params.get(
                                                  'num_workers', 0))
    logger.info(
        f'Training dataset has {len(train_dataset)} samples, test set has '
        f'{len(test_dataset)}.')

    logger.info(f'Device for data loader is {train_dataset.device} and for '
                f'model is {device}')
    save_top_model = os.path.join(model_dir, 'weights/{}_{}_{}.pt')
    params.update({
        'smiles_vocabulary_size': smiles_language.number_of_tokens,
        'protein_vocabulary_size': protein_language.number_of_tokens
    })  # yapf: disable

    model_fn = params.get('model_fn', 'bimodal_mca')
    model = MODEL_FACTORY[model_fn](params).to(device)
    model._associate_language(smiles_language)
    model._associate_language(protein_language)

    if os.path.isfile(os.path.join(model_dir, 'weights', 'best_mca.pt')):
        logger.info('Found existing model, restoring now...')
        try:
            model.load(os.path.join(model_dir, 'weights', 'best_mca.pt'))

            with open(os.path.join(model_dir, 'results', 'mse.json'),
                      'r') as f:
                info = json.load(f)

                max_roc_auc = info['best_roc_auc']
                min_loss = info['test_loss']

        except Exception:
            min_loss, max_roc_auc = 100, 0
    else:
        min_loss, max_roc_auc = 100, 0

    # Define optimizer
    optimizer = (OPTIMIZER_FACTORY[params.get('optimizer',
                                              'adam')](model.parameters(),
                                                       lr=params.get(
                                                           'lr', 0.001)))
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    params.update({'number_of_parameters': num_params})
    logger.info(f'Number of parameters: {num_params}')
    logger.info(f'Model: {model}')

    # Overwrite params.json file with updated parameters.
    with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp:
        json.dump(params, fp)

    # Start training
    logger.info('Training about to start...\n')
    t = time()

    model.save(save_top_model.format('epoch', '0', model_fn))

    for epoch in range(params['epochs']):

        model.train()
        logger.info(f"== Epoch [{epoch}/{params['epochs']}] ==")
        train_loss = 0

        for ind, (smiles, proteins, y) in enumerate(train_loader):
            if ind % 100 == 0:
                logger.info(f'Batch {ind}/{len(train_loader)}')
            y_hat, pred_dict = model(smiles, proteins)
            loss = model.loss(y_hat, y.to(device))
            optimizer.zero_grad()
            loss.backward()
            # Apply gradient clipping
            # torch.nn.utils.clip_grad_norm_(model.parameters(),1e-6)
            optimizer.step()
            train_loss += loss.item()

        logger.info("\t **** TRAINING ****   "
                    f"Epoch [{epoch + 1}/{params['epochs']}], "
                    f"loss: {train_loss / len(train_loader):.5f}. "
                    f"This took {time() - t:.1f} secs.")
        t = time()

        # Measure validation performance
        model.eval()
        with torch.no_grad():
            test_loss = 0
            predictions = []
            labels = []
            for ind, (smiles, proteins, y) in enumerate(test_loader):
                y_hat, pred_dict = model(smiles.to(device),
                                         proteins.to(device))
                predictions.append(y_hat)
                labels.append(y.clone())
                loss = model.loss(y_hat, y.to(device))
                test_loss += loss.item()

        predictions = torch.cat(predictions, dim=0).flatten().cpu().numpy()
        labels = torch.cat(labels, dim=0).flatten().cpu().numpy()

        test_loss = test_loss / len(test_loader)
        fpr, tpr, _ = roc_curve(labels, predictions)
        test_roc_auc = auc(fpr, tpr)

        # calculations for visualization plot
        precision, recall, _ = precision_recall_curve(labels, predictions)
        avg_precision = average_precision_score(labels, predictions)

        test_loss = test_loss / len(test_loader)
        logger.info(
            f"\t **** TESTING **** Epoch [{epoch + 1}/{params['epochs']}], "
            f"loss: {test_loss:.5f}, ROC-AUC: {test_roc_auc:.3f}, "
            f"Average precision: {avg_precision:.3f}.")

        def save(path, metric, typ, val=None):
            model.save(path.format(typ, metric, model_fn))
            info = {
                'best_roc_auc': str(max_roc_auc),
                'test_loss': str(min_loss)
            }
            with open(os.path.join(model_dir, 'results', metric + '.json'),
                      'w') as f:
                json.dump(info, f)
            np.save(os.path.join(model_dir, 'results', metric + '_preds.npy'),
                    np.vstack([predictions, labels]))
            if typ == 'best':
                logger.info(f'\t New best performance in "{metric}"'
                            f' with value : {val:.7f} in epoch: {epoch}')

        if test_roc_auc > max_roc_auc:
            max_roc_auc = test_roc_auc
            save(save_top_model, 'ROC-AUC', 'best', max_roc_auc)
            ep_roc = epoch
            roc_auc_loss = test_loss

        if test_loss < min_loss:
            min_loss = test_loss
            save(save_top_model, 'loss', 'best', min_loss)
            ep_loss = epoch
            loss_roc_auc = test_roc_auc
        if (epoch + 1) % params.get('save_model', 100) == 0:
            save(save_top_model, 'epoch', str(epoch))
    logger.info('Overall best performances are: \n \t'
                f'Loss = {min_loss:.4f} in epoch {ep_loss} '
                f'\t (ROC-AUC was {loss_roc_auc:4f}) \n \t'
                f'ROC-AUC = {max_roc_auc:.4f} in epoch {ep_roc} '
                f'\t (Loss was {roc_auc_loss:4f})')
    save(save_top_model, 'training', 'done')
    logger.info('Done with training, models saved, shutting down.')
Beispiel #6
0
def load_data(
    filename: str,
    smiles_language: Union[SMILESLanguage, str],
    max_close_neighbors: int = 5,
    max_total_bundle: int = 20,
) -> List[Data]:
    """Loads the data from a pre-processed JSON file and parses it into
    a list of torch_geometric.data.Data objects.

    Each element of the Data object is obtained by expanding a node up
    to 2 levels of depth. The sampling for this is controlled by
    `max_close_neighbors` and `max_total_bundle` (i.e. maximum total
    number of elements). This will be applied to all the nodes of the
    dataset, resulting in a returned list of `len` equal to the number
    of SMILES (i.e. total nodes).

    NOTE: Reason for this implementation is that ClusterData is
    failing when running it on a Data object with the enire dataset:
    ```
    # smiles = [nodes2smiles[i] for i in range(len(nodes2smiles))]
    # return Data(
    #     x=torch.tensor(list(range(len(smiles)))).view(-1, 1),
    #     edge_index=torch.tensor(edgelist).T,
    #     # num_nodes=len(smiles)
    # )
    # number_of_nodes = 10
    # cluster_data = ClusterData(
    #     data, num_parts=data.num_nodes // number_of_nodes, log=True
    # )
    # --> This is causing: SIGSEGV (Address boundary error)
    ```

    Args:
        filename (str): Path to the JSON file. It must have the fields
          `smiles2nodes` and `edgelist`.
        max_close_neighbors (int, optional): Number of close (1 jump
          away) nodes that the algorith sample. Defaults to 5.
        max_total_bundle (int, optional): Maximum total number of
          elements in each Data object. Defaults to 20.

    Returns:
        List[Data]
    """
    with open(filename, 'r') as f:
        raw = json.load(f)
    nodes2smiles = {v: k for k, v in raw['smiles2nodes'].items()}
    edgelist = raw['edgelist']

    if isinstance(smiles_language, str):
        smiles_language = SMILESLanguage.load(smiles_language)
    smiles_language = smiles_language

    data = []
    # Custom sampling
    G = nx.from_edgelist(edgelist)
    for root in range(len(G)):
        # FIXME This could have better been a DFS-like approach with fixed
        # depth but I had this half way there and it was easier to first sample
        # the closer hood and then expand to one more step from there. Not even
        # sure if DFS is what I want tho. I can see pros and cons.
        shortlist = []
        close_neighbors = np.random.permutation(list(
            G.neighbors(root)))[:max_close_neighbors]
        total_far_neighbors = sum(
            [len(list(G.neighbors(x))) for x in close_neighbors])

        shortlist += [[root, x] for x in close_neighbors]

        counter = 0
        while (len(shortlist) < max_total_bundle
               and counter < total_far_neighbors):
            # TODO Random sampling probability inversely proportional to the
            # degree?
            current_node = close_neighbors[counter % len(close_neighbors)]
            far_node = np.random.choice(list(G.neighbors(current_node)))
            shortlist.append([current_node, far_node])
            counter += 1

        # We need to relabel the nodes, but keep the original location in order
        # to retrieve the corresponding smiles
        sub_graph = nx.convert_node_labels_to_integers(
            nx.from_edgelist(shortlist), label_attribute='original')

        x = [
            torch.tensor(
                smiles_language.smiles_to_token_indexes(
                    nodes2smiles[sub_graph.nodes[i]['original']]))
            for i in range(len(sub_graph))
        ]
        edge_index = torch.tensor(nx.to_pandas_edgelist(sub_graph).values).T

        data.append(Data(x=x, edge_index=edge_index, num_nodes=len(x)))

    return data
Beispiel #7
0
def main(
    train_scores_filepath, test_scores_filepath, smi_filepath, model_path,
    params_filepath, training_name
):

    logging.basicConfig(level=logging.INFO, format='%(message)s')
    logger = logging.getLogger(f'{training_name}')
    logger.setLevel(logging.INFO)
    disable_rdkit_logging()
    device = get_device()

    # Restore pretrained model
    logger.info('Start model restoring.')
    try:
        with open(os.path.join(model_path, 'model_params.json'), 'r') as fp:
            params = json.load(fp)
        smiles_language = SMILESLanguage.load(
            os.path.join(model_path, 'smiles_language.pkl')
        )
        model = MODEL_FACTORY[params.get('model_fn', 'mca')](params).to(device)
        # Try weight restoring
        try:
            weight_path = os.path.join(
                model_path, 'weights',
                params.get('weights_name', 'best_ROC-AUC_mca.pt')
            )
            model.load(weight_path)
        except Exception:
            try:
                wp = os.listdir(os.path.join(model_path, 'weights'))[0]
                logger.info(
                    f"Weights {weight_path} not found. Try restore {wp}"
                )
                model.load(os.path.join(model_path, 'weights', wp))
            except Exception:
                raise Exception('Error in weight loading.')

    except Exception:
        raise Exception(
            'Error in model restoring. model_path should point to the model root '
            'folder that contains a model_params.json, a smiles_language.pkl and a '
            'weights folder.'
        )
    logger.info('Model restored. Now starting to craft it for the task')

    # Process parameter file:
    model_params = {}
    with open(params_filepath) as fp:
        model_params.update(json.load(fp))

    model = update_mca_model(model, model_params)

    logger.info('Model set up.')
    for idx, (name, param) in enumerate(model.named_parameters()):
        logger.info(
            (idx, name, param.shape, f'Gradients: {param.requires_grad}')
        )

    ft_model_path = os.path.join(model_path, 'finetuned')
    os.makedirs(os.path.join(ft_model_path, 'weights'), exist_ok=True)
    os.makedirs(os.path.join(ft_model_path, 'results'), exist_ok=True)

    logger.info('Now start data preprocessing...')

    # Assemble datasets
    smiles_dataset = SMILESDataset(
        smi_filepath,
        smiles_language=smiles_language,
        padding_length=params.get('smiles_padding_length', None),
        padding=params.get('padd_smiles', True),
        add_start_and_stop=params.get('add_start_stop_token', True),
        augment=params.get('augment_smiles', False),
        canonical=params.get('canonical', False),
        kekulize=params.get('kekulize', False),
        all_bonds_explicit=params.get('all_bonds_explicit', False),
        all_hs_explicit=params.get('all_hs_explicit', False),
        randomize=params.get('randomize', False),
        remove_bonddir=params.get('remove_bonddir', False),
        remove_chirality=params.get('remove_chirality', False),
        selfies=params.get('selfies', False),
        sanitize=params.get('sanitize', True)
    )

    train_dataset = AnnotatedDataset(
        annotations_filepath=train_scores_filepath,
        dataset=smiles_dataset,
        device=get_device()
    )
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=params['batch_size'],
        shuffle=True,
        drop_last=True,
        num_workers=params.get('num_workers', 0)
    )

    # Generally, if sanitize is True molecules are de-kekulized. Augmentation
    # preserves the "kekulization", so if it is used, test data should be
    # sanitized or canonicalized.
    smiles_test_dataset = SMILESDataset(
        smi_filepath,
        smiles_language=smiles_language,
        padding_length=params.get('smiles_padding_length', None),
        padding=params.get('padd_smiles', True),
        add_start_and_stop=params.get('add_start_stop_token', True),
        augment=params.get('augment_test_smiles', False),
        canonical=params.get('test_canonical', False),
        kekulize=params.get('test_kekulize', False),
        all_bonds_explicit=params.get('test_all_bonds_explicit', False),
        all_hs_explicit=params.get('test_all_hs_explicit', False),
        randomize=False,
        remove_bonddir=params.get('test_remove_bonddir', False),
        remove_chirality=params.get('test_remove_chirality', False),
        selfies=params.get('selfies', False),
        sanitize=params.get('test_sanitize', False)
    )

    # Dump eventually modified SMILES Language
    smiles_language.save(os.path.join(ft_model_path, 'smiles_language.pkl'))

    test_dataset = AnnotatedDataset(
        annotations_filepath=test_scores_filepath,
        dataset=smiles_test_dataset,
        device=get_device()
    )

    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=params['batch_size'],
        shuffle=False,
        drop_last=False,
        num_workers=params.get('num_workers', 0)
    )

    save_top_model = os.path.join(ft_model_path, 'weights/{}_{}_{}.pt')

    num_t_params = sum(
        p.numel() for p in model.parameters() if p.requires_grad
    )
    num_params = sum(p.numel() for p in model.parameters())
    params.update({'params': num_params, 'trainable_params': num_t_params})
    logger.info(
        f'Number of parameters: {num_params} (trainable {num_t_params}).'
    )

    # Define optimizer only for those layers which require gradients
    optimizer = (
        OPTIMIZER_FACTORY[params.get('optimizer', 'adam')](
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=params.get('lr', 0.00001)
        )
    )
    # Dump params.json file with updated parameters.
    with open(os.path.join(ft_model_path, 'model_params.json'), 'w') as fp:
        json.dump(params, fp)

    # Start training
    logger.info('Training about to start...\n')
    t = time()
    min_loss, max_roc_auc = 1000000, 0
    max_precision_recall_score = 0

    for epoch in range(params['epochs']):

        model.train()
        logger.info(params_filepath.split('/')[-1])
        logger.info(f"== Epoch [{epoch}/{params['epochs']}] ==")
        train_loss = 0

        for ind, (smiles, y) in enumerate(train_loader):

            smiles = torch.squeeze(smiles.to(device))
            y_hat, pred_dict = model(smiles)

            loss = model.loss(y_hat, y.to(device))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        logger.info(
            '\t **** TRAINING ****   '
            f"Epoch [{epoch + 1}/{params['epochs']}], "
            f'loss: {train_loss / len(train_loader):.5f}. '
            f'This took {time() - t:.1f} secs.'
        )
        t = time()

        # Measure validation performance
        model.eval()
        with torch.no_grad():
            test_loss = 0
            predictions = []
            labels = []
            for ind, (smiles, y) in enumerate(test_loader):

                smiles = torch.squeeze(smiles.to(device))

                y_hat, pred_dict = model(smiles)
                predictions.append(y_hat)
                # Copy y tensor since loss function applies downstream modification
                labels.append(y.clone())
                loss = model.loss(y_hat, y.to(device))
                test_loss += loss.item()

        predictions = torch.cat(predictions, dim=0).flatten().cpu().numpy()
        labels = torch.cat(labels, dim=0).flatten().cpu().numpy()

        # Remove NaNs from labels to compute scores
        predictions = predictions[~np.isnan(labels)]
        labels = labels[~np.isnan(labels)]
        test_loss_a = test_loss / len(test_loader)
        fpr, tpr, _ = roc_curve(labels, predictions)
        test_roc_auc_a = auc(fpr, tpr)

        # calculations for visualization plot
        precision, recall, _ = precision_recall_curve(labels, predictions)
        # score for precision vs accuracy
        test_precision_recall_score = average_precision_score(
            labels, predictions
        )

        logger.info(
            f"\t **** TEST **** Epoch [{epoch + 1}/{params['epochs']}], "
            f'loss: {test_loss_a:.5f}, , roc_auc: {test_roc_auc_a:.5f}, '
            f'avg precision-recall score: {test_precision_recall_score:.5f}'
        )
        info = {
            'test_auc': test_roc_auc_a,
            'train_loss': train_loss / len(train_loader),
            'test_loss': test_loss_a,
            'test_auc': test_roc_auc_a,
            'best_test_auc': max_roc_auc,
            'test_precision_recall_score': test_precision_recall_score,
            'best_precision_recall_score': max_precision_recall_score,
        }

        def save(path, metric, typ, val=None):
            model.save(path.format(typ, metric, params.get('model_fn', 'mca')))
            if typ == 'best':
                logger.info(
                    f'\t New best performance in {metric}'
                    f' with value : {val:.7f} in epoch: {epoch+1}'
                )

        if test_roc_auc_a > max_roc_auc:
            max_roc_auc = test_roc_auc_a
            info.update({'best_test_auc': max_roc_auc})
            save(save_top_model, 'ROC-AUC', 'best', max_roc_auc)
            np.save(
                os.path.join(ft_model_path, 'results', 'best_predictions.npy'),
                predictions
            )
            with open(
                os.path.join(ft_model_path, 'results', 'metrics.json'), 'w'
            ) as f:
                json.dump(info, f)

        if test_precision_recall_score > max_precision_recall_score:
            max_precision_recall_score = test_precision_recall_score
            info.update(
                {'best_precision_recall_score': max_precision_recall_score}
            )
            save(
                save_top_model, 'precision-recall score', 'best',
                max_precision_recall_score
            )

        if test_loss_a < min_loss:
            min_loss = test_loss_a
            save(save_top_model, 'loss', 'best', min_loss)
            ep_loss = epoch

        if (epoch + 1) % params.get('save_model', 100) == 0:
            save(save_top_model, 'epoch', str(epoch))

    logger.info(
        'Overall best performances are: \n \t'
        f'Loss = {min_loss:.4f} in epoch {ep_loss} '
    )
    save(save_top_model, 'training', 'done')
    logger.info('Done with training, models saved, shutting down.')
Beispiel #8
0
def main(train_scores_filepath,
         test_scores_filepath,
         smi_filepath,
         smiles_language_filepath,
         model_path,
         params_filepath,
         training_name,
         embedding_path=None):

    logging.basicConfig(level=logging.INFO, format='%(message)s')
    logger = logging.getLogger(f'{training_name}')
    logger.setLevel(logging.INFO)
    disable_rdkit_logging()

    # Process parameter file:
    params = {}
    with open(params_filepath) as fp:
        params.update(json.load(fp))

    if embedding_path:
        params['embedding_path'] = embedding_path

    # Create model directory and dump files
    model_dir = os.path.join(model_path, training_name)
    os.makedirs(os.path.join(model_dir, 'weights'), exist_ok=True)
    os.makedirs(os.path.join(model_dir, 'results'), exist_ok=True)
    with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp:
        json.dump(params, fp, indent=4)

    logger.info('Start data preprocessing...')
    smiles_language = SMILESLanguage.load(smiles_language_filepath)

    # Prepare FP processing
    if params.get('model_fn', 'mca') == 'dense':
        morgan_transform = Compose([
            SMILESToMorganFingerprints(radius=params.get('fp_radius', 2),
                                       bits=params.get('num_drug_features',
                                                       512),
                                       chirality=params.get(
                                           'fp_chirality', True)),
            ToTensor(get_device())
        ])

        def smiles_tensor_batch_to_fp(smiles):
            """ To abuse SMILES dataset for FP usage"""
            out = torch.Tensor(smiles.shape[0],
                               params.get('num_drug_features', 256))
            for ind, tensor in enumerate(smiles):
                smiles = smiles_language.token_indexes_to_smiles(
                    tensor.tolist())
                out[ind, :] = torch.squeeze(morgan_transform(smiles))
            return out

    # Assemble datasets
    smiles_dataset = SMILESDataset(
        smi_filepath,
        smiles_language=smiles_language,
        padding_length=params.get('smiles_padding_length', None),
        padding=params.get('padd_smiles', True),
        add_start_and_stop=params.get('add_start_stop_token', True),
        augment=params.get('augment_smiles', False),
        canonical=params.get('canonical', False),
        kekulize=params.get('kekulize', False),
        all_bonds_explicit=params.get('all_bonds_explicit', False),
        all_hs_explicit=params.get('all_hs_explicit', False),
        randomize=params.get('randomize', False),
        remove_bonddir=params.get('remove_bonddir', False),
        remove_chirality=params.get('remove_chirality', False),
        selfies=params.get('selfies', False),
        sanitize=params.get('sanitize', True))

    # include arg label_columns if data file has any unwanted columns (such as index) to be ignored.
    train_dataset = AnnotatedDataset(
        annotations_filepath=train_scores_filepath,
        dataset=smiles_dataset,
        device=get_device())
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=params['batch_size'],
                                               shuffle=True,
                                               drop_last=True,
                                               num_workers=params.get(
                                                   'num_workers', 0))

    if (params.get('uncertainty', True)
            and params.get('augment_test_data', False)):
        raise ValueError(
            'Epistemic uncertainty evaluation not supported if augmentation '
            'is not enabled for test data.')

    # Generally, if sanitize is True molecules are de-kekulized. Augmentation
    # preserves the "kekulization", so if it is used, test data should be
    # sanitized or canonicalized.
    smiles_test_dataset = SMILESDataset(
        smi_filepath,
        smiles_language=smiles_language,
        padding_length=params.get('smiles_padding_length', None),
        padding=params.get('padd_smiles', True),
        add_start_and_stop=params.get('add_start_stop_token', True),
        augment=params.get('augment_test_smiles', False),
        canonical=params.get('test_canonical', False),
        kekulize=params.get('test_kekulize', False),
        all_bonds_explicit=params.get('test_all_bonds_explicit', False),
        all_hs_explicit=params.get('test_all_hs_explicit', False),
        randomize=False,
        remove_bonddir=params.get('test_remove_bonddir', False),
        remove_chirality=params.get('test_remove_chirality', False),
        selfies=params.get('selfies', False),
        sanitize=params.get('test_sanitize', False))

    # Dump eventually modified SMILES Language
    smiles_language.save(os.path.join(model_dir, 'smiles_language.pkl'))

    logger.info(smiles_dataset._dataset.transform)
    logger.info(smiles_test_dataset._dataset.transform)

    # include arg label_columns if data file has any unwanted columns (such as index) to be ignored.
    test_dataset = AnnotatedDataset(annotations_filepath=test_scores_filepath,
                                    dataset=smiles_test_dataset,
                                    device=get_device())

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=params['batch_size'],
                                              shuffle=False,
                                              drop_last=False,
                                              num_workers=params.get(
                                                  'num_workers', 0))

    if params.get('confidence', False):
        smiles_conf_dataset = SMILESDataset(
            smi_filepath,
            smiles_language=smiles_language,
            padding_length=params.get('smiles_padding_length', None),
            padding=params.get('padd_smiles', True),
            add_start_and_stop=params.get('add_start_stop_token', True),
            augment=True,  # Natively true for epistemic uncertainity estimate
            canonical=params.get('canonical', False),
            kekulize=params.get('kekulize', False),
            all_bonds_explicit=params.get('all_bonds_explicit', False),
            all_hs_explicit=params.get('all_hs_explicit', False),
            randomize=params.get('randomize', False),
            remove_bonddir=params.get('remove_bonddir', False),
            remove_chirality=params.get('remove_chirality', False),
            selfies=params.get('selfies', False))
        conf_dataset = AnnotatedDataset(
            annotations_filepath=test_scores_filepath,
            dataset=smiles_conf_dataset,
            device=get_device())
        conf_loader = torch.utils.data.DataLoader(
            dataset=conf_dataset,
            batch_size=params['batch_size'],
            shuffle=False,
            drop_last=False,
            num_workers=params.get('num_workers', 0))

    if not params.get('embedding', 'learned') == 'pretrained':
        params.update(
            {'smiles_vocabulary_size': smiles_language.number_of_tokens})

    device = get_device()
    logger.info(f'Device for data loader is {train_dataset.device} and for '
                f'model is {device}')
    save_top_model = os.path.join(model_dir, 'weights/{}_{}_{}.pt')

    model = MODEL_FACTORY[params.get('model_fn', 'mca')](params).to(device)
    logger.info(model)
    logger.info(model.loss_fn.class_weights)

    logger.info('Parameters follow')
    for name, param in model.named_parameters():
        logger.info((name, param.shape))

    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    params.update({'number_of_parameters': num_params})
    logger.info(f'Number of parameters {num_params}')

    # Define optimizer
    optimizer = (OPTIMIZER_FACTORY[params.get('optimizer',
                                              'adam')](model.parameters(),
                                                       lr=params.get(
                                                           'lr', 0.00001)))
    # Overwrite params.json file with updated parameters.
    with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp:
        json.dump(params, fp)

    # Start training
    logger.info('Training about to start...\n')
    t = time()
    min_loss, max_roc_auc = 1000000, 0
    max_precision_recall_score = 0

    for epoch in range(params['epochs']):

        model.train()
        logger.info(params_filepath.split('/')[-1])
        logger.info(f"== Epoch [{epoch}/{params['epochs']}] ==")
        train_loss = 0

        for ind, (smiles, y) in enumerate(train_loader):

            smiles = torch.squeeze(smiles.to(device))
            # Transform smiles to FP if needed
            if params.get('model_fn', 'mca') == 'dense':
                smiles = smiles_tensor_batch_to_fp(smiles).to(device)

            y_hat, pred_dict = model(smiles)

            loss = model.loss(y_hat, y.to(device))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        logger.info('\t **** TRAINING ****   '
                    f"Epoch [{epoch + 1}/{params['epochs']}], "
                    f'loss: {train_loss / len(train_loader):.5f}. '
                    f'This took {time() - t:.1f} secs.')
        t = time()

        # Measure validation performance
        model.eval()
        with torch.no_grad():
            test_loss = 0
            predictions = []
            labels = []
            for ind, (smiles, y) in enumerate(test_loader):

                smiles = torch.squeeze(smiles.to(device))
                # Transform smiles to FP if needed
                if params.get('model_fn', 'mca') == 'dense':
                    smiles = smiles_tensor_batch_to_fp(smiles).to(device)

                y_hat, pred_dict = model(smiles)
                predictions.append(y_hat)
                # Copy y tensor since loss function applies downstream
                #   modification
                labels.append(y.clone())
                loss = model.loss(y_hat, y.to(device))
                test_loss += loss.item()

        predictions = torch.cat(predictions, dim=0).flatten().cpu().numpy()
        labels = torch.cat(labels, dim=0).flatten().cpu().numpy()

        # Remove NaNs from labels to compute scores
        predictions = predictions[~np.isnan(labels)]
        labels = labels[~np.isnan(labels)]
        test_loss_a = test_loss / len(test_loader)
        fpr, tpr, _ = roc_curve(labels, predictions)
        test_roc_auc_a = auc(fpr, tpr)

        # calculations for visualization plot
        precision, recall, _ = precision_recall_curve(labels, predictions)
        # score for precision vs accuracy
        test_precision_recall_score = average_precision_score(
            labels, predictions)

        logger.info(
            f"\t **** TEST **** Epoch [{epoch + 1}/{params['epochs']}], "
            f'loss: {test_loss_a:.5f}, , roc_auc: {test_roc_auc_a:.5f}, '
            f'avg precision-recall score: {test_precision_recall_score:.5f}')
        info = {
            'test_auc': test_roc_auc_a,
            'train_loss': train_loss / len(train_loader),
            'test_loss': test_loss_a,
            'test_auc': test_roc_auc_a,
            'best_test_auc': max_roc_auc,
            'test_precision_recall_score': test_precision_recall_score,
            'best_precision_recall_score': max_precision_recall_score,
        }

        def save(path, metric, typ, val=None):
            model.save(path.format(typ, metric, params.get('model_fn', 'mca')))
            if typ == 'best':
                logger.info(f'\t New best performance in {metric}'
                            f' with value : {val:.7f} in epoch: {epoch+1}')

        if test_roc_auc_a > max_roc_auc:
            max_roc_auc = test_roc_auc_a
            info.update({'best_test_auc': max_roc_auc})
            save(save_top_model, 'ROC-AUC', 'best', max_roc_auc)
            np.save(os.path.join(model_dir, 'results', 'best_predictions.npy'),
                    predictions)
            with open(os.path.join(model_dir, 'results', 'metrics.json'),
                      'w') as f:
                json.dump(info, f)
            if params.get('confidence', False):
                # Compute uncertainity estimates and save them
                epistemic_conf = monte_carlo_dropout(model,
                                                     regime='loader',
                                                     loader=conf_loader)
                aleatoric_conf = test_time_augmentation(model,
                                                        regime='loader',
                                                        loader=conf_loader)

                np.save(
                    os.path.join(model_dir, 'results', 'epistemic_conf.npy'),
                    epistemic_conf)
                np.save(
                    os.path.join(model_dir, 'results', 'aleatoric_conf.npy'),
                    aleatoric_conf)

        if test_precision_recall_score > max_precision_recall_score:
            max_precision_recall_score = test_precision_recall_score
            info.update(
                {'best_precision_recall_score': max_precision_recall_score})
            save(save_top_model, 'precision-recall score', 'best',
                 max_precision_recall_score)

        if test_loss_a < min_loss:
            min_loss = test_loss_a
            save(save_top_model, 'loss', 'best', min_loss)
            ep_loss = epoch

        if (epoch + 1) % params.get('save_model', 100) == 0:
            save(save_top_model, 'epoch', str(epoch))

    logger.info('Overall best performances are: \n \t'
                f'Loss = {min_loss:.4f} in epoch {ep_loss} ')
    save(save_top_model, 'training', 'done')
    logger.info('Done with training, models saved, shutting down.')
Beispiel #9
0
def main(train_sensitivity_filepath, test_sensitivity_filepath, gep_filepath,
         smi_filepath, gene_filepath, smiles_language_filepath, model_path,
         params_filepath, training_name):

    logger = logging.getLogger(f'{training_name}')
    # Process parameter file:
    params = {}
    with open(params_filepath) as fp:
        params.update(json.load(fp))

    # Create model directory and dump files
    model_dir = os.path.join(model_path, training_name)
    os.makedirs(os.path.join(model_dir, 'weights'), exist_ok=True)
    os.makedirs(os.path.join(model_dir, 'results'), exist_ok=True)
    with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp:
        json.dump(params, fp, indent=4)

    # Prepare the dataset
    logger.info("Start data preprocessing...")

    # Load SMILES language
    smiles_language = SMILESLanguage.load(smiles_language_filepath)

    # Load the gene list
    with open(gene_filepath, 'rb') as f:
        gene_list = pickle.load(f)

    # Assemble datasets
    train_dataset = DrugSensitivityDataset(
        drug_sensitivity_filepath=train_sensitivity_filepath,
        smi_filepath=smi_filepath,
        gene_expression_filepath=gep_filepath,
        smiles_language=smiles_language,
        gene_list=gene_list,
        drug_sensitivity_min_max=params.get('drug_sensitivity_min_max', True),
        augment=params.get('augment_smiles', True),
        add_start_and_stop=params.get('smiles_start_stop_token', True),
        padding_length=params.get('smiles_padding_length', None),
        device=torch.device(params.get('dataset_device', 'cpu')),
        backend='eager')
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=params['batch_size'],
                                               shuffle=True,
                                               drop_last=True,
                                               num_workers=params.get(
                                                   'num_workers', 0))

    test_dataset = DrugSensitivityDataset(
        drug_sensitivity_filepath=test_sensitivity_filepath,
        smi_filepath=smi_filepath,
        gene_expression_filepath=gep_filepath,
        smiles_language=smiles_language,
        gene_list=gene_list,
        drug_sensitivity_min_max=params.get('drug_sensitivity_min_max', True),
        augment=params.get('augment_smiles', True),
        add_start_and_stop=params.get('smiles_start_stop_token', True),
        padding_length=params.get('smiles_padding_length', None),
        device=torch.device(params.get('dataset_device', 'cpu')),
        backend='eager')
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=params['batch_size'],
                                              shuffle=True,
                                              drop_last=True,
                                              num_workers=params.get(
                                                  'num_workers', 0))

    params.update({
        'number_of_genes': len(gene_list),
        'smiles_vocabulary_size': smiles_language.number_of_tokens
    })

    # Set the tensorboard logger
    tb_logger = Logger(os.path.join(model_dir, 'tb_logs'))
    device = get_device()
    logger.info(f'Device for data loader is {train_dataset.device} and for '
                f'model is {device}')
    save_top_model = os.path.join(model_dir, 'weights/{}_{}_{}.pt')

    model = MODEL_FACTORY[params.get('model_fn', 'mca')](params).to(device)

    # Define optimizer
    optimizer = (OPTIMIZER_FACTORY[params.get('optimizer',
                                              'Adam')](model.parameters(),
                                                       lr=params.get(
                                                           'lr', 0.01)))
    # Overwrite params.json file with updated parameters.
    with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp:
        json.dump(params, fp)

    # Start training
    logger.info('Training about to start...\n')
    t = time()
    min_loss, max_pearson = 100, 0

    for epoch in range(params['epochs']):

        model.train()
        logger.info(params_filepath.split('/')[-1])
        logger.info(f"== Epoch [{epoch}/{params['epochs']}] ==")
        train_loss = 0

        for ind, (smiles, gep, y) in enumerate(train_loader):
            y_hat, pred_dict = model(torch.squeeze(smiles.to(device)),
                                     gep.to(device))
            loss = model.loss(y_hat, y.to(device))
            optimizer.zero_grad()
            loss.backward()
            # Apply gradient clipping
            # torch.nn.utils.clip_grad_norm_(model.parameters(),1e-6)
            optimizer.step()
            train_loss += loss.item()

        logger.info("\t **** TRAINING ****   "
                    f"Epoch [{epoch + 1}/{params['epochs']}], "
                    f"loss: {train_loss / len(train_loader):.5f}. "
                    f"This took {time() - t:.1f} secs.")
        t = time()

        # Measure validation performance
        model.eval()
        with torch.no_grad():
            test_loss = 0
            predictions = []
            labels = []
            for ind, (smiles, gep, y) in enumerate(test_loader):
                y_hat, pred_dict = model(torch.squeeze(smiles.to(device)),
                                         gep.to(device))
                predictions.append(y_hat)
                labels.append(y)
                loss = model.loss(y_hat, y.to(device))
                test_loss += loss.item()

        predictions = np.array(
            [p.cpu() for preds in predictions for p in preds])
        labels = np.array([l.cpu() for label in labels for l in label])
        test_pearson_a = pearsonr(predictions, labels)
        test_rmse_a = np.sqrt(np.mean((predictions - labels)**2))
        np.save(os.path.join(model_dir, 'results', 'preds.npy'),
                np.hstack([predictions, labels]))
        test_loss_a = test_loss / len(test_loader)
        logger.info(
            f"\t **** TESTING **** Epoch [{epoch + 1}/{params['epochs']}], "
            f"loss: {test_loss_a:.5f}, "
            f"Pearson: {test_pearson_a:.3f}, "
            f"RMSE: {test_rmse_a:.3f}")

        # TensorBoard logging of scalars.
        info = {
            'train_loss': train_loss / len(train_loader),
            'test_loss': test_loss_a,
        }
        for tag, value in info.items():
            tb_logger.scalar_summary(tag, value, epoch + 1)

        def save(path, metric, typ, val=None):
            model.save(path.format(typ, metric, params.get('model_fn', 'mca')))
            if typ == 'best':
                logger.info(f'\t New best performance in "{metric}"'
                            f' with value : {val:.7f} in epoch: {epoch}')

        if test_loss_a < min_loss:
            min_loss = test_loss_a
            min_loss_pearson = test_pearson_a
            save(save_top_model, 'mse', 'best', min_loss)
            ep_loss = epoch
        if test_pearson_a > max_pearson:
            max_pearson = test_pearson_a
            max_pearson_loss = test_loss_a
            save(save_top_model, 'pearson', 'best', max_pearson)
            ep_pearson = epoch
        if (epoch + 1) % params.get('save_model', 100) == 0:
            save(save_top_model, 'epoch', str(epoch))
    logger.info('Overall best performances are: \n \t'
                f'Loss = {min_loss:.4f} in epoch {ep_loss} '
                f'\t (Pearson was {min_loss_pearson:4f}) \n \t'
                f'Pearson = {max_pearson:.4f} in epoch {ep_pearson} '
                f'\t (Loss was {max_pearson_loss:2f})')
    save(save_top_model, 'training', 'done')
    logger.info('Done with training, models saved, shutting down.')
Beispiel #10
0
def main(parser_namespace):
    # model loading
    disable_rdkit_logging()
    affinity_path = parser_namespace.affinity_path
    svae_path = parser_namespace.svae_path
    svae_weights_path = os.path.join(svae_path, "weights", "best_rec.pt")
    results_file_name = parser_namespace.optimisation_name

    logger.add(results_file_name + ".log", rotation="10 MB")

    svae_params = dict()
    with open(os.path.join(svae_path, "model_params.json"), "r") as f:
        svae_params.update(json.load(f))

    smiles_language = SMILESLanguage.load(
        os.path.join(svae_path, "selfies_language.pkl"))

    # initialize encoder, decoder, testVAE, and GP_generator_MW
    gru_encoder = StackGRUEncoder(svae_params)
    gru_decoder = StackGRUDecoder(svae_params)
    gru_vae = TeacherVAE(gru_encoder, gru_decoder)
    gru_vae.load_state_dict(
        torch.load(svae_weights_path, map_location=get_device()))

    gru_vae._associate_language(smiles_language)
    gru_vae.eval()

    smiles_generator = SmilesGenerator(gru_vae)

    with open(os.path.join(affinity_path, "model_params.json")) as f:
        predictor_params = json.load(f)
    affinity_predictor = MODEL_FACTORY["bimodal_mca"](predictor_params)
    affinity_predictor.load(
        os.path.join(
            affinity_path,
            f"weights/best_{predictor_params.get('p_metric', 'ROC-AUC')}_bimodal_mca.pt",
        ),
        map_location=get_device(),
    )
    affinity_protein_language = ProteinLanguage.load(
        os.path.join(affinity_path, "protein_language.pkl"))
    affinity_smiles_language = SMILESLanguage.load(
        os.path.join(affinity_path, "smiles_language.pkl"))
    affinity_predictor._associate_language(affinity_smiles_language)
    affinity_predictor._associate_language(affinity_protein_language)
    affinity_predictor.eval()

    erg_protein = "MASTIKEALSVVSEDQSLFECAYGTPHLAKTEMTASSSSDYGQTSKMSPRVPQQDWLSQPPARVTIKMECNPSQVNGSRNSPDECSVAKGGKMVGSPDTVGMNYGSYMEEKHMPPPNMTTNERRVIVPADPTLWSTDHVRQWLEWAVKEYGLPDVNILLFQNIDGKELCKMTKDDFQRLTPSYNADILLSHLHYLRETPLPHLTSDDVDKALQNSPRLMHARNTGGAAFIFPNTSVYPEATQRITTRPDLPYEPPRRSAWTGHGHPTPQSKAAQPSPSTVPKTEDQRPQLDPYQILGPTSSRLANPGSGQIQLWQFLLELLSDSSNSSCITWEGTNGEFKMTDPDEVARRWGERKSKPNMNYDKLSRALRYYYDKNIMTKVHGKRYAYKFDFHGIAQALQPHPPESSLYKYPSDLPYMGSYHAHPQKMNFVAPHPPALPVTSSSFFAAPNPYWNSPTGGIYPNTRLPTSHMPSHLGTYY"

    target_minimization_function = AffinityMinimization(
        smiles_generator, 30, affinity_predictor, erg_protein)
    qed_function = QEDMinimization(smiles_generator, 30)
    sa_function = SAMinimization(smiles_generator, 30)
    combined_minimization = CombinedMinimization(
        [target_minimization_function, qed_function, sa_function], 1,
        [0.75, 1, 0.5])
    target_optimizer = GPOptimizer(combined_minimization.evaluate)

    params = dict(
        dimensions=[(-5.0, 5.0)] * 256,
        acq_func="EI",
        n_calls=20,
        n_initial_points=19,
        initial_point_generator="random",
        random_state=1234,
    )
    logger.info("Optimisation parameters: {params}", params=params)

    # optimisation
    for j in range(5):
        res = target_optimizer.optimize(params)
        latent_point = torch.tensor([[res.x]])

        with open(results_file_name + "_LP" + str(j + 1) + ".pkl", "wb") as f:
            pickle.dump(latent_point, f, protocol=2)

        smile_set = set()

        while len(smile_set) < 20:
            smiles = smiles_generator.generate_smiles(
                latent_point.repeat(1, 30, 1))
            smile_set.update(set(smiles))
        smile_set = list(smile_set)

        pad_smiles_predictor = LeftPadding(
            affinity_predictor.smiles_padding_length,
            affinity_predictor.smiles_language.padding_index,
        )
        to_tensor = ToTensor(get_device())
        smiles_num = [
            torch.unsqueeze(
                to_tensor(
                    pad_smiles_predictor(
                        affinity_predictor.smiles_language.
                        smiles_to_token_indexes(smile))),
                0,
            ) for smile in smile_set
        ]

        smiles_tensor = torch.cat(smiles_num, dim=0)

        pad_protein_predictor = LeftPadding(
            affinity_predictor.protein_padding_length,
            affinity_predictor.protein_language.padding_index,
        )

        protein_num = torch.unsqueeze(
            to_tensor(
                pad_protein_predictor(
                    affinity_predictor.protein_language.
                    sequence_to_token_indexes([erg_protein]))),
            0,
        )
        protein_num = protein_num.repeat(len(smile_set), 1)

        with torch.no_grad():
            pred, _ = affinity_predictor(smiles_tensor, protein_num)
        affinities = torch.squeeze(pred, 1).numpy()

        sas = SAS()
        sa_scores = [sas(smile) for smile in smile_set]
        qed_scores = [qed(Chem.MolFromSmiles(smile)) for smile in smile_set]

        # save to file
        file = results_file_name + str(j + 1) + ".txt"
        logger.info("creating {file}", file=file)

        with open(file, "w") as f:
            f.write(
                f'{"point":<10}{"Affinity":<10}{"QED":<10}{"SA":<10}{"smiles":<15}\n'
            )
            for i in range(20):
                dat = [
                    i + 1, affinities[i], qed_scores[i], sa_scores[i],
                    smile_set[i]
                ]
                f.write(
                    f'{dat[0]:<10}{"%.3f"%dat[1]:<10}{"%.3f"%dat[2]:<10}{"%.3f"%dat[3]:<10}{dat[4]:<15}\n'
                )
                  # in the thesis we used 40k molecules extracted
                  # from the PubChem database

# The model is separated in 3 files, PARAMS, WEIGHTS, LANGUAGE.
# See: https://github.com/PaccMann/paccmann_chemistry for more info.
DATA_DIR =  # Directory with the model and the data
PARAM_FILE = os.path.join(DATA_DIR, 'model_params.json')
WEIGHT_FILE = os.path.join(DATA_DIR, 'weights', 'best_loss.pt')
LANG_FILE = os.path.join(DATA_DIR, 'selfies_language.pkl')

MODEL_DIR = './finetuned_model' # Directory for the finetuned model


# START LOADING
paccmann_vae = Encoder(PARAM_FILE, LANG_FILE, WEIGHT_FILE, 1)
smiles_language = SMILESLanguage.load(LANG_FILE)

SAVE_FOLDER = './'
LATENTS_FILE = latents_file if latents_file is not None \
    else os.path.join(SAVE_FOLDER, f'latents.{model}.npy')
CATALYST_LATENTS_FILE = os.path.join(
    SAVE_FOLDER, f'catalyst_latents.{model}.npy'
)


df = pd.read_csv(COMPOUNDS_PATH, header=None, sep='\t', names=['SMILES', 'id'])
db_catalysts = pd.read_csv(CATALYSTS_PATH)

latents, failed_smiles = calculate_latents(
    paccmann_vae, df.SMILES, LATENTS_FILE
)
def main(parser_namespace):

    disable_rdkit_logging()

    model_path = parser_namespace.model_path
    data_path = parser_namespace.data_path

    weights_path = os.path.join(model_path, 'weights', 'best_loss.pt')

    device = get_device()
    # read the params json
    params = dict()
    with open(os.path.join(model_path, 'model_params.json'), 'r') as f:
        params.update(json.load(f))

    params['batch_size'] = 1

    # Load SMILES language
    smiles_language = SMILESLanguage.load(
        os.path.join(model_path, 'selfies_language.pkl'))

    data_preparation = get_data_preparation(params.get('batch_mode'))
    device = get_device()

    dataset = SMILESDataset(
        data_path,
        smiles_language=smiles_language,
        padding=False,
        selfies=params.get('selfies', False),
        add_start_and_stop=params.get('add_start_stop_token', True),
        augment=False,  #params.get('augment_smiles', False),
        canonical=params.get('canonical', False),
        kekulize=params.get('kekulize', False),
        all_bonds_explicit=params.get('all_bonds_explicit', False),
        all_hs_explicit=params.get('all_hs_explicit', False),
        remove_bonddir=params.get('remove_bonddir', False),
        remove_chirality=params.get('remove_chirality', False),
        backend='lazy',
        device=device)

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=params.get('batch_size', 64),
        collate_fn=collate_fn,
        drop_last=True,
        shuffle=True,
        pin_memory=params.get('pin_memory', True),
        num_workers=params.get('num_workers', 8))
    # initialize encoder and decoder
    gru_encoder = StackGRUEncoder(params).to(device)
    gru_decoder = StackGRUDecoder(params).to(device)
    gru_vae = TeacherVAE(gru_encoder, gru_decoder).to(device)
    logger.info('\n****MODEL SUMMARY***\n')
    for name, parameter in gru_vae.named_parameters():
        logger.info(f'Param {name}, shape:\t{parameter.shape}')
    total_params = sum(p.numel() for p in gru_vae.parameters())
    logger.info(f'Total # params: {total_params}')

    gru_vae.load_state_dict(torch.load(weights_path, map_location=device))

    # Updating the vocab size will break the model
    params.update({
        # 'vocab_size': smiles_language.number_of_tokens,
        'pad_index': smiles_language.padding_index
    })  # yapf:disable

    # if params.get('embedding', 'learned') == 'one_hot':
    #     params.update({'embedding_size': params['vocab_size']})

    # train for n_epoch epochs
    logger.info(
        'Model creation, loading and data processing done. Evaluation starts.')

    gru_vae.eval()
    gru_vae.to(device)

    counter = 0
    with torch.no_grad():
        latent_code = []
        from tqdm import tqdm
        for batch in tqdm(dataloader, total=len(dataloader)):
            (encoder_seq, _, _) = data_preparation(batch,
                                                   input_keep=0.,
                                                   start_index=2,
                                                   end_index=3,
                                                   device=device)
            try:
                mu, logvar = gru_vae.encode(encoder_seq)
            except RuntimeError:
                # Substitute any new tokens by "<UNK>" tokens
                new_seq = []
                padd_encoder_seq, lenghts = (
                    torch.nn.utils.rnn.pad_packed_sequence(encoder_seq,
                                                           batch_first=True))
                for seq, _len in zip(padd_encoder_seq, lenghts):
                    seq = seq[:_len]
                    if any([x >= params['vocab_size'] for x in seq]):
                        seq = torch.tensor([
                            x if x < params['vocab_size'] else
                            smiles_language.unknown_index
                            for x in seq.tolist()
                        ]).short()

                        failed_smiles = smiles_language.selfies_to_smiles(
                            smiles_language.token_indexes_to_smiles(
                                seq.tolist()))
                        logger.warning(
                            f'Out of bounds sample: ~{counter}\t{failed_smiles}'
                        )
                    new_seq.append(seq)

                if new_seq:
                    for _ in range(params['batch_size'] - len(new_seq)):
                        new_seq.append(torch.ones_like(new_seq[-1]))
                    (encoder_seq, _, _) = data_preparation(new_seq,
                                                           input_keep=0.,
                                                           start_index=2,
                                                           end_index=3,
                                                           device=device)
                    mu, logvar = gru_vae.encode(encoder_seq)
            for _mu in mu.tolist():
                latent_code.append([counter, _mu])
                counter += 1

    LATENT_CODE_PATH = os.path.join(os.path.dirname(data_path),
                                    'samples_latent_code.tsv')

    with open(LATENT_CODE_PATH, 'w') as f:
        for i, mu in latent_code:
            f.write(f'{i}\t{",".join([str(x) for x in mu[0]])}\n')
Beispiel #13
0
def main(parser_namespace):
    try:
        device = get_device()
        disable_rdkit_logging()
        # read the params json
        params = dict()
        with open(parser_namespace.params_filepath) as f:
            params.update(json.load(f))

        # get params
        train_smiles_filepath = parser_namespace.train_smiles_filepath
        test_smiles_filepath = parser_namespace.test_smiles_filepath
        smiles_language_filepath = (
            parser_namespace.smiles_language_filepath
            if parser_namespace.smiles_language_filepath.lower() != 'none' else
            None)

        model_path = parser_namespace.model_path
        training_name = parser_namespace.training_name

        writer = SummaryWriter(f'logs/{training_name}')

        logger.info(f'Model with name {training_name} starts.')

        model_dir = os.path.join(model_path, training_name)
        log_path = os.path.join(model_dir, 'logs')
        val_dir = os.path.join(log_path, 'val_logs')
        os.makedirs(os.path.join(model_dir, 'weights'), exist_ok=True)
        os.makedirs(os.path.join(model_dir, 'results'), exist_ok=True)
        os.makedirs(log_path, exist_ok=True)
        os.makedirs(val_dir, exist_ok=True)

        # Load SMILES language
        smiles_language = None
        if smiles_language_filepath is not None:
            smiles_language = SMILESLanguage.load(smiles_language_filepath)

        logger.info(f'Smiles filepath: {train_smiles_filepath}')

        # create SMILES eager dataset
        smiles_train_data = SMILESDataset(
            train_smiles_filepath,
            smiles_language=smiles_language,
            padding=False,
            selfies=params.get('selfies', False),
            add_start_and_stop=params.get('add_start_stop_token', True),
            augment=params.get('augment_smiles', False),
            canonical=params.get('canonical', False),
            kekulize=params.get('kekulize', False),
            all_bonds_explicit=params.get('all_bonds_explicit', False),
            all_hs_explicit=params.get('all_hs_explicit', False),
            remove_bonddir=params.get('remove_bonddir', False),
            remove_chirality=params.get('remove_chirality', False),
            backend='lazy',
            device=device,
        )
        smiles_test_data = SMILESDataset(
            test_smiles_filepath,
            smiles_language=smiles_language,
            padding=False,
            selfies=params.get('selfies', False),
            add_start_and_stop=params.get('add_start_stop_token', True),
            augment=params.get('augment_smiles', False),
            canonical=params.get('canonical', False),
            kekulize=params.get('kekulize', False),
            all_bonds_explicit=params.get('all_bonds_explicit', False),
            all_hs_explicit=params.get('all_hs_explicit', False),
            remove_bonddir=params.get('remove_bonddir', False),
            remove_chirality=params.get('remove_chirality', False),
            backend='lazy',
            device=device,
        )

        if smiles_language_filepath is None:
            smiles_language = smiles_train_data.smiles_language
            smiles_language.save(
                os.path.join(model_path, f'{training_name}.lang'))
        else:
            smiles_language_filename = os.path.basename(
                smiles_language_filepath)
            smiles_language.save(
                os.path.join(model_dir, smiles_language_filename))

        params.update({
            'vocab_size': smiles_language.number_of_tokens,
            'pad_index': smiles_language.padding_index
        })

        vocab_dict = smiles_language.index_to_token
        params.update({
            'start_index':
            list(vocab_dict.keys())[list(
                vocab_dict.values()).index('<START>')],
            'end_index':
            list(vocab_dict.keys())[list(vocab_dict.values()).index('<STOP>')]
        })

        if params.get('embedding', 'learned') == 'one_hot':
            params.update({'embedding_size': params['vocab_size']})

        with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp:
            json.dump(params, fp)

        # create DataLoaders
        train_data_loader = torch.utils.data.DataLoader(
            smiles_train_data,
            batch_size=params.get('batch_size', 64),
            collate_fn=collate_fn,
            drop_last=True,
            shuffle=True,
            pin_memory=params.get('pin_memory', True),
            num_workers=params.get('num_workers', 8))

        test_data_loader = torch.utils.data.DataLoader(
            smiles_test_data,
            batch_size=params.get('batch_size', 64),
            collate_fn=collate_fn,
            drop_last=True,
            shuffle=True,
            pin_memory=params.get('pin_memory', True),
            num_workers=params.get('num_workers', 8))
        # initialize encoder and decoder
        gru_encoder = StackGRUEncoder(params).to(device)
        gru_decoder = StackGRUDecoder(params).to(device)
        gru_vae = TeacherVAE(gru_encoder, gru_decoder).to(device)
        # TODO I haven't managed to get this to work. I will leave it here
        # if somewant (or future me) wants to give it a look and get the
        # tensorboard graph to work
        # if writer and False:
        #     gru_vae.set_batch_mode('padded')
        #     dummy_input = torch.ones(smiles_train_data[0].shape)
        #     dummy_input = dummy_input.unsqueeze(0).to(device)
        #     writer.add_graph(gru_vae, (dummy_input, dummy_input, dummy_input))
        #     gru_vae.set_batch_mode(params.get('batch_mode'))
        logger.info('\n****MODEL SUMMARY***\n')
        for name, parameter in gru_vae.named_parameters():
            logger.info(f'Param {name}, shape:\t{parameter.shape}')
        total_params = sum(p.numel() for p in gru_vae.parameters())
        logger.info(f'Total # params: {total_params}')

        loss_tracker = {
            'test_loss_a': 10e4,
            'test_rec_a': 10e4,
            'test_kld_a': 10e4,
            'ep_loss': 0,
            'ep_rec': 0,
            'ep_kld': 0
        }

        # train for n_epoch epochs
        logger.info(
            'Model creation and data processing done, Training starts.')
        decoder_search = SEARCH_FACTORY[
            params.get('decoder_search', 'sampling')
        ](
            temperature=params.get('temperature', 1.),
            beam_width=params.get('beam_width', 3),
            top_tokens=params.get('top_tokens', 5)
        )  # yapf: disable

        if writer:
            pparams = params.copy()
            pparams['training_file'] = train_smiles_filepath
            pparams['test_file'] = test_smiles_filepath
            pparams['language_file'] = smiles_language_filepath
            pparams['model_path'] = model_path
            pparams = {
                k: v if v is not None else 'N.A.'
                for k, v in params.items()
            }
            pparams['training_name'] = training_name
            from pprint import pprint
            pprint(pparams)
            writer.add_hparams(hparam_dict=pparams, metric_dict={})

        for epoch in range(params['epochs'] + 1):
            t = time()
            loss_tracker = train_vae(
                epoch,
                gru_vae,
                train_data_loader,
                test_data_loader,
                smiles_language,
                model_dir,
                search=decoder_search,
                optimizer=params.get('optimizer', 'adadelta'),
                lr=params['learning_rate'],
                kl_growth=params['kl_growth'],
                input_keep=params['input_keep'],
                test_input_keep=params['test_input_keep'],
                generate_len=params['generate_len'],
                log_interval=params['log_interval'],
                save_interval=params['save_interval'],
                eval_interval=params['eval_interval'],
                loss_tracker=loss_tracker,
                logger=logger,
                # writer=writer,
                batch_mode=params.get('batch_mode'))
            logger.info(f'Epoch {epoch}, took {time() - t:.1f}.')

        logger.info('OVERALL: \t Best loss = {0:.4f} in Ep {1}, '
                    'best Rec = {2:.4f} in Ep {3}, '
                    'best KLD = {4:.4f} in Ep {5}'.format(
                        loss_tracker['test_loss_a'], loss_tracker['ep_loss'],
                        loss_tracker['test_rec_a'], loss_tracker['ep_rec'],
                        loss_tracker['test_kld_a'], loss_tracker['ep_kld']))
        logger.info('Training done, shutting down.')
    except Exception:
        logger.exception('Exception occurred while running train_vae.py.')
def main(*, parser_namespace):

    disable_rdkit_logging()

    # read the params json
    params = dict()
    with open(parser_namespace.params_path) as f:
        params.update(json.load(f))

    # get params, json args take precedence
    mol_model_path = params.get('mol_model_path',
                                parser_namespace.mol_model_path)
    protein_model_path = params.get('protein_model_path',
                                    parser_namespace.protein_model_path)
    affinity_model_path = params.get('affinity_model_path',
                                     parser_namespace.affinity_model_path)
    protein_data_path = params.get('protein_data_path',
                                   parser_namespace.protein_data_path)
    model_name = params.get(
        'model_name', parser_namespace.model_name
    )   # yapf: disable
    test_id = int(params.get(
        'test_protein_id', parser_namespace.test_protein_id
    ))   # yapf: disable
    unbiased_preds_path = params.get(
        'unbiased_predictions_path', parser_namespace.unbiased_predictions_path
    )   # yapf: disable
    model_name += '_' + str(test_id)
    logger.info(f'Model with name {model_name} starts.')

    # passing optional paths to params to possibly update_reward_fn
    optional_reward_args = [
        'tox21_path', 'organdb_path', 'site', 'clintox_path', 'sider_path'
    ]
    for arg in optional_reward_args:
        if parser_namespace.__dict__[arg]:
            # json still has presedence
            params[arg] = params.get(arg, parser_namespace.__dict__[arg])

    # Load protein sequence data
    if protein_data_path.endswith('.smi'):
        protein_df = read_smi(protein_data_path, names=['Sequence'])
    elif protein_data_path.endswith('.csv'):
        protein_df = pd.read_csv(protein_data_path, index_col='entry_name')
    else:
        raise TypeError(
            f"{protein_data_path.split('.')[-1]} files are not supported.")

    protein_test_name = protein_df.iloc[test_id].name
    logger.info(f'Test protein is {protein_test_name}')

    # Restore SMILES Model
    with open(os.path.join(mol_model_path, 'model_params.json')) as f:
        mol_params = json.load(f)

    gru_encoder = StackGRUEncoder(mol_params)
    gru_decoder = StackGRUDecoder(mol_params)
    generator = TeacherVAE(gru_encoder, gru_decoder)
    generator.load(os.path.join(
        mol_model_path,
        f"weights/best_{params.get('smiles_metric', 'rec')}.pt"),
                   map_location=get_device())
    # Load languages
    generator_smiles_language = SMILESLanguage.load(
        os.path.join(mol_model_path, 'selfies_language.pkl'))
    generator._associate_language(generator_smiles_language)

    # Restore protein model
    with open(os.path.join(protein_model_path, 'model_params.json')) as f:
        protein_params = json.load(f)

    # Define network
    protein_encoder = ENCODER_FACTORY['dense'](protein_params)
    protein_encoder.load(os.path.join(
        protein_model_path,
        f"weights/best_{params.get('omics_metric','both')}_encoder.pt"),
                         map_location=get_device())
    protein_encoder.eval()

    # Restore affinity predictor
    with open(os.path.join(affinity_model_path, 'model_params.json')) as f:
        predictor_params = json.load(f)
    predictor = MODEL_FACTORY['bimodal_mca'](predictor_params)
    predictor.load(os.path.join(
        affinity_model_path,
        f"weights/best_{params.get('p_metric', 'ROC-AUC')}_bimodal_mca.pt"),
                   map_location=get_device())
    predictor.eval()

    # Load languages
    affinity_smiles_language = SMILESLanguage.load(
        os.path.join(affinity_model_path, 'smiles_language.pkl'))
    affinity_protein_language = ProteinLanguage.load(
        os.path.join(affinity_model_path, 'protein_language.pkl'))
    predictor._associate_language(affinity_smiles_language)
    predictor._associate_language(affinity_protein_language)

    # Specifies the baseline model used for comparison
    unbiased_preds = np.array(
        pd.read_csv(
            os.path.join(unbiased_preds_path, protein_test_name + '.csv')
        )['affinity'].values
    )  # yapf: disable

    # Create a fresh model that will be optimized
    gru_encoder_rl = StackGRUEncoder(mol_params)
    gru_decoder_rl = StackGRUDecoder(mol_params)
    generator_rl = TeacherVAE(gru_encoder_rl, gru_decoder_rl)
    generator_rl.load(os.path.join(
        mol_model_path, f"weights/best_{params.get('metric', 'rec')}.pt"),
                      map_location=get_device())
    generator_rl.eval()
    # Load languages
    generator_rl._associate_language(generator_smiles_language)

    protein_encoder_rl = ENCODER_FACTORY['dense'](protein_params)
    protein_encoder_rl.load(os.path.join(
        protein_model_path,
        f"weights/best_{params.get('metric', 'both')}_encoder.pt"),
                            map_location=get_device())
    protein_encoder_rl.eval()
    model_folder_name = model_name
    learner = ReinforceProtein(generator_rl, protein_encoder_rl, predictor,
                               protein_df, params, model_folder_name, logger)

    biased_ratios, tox_ratios = [], []
    rewards, rl_losses = [], []
    gen_mols, gen_prot, gen_affinity, mode = [], [], [], []

    logger.info(f'Model stored at {learner.model_path}')

    for epoch in range(1, params['epochs'] + 1):

        for step in range(1, params['steps']):

            # Randomly sample a protein
            protein_name = np.random.choice(protein_df.index)
            while protein_name == protein_test_name:
                protein_name = np.random.choice(protein_df.index)

            logger.info(f'Current train protein: {protein_name}')

            rew, loss = learner.policy_gradient(protein_name, epoch,
                                                params['batch_size'])
            logger.info(
                f"Epoch {epoch:d}/{params['epochs']:d}, step {step:d}/"
                f"{params['steps']:d}\t loss={loss:.2f}, mean rew={rew:.2f}")

            rewards.append(rew.item())
            rl_losses.append(loss)

        # Save model
        if epoch % 10 == 0:
            learner.save(f'gen_{epoch}.pt', f'enc_{epoch}.pt')
        logger.info(f'EVAL protein: {protein_test_name}')

        smiles, preds = (learner.generate_compounds_and_evaluate(
            epoch, params['eval_batch_size'], protein_test_name))
        gs = [s for i, s in enumerate(smiles) if preds[i] > 0.5]
        gp = preds[preds > 0.5]
        for p, s in zip(gp, gs):
            gen_mols.append(s)
            gen_prot.append(protein_test_name)
            gen_affinity.append(p)
            mode.append('eval')

        inds = np.argsort(gp)[::-1]
        for i in inds[:5]:
            logger.info(f'Epoch {epoch:d}, generated {gs[i]} against '
                        f'{protein_test_name}.\n Predicted IC50 = {gp[i]}. ')

        plot_and_compare_proteins(unbiased_preds, preds, protein_test_name,
                                  epoch, learner.model_path, 'train',
                                  params['eval_batch_size'])
        biased_ratios.append(
            np.round(100 * (np.sum(preds > 0.5) / len(preds)), 1))
        all_toxes = np.array([learner.tox21(s) for s in smiles])
        tox_ratios.append(
            np.round(100 * (np.sum(all_toxes == 1.) / len(all_toxes)), 1))
        logger.info(f'Percentage of non-toxic compounds {tox_ratios[-1]}')

        toxes = [learner.tox21(s) for s in gen_mols]
        # Save results (good molecules!) in DF
        df = pd.DataFrame({
            'protein': gen_prot,
            'SMILES': gen_mols,
            'Binding probability': gen_affinity,
            'mode': mode,
            'Tox21': toxes
        })
        df.to_csv(os.path.join(learner.model_path, 'results', 'generated.csv'))
        # Plot loss development
        loss_df = pd.DataFrame({'loss': rl_losses, 'rewards': rewards})
        loss_df.to_csv(learner.model_path +
                       '/results/loss_reward_evolution.csv')
        plot_loss(rl_losses,
                  rewards,
                  params['epochs'],
                  protein_name,
                  learner.model_path,
                  rolling=5)
    pd.DataFrame({
        'efficacy_ratio': biased_ratios,
        'tox_ratio': tox_ratios
    }).to_csv(learner.model_path + '/results/ratios.csv')
Beispiel #15
0
def main(
    train_sensitivity_filepath, test_sensitivity_filepath, gep_filepath,
    smi_filepath, gene_filepath, smiles_language_filepath, model_path,
    params_filepath, training_name
):

    logger = logging.getLogger(f'{training_name}')
    # Process parameter file:
    params = {}
    with open(params_filepath) as fp:
        params.update(json.load(fp))

    # Create model directory and dump files
    model_dir = os.path.join(model_path, training_name)
    os.makedirs(os.path.join(model_dir, 'weights'), exist_ok=True)
    os.makedirs(os.path.join(model_dir, 'results'), exist_ok=True)
    with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp:
        json.dump(params, fp, indent=4)

    # Prepare the dataset
    logger.info("Start data preprocessing...")

    # Load SMILES language
    smiles_language = SMILESLanguage.load(smiles_language_filepath)

    # Load the gene list
    with open(gene_filepath, 'rb') as f:
        gene_list = pickle.load(f)

    # Assemble datasets
    train_dataset = DrugSensitivityDataset(
        drug_sensitivity_filepath=train_sensitivity_filepath,
        smi_filepath=smi_filepath,
        gene_expression_filepath=gep_filepath,
        smiles_language=smiles_language,
        gene_list=gene_list,
        drug_sensitivity_min_max=params.get('drug_sensitivity_min_max', True),
        drug_sensitivity_processing_parameters=params.get(
            'drug_sensitivity_processing_parameters', {}
        ),
        augment=params.get('augment_smiles', True),
        canonical=params.get('canonical', False),
        kekulize=params.get('kekulize', False),
        all_bonds_explicit=params.get('all_bonds_explicit', False),
        all_hs_explicit=params.get('all_hs_explicit', False),
        randomize=params.get('randomize', False),
        remove_bonddir=params.get('remove_bonddir', False),
        remove_chirality=params.get('remove_chirality', False),
        selfies=params.get('selfies', False),
        add_start_and_stop=params.get('smiles_start_stop_token', True),
        padding_length=params.get('smiles_padding_length', None),
        gene_expression_standardize=params.get(
            'gene_expression_standardize', True
        ),
        gene_expression_min_max=params.get('gene_expression_min_max', False),
        gene_expression_processing_parameters=params.get(
            'gene_expression_processing_parameters', {}
        ),
        device=torch.device(params.get('dataset_device', 'cpu')),
        backend='eager'
    )
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=params['batch_size'],
        shuffle=True,
        drop_last=True,
        num_workers=params.get('num_workers', 0)
    )

    test_dataset = DrugSensitivityDataset(
        drug_sensitivity_filepath=test_sensitivity_filepath,
        smi_filepath=smi_filepath,
        gene_expression_filepath=gep_filepath,
        smiles_language=smiles_language,
        gene_list=gene_list,
        drug_sensitivity_min_max=params.get('drug_sensitivity_min_max', True),
        drug_sensitivity_processing_parameters=params.get(
            'drug_sensitivity_processing_parameters',
            train_dataset.drug_sensitivity_processing_parameters
        ),
        augment=params.get('augment_test_smiles', False),
        canonical=params.get('canonical', False),
        kekulize=params.get('kekulize', False),
        all_bonds_explicit=params.get('all_bonds_explicit', False),
        all_hs_explicit=params.get('all_hs_explicit', False),
        randomize=params.get('randomize', False),
        remove_bonddir=params.get('remove_bonddir', False),
        remove_chirality=params.get('remove_chirality', False),
        selfies=params.get('selfies', False),
        add_start_and_stop=params.get('smiles_start_stop_token', True),
        padding_length=params.get('smiles_padding_length', None),
        gene_expression_standardize=params.get(
            'gene_expression_standardize', True
        ),
        gene_expression_min_max=params.get('gene_expression_min_max', False),
        gene_expression_processing_parameters=params.get(
            'gene_expression_processing_parameters',
            train_dataset.gene_expression_dataset.processing
        ),
        device=torch.device(params.get('dataset_device', 'cpu')),
        backend='eager'
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=params['batch_size'],
        shuffle=True,
        drop_last=True,
        num_workers=params.get('num_workers', 0)
    )
    logger.info(
        f'Training dataset has {len(train_dataset)} samples, test set has '
        f'{len(test_dataset)}.'
    )

    device = get_device()
    logger.info(
        f'Device for data loader is {train_dataset.device} and for '
        f'model is {device}'
    )
    save_top_model = os.path.join(model_dir, 'weights/{}_{}_{}.pt')
    params.update({  # yapf: disable
        'number_of_genes': len(gene_list),  # yapf: disable
        'smiles_vocabulary_size': smiles_language.number_of_tokens,
        'drug_sensitivity_processing_parameters':
            train_dataset.drug_sensitivity_processing_parameters,
        'gene_expression_processing_parameters':
            train_dataset.gene_expression_dataset.processing
    })

    model = MODEL_FACTORY[params.get('model_fn', 'mca')](params).to(device)
    model._associate_language(smiles_language)

    if os.path.isfile(os.path.join(model_dir, 'weights', 'best_mse_mca.pt')):
        logger.info('Found existing model, restoring now...')
        model.load(os.path.join(model_dir, 'weights', 'mse_best_mca.pt'))

        with open(os.path.join(model_dir, 'results', 'mse.json'), 'r') as f:
            info = json.load(f)

            min_rmse = info['best_rmse']
            max_pearson = info['best_pearson']
            min_loss = info['test_loss']

    else:
        min_loss, min_rmse, max_pearson = 100, 1000, 0

    # Define optimizer
    optimizer = (
        OPTIMIZER_FACTORY[params.get('optimizer', 'Adam')]
        (model.parameters(), lr=params.get('lr', 0.01))
    )

    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    params.update({'number_of_parameters': num_params})
    logger.info(f'Number of parameters {num_params}')

    # Overwrite params.json file with updated parameters.
    with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp:
        json.dump(params, fp)

    # Start training
    logger.info('Training about to start...\n')
    t = time()

    model.save(
        save_top_model.format('epoch', '0', params.get('model_fn', 'mca'))
    )

    for epoch in range(params['epochs']):

        model.train()
        logger.info(params_filepath.split('/')[-1])
        logger.info(f"== Epoch [{epoch}/{params['epochs']}] ==")
        train_loss = 0

        for ind, (smiles, gep, y) in enumerate(train_loader):
            y_hat, pred_dict = model(
                torch.squeeze(smiles.to(device)), gep.to(device)
            )
            loss = model.loss(y_hat, y.to(device))
            optimizer.zero_grad()
            loss.backward()
            # Apply gradient clipping
            # torch.nn.utils.clip_grad_norm_(model.parameters(),1e-6)
            optimizer.step()
            train_loss += loss.item()

        logger.info(
            "\t **** TRAINING ****   "
            f"Epoch [{epoch + 1}/{params['epochs']}], "
            f"loss: {train_loss / len(train_loader):.5f}. "
            f"This took {time() - t:.1f} secs."
        )
        t = time()

        # Measure validation performance
        model.eval()
        with torch.no_grad():
            test_loss = 0
            predictions = []
            labels = []
            for ind, (smiles, gep, y) in enumerate(test_loader):
                y_hat, pred_dict = model(
                    torch.squeeze(smiles.to(device)), gep.to(device)
                )
                predictions.append(y_hat)
                labels.append(y)
                loss = model.loss(y_hat, y.to(device))
                test_loss += loss.item()

        predictions = np.array(
            [p.cpu() for preds in predictions for p in preds]
        )
        labels = np.array([l.cpu() for label in labels for l in label])
        test_pearson_a = pearsonr(
            torch.Tensor(predictions), torch.Tensor(labels)
        )
        test_rmse_a = np.sqrt(np.mean((predictions - labels)**2))
        test_loss_a = test_loss / len(test_loader)
        logger.info(
            f"\t **** TESTING **** Epoch [{epoch + 1}/{params['epochs']}], "
            f"loss: {test_loss_a:.5f}, "
            f"Pearson: {test_pearson_a:.3f}, "
            f"RMSE: {test_rmse_a:.3f}"
        )

        def save(path, metric, typ, val=None):
            model.save(path.format(typ, metric, params.get('model_fn', 'mca')))
            with open(
                os.path.join(model_dir, 'results', metric + '.json'), 'w'
            ) as f:
                json.dump(info, f)
            np.save(
                os.path.join(model_dir, 'results', metric + '_preds.npy'),
                np.vstack([predictions, labels])
            )
            if typ == 'best':
                logger.info(
                    f'\t New best performance in "{metric}"'
                    f' with value : {val:.7f} in epoch: {epoch}'
                )

        def update_info():
            return {
                'best_rmse': str(min_rmse),
                'best_pearson': str(float(max_pearson)),
                'test_loss': str(min_loss),
                'predictions': [float(p) for p in predictions]
            }

        if test_loss_a < min_loss:
            min_rmse = test_rmse_a
            min_loss = test_loss_a
            min_loss_pearson = test_pearson_a
            info = update_info()
            save(save_top_model, 'mse', 'best', min_loss)
            ep_loss = epoch
        if test_pearson_a > max_pearson:
            max_pearson = test_pearson_a
            max_pearson_loss = test_loss_a
            info = update_info()
            save(save_top_model, 'pearson', 'best', max_pearson)
            ep_pearson = epoch
        if (epoch + 1) % params.get('save_model', 100) == 0:
            save(save_top_model, 'epoch', str(epoch))
    logger.info(
        'Overall best performances are: \n \t'
        f'Loss = {min_loss:.4f} in epoch {ep_loss} '
        f'\t (Pearson was {min_loss_pearson:4f}) \n \t'
        f'Pearson = {max_pearson:.4f} in epoch {ep_pearson} '
        f'\t (Loss was {max_pearson_loss:2f})'
    )
    save(save_top_model, 'training', 'done')
    logger.info('Done with training, models saved, shutting down.')
def main(*, parser_namespace):

    disable_rdkit_logging()

    # read the params json
    params = dict()
    with open(parser_namespace.params_path) as f:
        params.update(json.load(f))

    # get params
    mol_model_path = params.get('mol_model_path',
                                parser_namespace.mol_model_path)
    omics_model_path = params.get('omics_model_path',
                                  parser_namespace.omics_model_path)
    ic50_model_path = params.get('ic50_model_path',
                                 parser_namespace.ic50_model_path)
    omics_data_path = params.get('omics_data_path',
                                 parser_namespace.omics_data_path)
    model_name = params.get(
        'model_name', parser_namespace.model_name
    )   # yapf: disable
    site = params.get(
        'site', parser_namespace.site
    )   # yapf: disable

    params['site'] = site

    logger.info(f'Model with name {model_name} starts.')

    # Load omics profiles for conditional generation,
    # complement with avg per site
    omics_df = pd.read_pickle(omics_data_path)
    omics_df = add_avg_profile(omics_df)

    # Restore SMILES Model
    with open(os.path.join(mol_model_path, 'model_params.json')) as f:
        mol_params = json.load(f)
    gru_encoder = StackGRUEncoder(mol_params)
    gru_decoder = StackGRUDecoder(mol_params)
    generator = TeacherVAE(gru_encoder, gru_decoder)
    generator.load(os.path.join(
        mol_model_path,
        f"weights/best_{params.get('smiles_metric', 'rec')}.pt"),
                   map_location=get_device())
    # Load languages
    generator_smiles_language = SMILESLanguage.load(
        os.path.join(mol_model_path, 'selfies_language.pkl'))
    generator._associate_language(generator_smiles_language)

    # Restore omics model
    with open(os.path.join(omics_model_path, 'model_params.json')) as f:
        cell_params = json.load(f)

    # Define network
    cell_encoder = ENCODER_FACTORY['dense'](cell_params)
    cell_encoder.load(os.path.join(
        omics_model_path,
        f"weights/best_{params.get('omics_metric','both')}_encoder.pt"),
                      map_location=get_device())
    cell_encoder.eval()

    # Restore PaccMann
    with open(os.path.join(ic50_model_path, 'model_params.json')) as f:
        paccmann_params = json.load(f)
    paccmann_predictor = MODEL_FACTORY['mca'](paccmann_params)
    paccmann_predictor.load(os.path.join(
        ic50_model_path,
        f"weights/best_{params.get('ic50_metric', 'rmse')}_mca.pt"),
                            map_location=get_device())
    paccmann_predictor.eval()
    paccmann_smiles_language = SMILESLanguage.load(
        os.path.join(ic50_model_path, 'smiles_language.pkl'))
    paccmann_predictor._associate_language(paccmann_smiles_language)

    # Specifies the baseline model used for comparison
    baseline = ReinforceOmic(generator, cell_encoder, paccmann_predictor,
                             omics_df, params, 'baseline', logger)

    # Create a fresh model that will be optimized
    gru_encoder_rl = StackGRUEncoder(mol_params)
    gru_decoder_rl = StackGRUDecoder(mol_params)
    generator_rl = TeacherVAE(gru_encoder_rl, gru_decoder_rl)
    generator_rl.load(os.path.join(
        mol_model_path, f"weights/best_{params.get('metric', 'rec')}.pt"),
                      map_location=get_device())
    generator_rl.eval()
    generator_rl._associate_language(generator_smiles_language)

    cell_encoder_rl = ENCODER_FACTORY['dense'](cell_params)
    cell_encoder_rl.load(os.path.join(
        omics_model_path,
        f"weights/best_{params.get('metric', 'both')}_encoder.pt"),
                         map_location=get_device())
    cell_encoder_rl.eval()
    model_folder_name = site + '_' + model_name
    learner = ReinforceOmic(generator_rl, cell_encoder_rl, paccmann_predictor,
                            omics_df, params, model_folder_name, logger)

    # Split the samples for conditional generation and initialize training
    train_omics, test_omics = omics_data_splitter(
        omics_df, site, params.get('test_fraction', 0.2))
    rewards, rl_losses = [], []
    gen_mols, gen_cell, gen_ic50, modes = [], [], [], []
    logger.info('Models restored, start training.')

    for epoch in range(1, params['epochs'] + 1):

        for step in range(1, params['steps']):

            # Randomly sample a cell line:
            cell_line = np.random.choice(train_omics)

            rew, loss = learner.policy_gradient(cell_line, epoch,
                                                params['batch_size'])
            print(f"Epoch {epoch:d}/{params['epochs']:d}, step {step:d}/"
                  f"{params['steps']:d}\t loss={loss:.2f}, rew={rew:.2f}")

            rewards.append(rew.item())
            rl_losses.append(loss)

        # Save model
        learner.save(f'gen_{epoch}.pt', f'enc_{epoch}.pt')

        # Compare baseline and trained model on cell line
        base_smiles, base_preds = baseline.generate_compounds_and_evaluate(
            epoch, params['eval_batch_size'], cell_line)
        smiles, preds = learner.generate_compounds_and_evaluate(
            epoch, params['eval_batch_size'], cell_line)
        gs = [
            s for i, s in enumerate(smiles)
            if preds[i] < learner.ic50_threshold
        ]
        gp = preds[preds < learner.ic50_threshold]
        for p, s in zip(gp, gs):
            gen_mols.append(s)
            gen_cell.append(cell_line)
            gen_ic50.append(p)
            modes.append('train')

        plot_and_compare(base_preds, preds, site, cell_line, epoch,
                         learner.model_path, 'train',
                         params['eval_batch_size'])

        # Evaluate on a validation cell line.
        eval_cell_line = np.random.choice(test_omics)
        base_smiles, base_preds = baseline.generate_compounds_and_evaluate(
            epoch, params['eval_batch_size'], eval_cell_line)
        smiles, preds = learner.generate_compounds_and_evaluate(
            epoch, params['eval_batch_size'], eval_cell_line)
        plot_and_compare(base_preds, preds, site, eval_cell_line, epoch,
                         learner.model_path, 'test', params['eval_batch_size'])
        gs = [
            s for i, s in enumerate(smiles)
            if preds[i] < learner.ic50_threshold
        ]
        gp = preds[preds < learner.ic50_threshold]
        for p, s in zip(gp, gs):
            gen_mols.append(s)
            gen_cell.append(eval_cell_line)
            gen_ic50.append(p)
            modes.append('test')

        inds = np.argsort(preds)
        for i in inds[:5]:
            logger.info(f'Epoch {epoch:d}, generated {smiles[i]} against '
                        f'{eval_cell_line}.\n Predicted IC50 = {preds[i]}. ')

        # Save results (good molecules!) in DF
        df = pd.DataFrame({
            'cell_line': gen_cell,
            'SMILES': gen_mols,
            'IC50': gen_ic50,
            'mode': modes,
            'tox21': [learner.tox21(s) for s in gen_mols]
        })
        df.to_csv(os.path.join(learner.model_path, 'results', 'generated.csv'))
        # Plot loss development
        loss_df = pd.DataFrame({'loss': rl_losses, 'rewards': rewards})
        loss_df.to_csv(learner.model_path +
                       '/results/loss_reward_evolution.csv')
        plot_loss(rl_losses,
                  rewards,
                  params['epochs'],
                  cell_line,
                  learner.model_path,
                  rolling=5,
                  site=site)
Beispiel #17
0
def main(model_path, smi_filepath, labels_filepath, output_folder,
         smiles_language_filepath, model_id, confidence):

    logging.basicConfig(level=logging.INFO, format='%(message)s')
    logger = logging.getLogger('eval_toxicity')
    logger.setLevel(logging.INFO)
    disable_rdkit_logging()

    # Process parameter file:
    params = {}
    with open(os.path.join(model_path, 'model_params.json'), 'r') as fp:
        params.update(json.load(fp))

    # Create model directory
    os.makedirs(output_folder, exist_ok=True)

    if model_id not in MODEL_FACTORY.keys():
        raise KeyError(
            f'Model ID: Pass one of {MODEL_FACTORY.keys()}, not {model_id}')

    device = get_device()
    weights_path = os.path.join(model_path, 'weights',
                                f'best_ROC-AUC_{model_id}.pt')

    # Restore model
    model = MODEL_FACTORY[model_id](params).to(device)
    if os.path.isfile(weights_path):
        try:
            model.load(weights_path, map_location=device)
        except Exception:
            logger.error(f'Error in model restoring from {weights_path}')
    else:
        logger.info(f'Did not find weights at {weights_path}, '
                    f'name weights: "best_ROC-AUC_{model_id}.pt".')
    model.eval()

    logger.info('Model restored. Model specs & parameters follow')
    for name, param in model.named_parameters():
        logger.info((name, param.shape))

    # Load language
    if smiles_language_filepath == '.':
        smiles_language_filepath = os.path.join(model_path,
                                                'smiles_language.pkl')
    smiles_language = SMILESLanguage.load(smiles_language_filepath)

    # Assemble datasets
    smiles_dataset = SMILESDataset(
        smi_filepath,
        smiles_language=smiles_language,
        padding_length=params.get('smiles_padding_length', None),
        padding=params.get('padd_smiles', True),
        add_start_and_stop=params.get('add_start_stop_token', True),
        augment=params.get('augment_test_smiles', False),
        canonical=params.get('test_canonical', False),
        kekulize=params.get('test_kekulize', False),
        all_bonds_explicit=params.get('test_all_bonds_explicit', False),
        all_hs_explicit=params.get('test_all_hs_explicit', False),
        randomize=False,
        remove_bonddir=params.get('test_remove_bonddir', False),
        remove_chirality=params.get('test_remove_chirality', False),
        selfies=params.get('selfies', False),
        sanitize=params.get('test_sanitize', False),
    )
    logger.info(
        f'SMILES Padding length is {smiles_dataset._dataset.padding_length}.'
        'Consider setting manually if this looks wrong.')
    dataset = AnnotatedDataset(annotations_filepath=labels_filepath,
                               dataset=smiles_dataset,
                               device=device)
    loader = torch.utils.data.DataLoader(dataset=dataset,
                                         batch_size=10,
                                         shuffle=False,
                                         drop_last=False,
                                         num_workers=0)
    if confidence:
        smiles_aleatoric_dataset = SMILESDataset(
            smi_filepath,
            smiles_language=smiles_language,
            padding_length=params.get('smiles_padding_length', None),
            padding=params.get('padd_smiles', True),
            add_start_and_stop=params.get('add_start_stop_token', True),
            augment=True,  # Natively true for aleatoric uncertainity estimate
            canonical=params.get('test_canonical', False),
            kekulize=params.get('test_kekulize', False),
            all_bonds_explicit=params.get('test_all_bonds_explicit', False),
            all_hs_explicit=params.get('test_all_hs_explicit', False),
            remove_bonddir=params.get('test_remove_bonddir', False),
            remove_chirality=params.get('test_remove_chirality', False),
            selfies=params.get('selfies', False),
            sanitize=params.get('test_sanitize', False),
        )
        ale_dataset = AnnotatedDataset(annotations_filepath=labels_filepath,
                                       dataset=smiles_aleatoric_dataset,
                                       device=device)
        ale_loader = torch.utils.data.DataLoader(dataset=ale_dataset,
                                                 batch_size=10,
                                                 shuffle=False,
                                                 drop_last=False,
                                                 num_workers=0)
        smiles_epi_dataset = SMILESDataset(
            smi_filepath,
            smiles_language=smiles_language,
            padding_length=params.get('smiles_padding_length', None),
            padding=params.get('padd_smiles', True),
            add_start_and_stop=params.get('add_start_stop_token', True),
            augment=False,  # Natively false for epistemic uncertainity estimate
            canonical=params.get('test_canonical', False),
            kekulize=params.get('test_kekulize', False),
            all_bonds_explicit=params.get('test_all_bonds_explicit', False),
            all_hs_explicit=params.get('test_all_hs_explicit', False),
            remove_bonddir=params.get('test_remove_bonddir', False),
            remove_chirality=params.get('test_remove_chirality', False),
            selfies=params.get('selfies', False),
            sanitize=params.get('test_sanitize', False),
        )
        epi_dataset = AnnotatedDataset(annotations_filepath=labels_filepath,
                                       dataset=smiles_epi_dataset,
                                       device=device)
        epi_loader = torch.utils.data.DataLoader(dataset=epi_dataset,
                                                 batch_size=10,
                                                 shuffle=False,
                                                 drop_last=False,
                                                 num_workers=0)

    logger.info(f'Device for data loader is {dataset.device} and for '
                f'model is {device}')
    # Start evaluation
    logger.info('Evaluation about to start...\n')

    preds, labels, attention_scores, smiles = [], [], [], []
    for idx, (smiles_batch, labels_batch) in enumerate(loader):
        pred, pred_dict = model(smiles_batch.to(device))
        preds.extend(pred.detach().squeeze().tolist())
        attention_scores.extend(
            torch.stack(pred_dict['smiles_attention'], dim=1).detach())
        smiles.extend([
            smiles_language.token_indexes_to_smiles(s.tolist())
            for s in smiles_batch
        ])
        labels.extend(labels_batch.squeeze().tolist())
    # Scores are now 3D: num_samples x num_att_layers x padding_length
    attention = torch.stack(attention_scores, dim=0).numpy()

    if confidence:
        # Compute uncertainity estimates and save them
        epistemic_conf, epistemic_pred = monte_carlo_dropout(model,
                                                             regime='loader',
                                                             loader=epi_loader)
        aleatoric_conf, aleatoric_pred = test_time_augmentation(
            model, regime='loader', loader=ale_loader)
        epi_conf_df = pd.DataFrame(
            data=epistemic_conf.numpy(),
            columns=[
                f'epistemic_conf_{i}' for i in range(epistemic_conf.shape[1])
            ],
        )
        ale_conf_df = pd.DataFrame(
            data=aleatoric_conf.numpy(),
            columns=[
                f'aleatoric_conf_{i}' for i in range(aleatoric_conf.shape[1])
            ],
        )
        epi_pred_df = pd.DataFrame(
            data=epistemic_pred.numpy(),
            columns=[
                f'epistemic_pred_{i}' for i in range(epistemic_pred.shape[1])
            ],
        )
        ale_pred_df = pd.DataFrame(
            data=aleatoric_pred.numpy(),
            columns=[
                f'aleatoric_pred_{i}' for i in range(aleatoric_pred.shape[1])
            ],
        )

    logger.info(f'Shape of attention scores {attention.shape}.')
    np.save(os.path.join(output_folder, 'attention_raw.npy'), attention)
    attention_avg = np.mean(attention, axis=1)
    att_df = pd.DataFrame(
        data=attention_avg,
        columns=[f'att_idx_{i}' for i in range(attention_avg.shape[1])],
    )
    pred_df = pd.DataFrame(
        data=preds,
        columns=[f'pred_{i}' for i in range(len(preds[0]))],
    )
    lab_df = pd.DataFrame(
        data=labels,
        columns=[f'label_{i}' for i in range(len(labels[0]))],
    )
    df = pd.concat([pred_df, lab_df], axis=1)
    if confidence:
        df = pd.concat(
            [df, epi_conf_df, ale_conf_df, epi_pred_df, ale_pred_df], axis=1)
    df = pd.concat([df, att_df], axis=1)
    df.insert(0, 'SMILES', smiles)
    df.to_csv(os.path.join(output_folder, 'results.csv'), index=False)

    logger.info('Done, shutting down.')