Example #1
0
def packed_sequential_data_preparation(input_batch,
                                       input_keep=1,
                                       start_index=2,
                                       end_index=3,
                                       dropout_index=1,
                                       device=get_device(),
                                       enforce_sorted=False):
    """
    Sequential Training Data Builder.

    Args:
        input_batch (torch.Tensor): Batch of padded sequences, output of
            nn.utils.rnn.pad_sequence(batch) of size
            `[sequence length, batch_size, 1]`.
        input_keep (float): The probability not to drop input sequence tokens
            according to a Bernoulli distribution with p = input_keep.
            Defaults to 1.
        start_index (int): The index of the sequence start token.
        end_index (int): The index of the sequence end token.
        dropout_index (int): The index of the dropout token. Defaults to 1.

    Returns:
    (torch.Tensor, torch.Tensor, torch.Tensor): encoder_seq, decoder_seq,
        target_seq

        encoder_seq is a batch of padded input sequences starting with the
            start_index, of size `[sequence length +1, batch_size, 1]`.
        decoder_seq is like encoder_seq but word dropout is applied
            (so if input_keep==1, then decoder_seq = encoder_seq).
        target_seq (torch.Tensor): Batch of padded target sequences ending
            in the end_index, of size `[sequence length +1, batch_size, 1]`.
    """
    def _process_sample(sample):
        if len(sample.shape) != 1:
            raise ValueError
        input = sample.long().to(device)
        decoder = input.clone()

        # apply token dropout if keep != 1
        if input_keep != 1:
            # mask for token dropout
            mask = Bernoulli(input_keep).sample(input.shape)
            mask = torch.LongTensor(mask.numpy())
            dropout_loc = np.where(mask == 0)[0]
            decoder[dropout_loc] = dropout_index

        # just .clone() propagates to graph
        target = torch.cat(
            [input[1:].detach().clone(),
             torch.Tensor([0]).long().to(device)])
        return input, decoder, target.to(device)

    batch = [_process_sample(sample) for sample in input_batch]

    encoder_decoder_target = zip(*batch)
    encoder_decoder_target = [
        torch.nn.utils.rnn.pack_sequence(entry, enforce_sorted=enforce_sorted)
        for entry in encoder_decoder_target
    ]
    return encoder_decoder_target
Example #2
0
    def load_pretrained_paccmann(self, params_file: str, lang_file: str,
                                 weights_file: str, batch_size: int,
                                 batch_mode: str):
        params = dict()
        with open(params_file, 'r') as f:
            params.update(json.load(f))
        params['batch_mode'] = batch_mode
        params['batch_size'] = batch_size

        self.selfies = params.get('selfies', False)

        self.device = get_device()
        self.smiles_language = SMILESLanguage.load(lang_file)

        self.gru_encoder = StackGRUEncoder(params).to(self.device)
        self.gru_decoder = StackGRUDecoder(params).to(self.device)
        self.gru_vae = TeacherVAE(self.gru_encoder,
                                  self.gru_decoder).to(self.device)
        self.gru_vae.load_state_dict(
            torch.load(weights_file, map_location=self.device))
        self.gru_vae.eval()

        transforms = []
        if self.selfies:
            transforms += [Selfies()]
        transforms += [
            SMILESToTokenIndexes(smiles_language=self.smiles_language)
        ]
        transforms += [ToTensor(device=self.device)]
        self.transform = Compose(transforms)
def main(*, parser_namespace):

    disable_rdkit_logging()

    # read the params json
    params = dict()
    with open(parser_namespace.params_path) as f:
        params.update(json.load(f))

    # get params
    mol_model_path = params.get('mol_model_path',
                                parser_namespace.mol_model_path)
    omics_model_path = params.get('omics_model_path',
                                  parser_namespace.omics_model_path)
    ic50_model_path = params.get('ic50_model_path',
                                 parser_namespace.ic50_model_path)
    omics_data_path = params.get('omics_data_path',
                                 parser_namespace.omics_data_path)
    model_name = params.get(
        'model_name', parser_namespace.model_name
    )   # yapf: disable
    site = params.get(
        'site', parser_namespace.site
    )   # yapf: disable

    params['site'] = site

    logger.info(f'Model with name {model_name} starts.')

    # Load omics profiles for conditional generation,
    # complement with avg per site
    omics_df = pd.read_pickle(omics_data_path)
    omics_df = add_avg_profile(omics_df)

    # Restore SMILES Model
    with open(os.path.join(mol_model_path, 'model_params.json')) as f:
        mol_params = json.load(f)
    gru_encoder = StackGRUEncoder(mol_params)
    gru_decoder = StackGRUDecoder(mol_params)
    generator = TeacherVAE(gru_encoder, gru_decoder)
    generator.load(os.path.join(
        mol_model_path,
        f"weights/best_{params.get('smiles_metric', 'rec')}.pt"),
                   map_location=get_device())
    # Load languages
    generator_smiles_language = SMILESLanguage.load(
        os.path.join(mol_model_path, 'selfies_language.pkl'))
    generator._associate_language(generator_smiles_language)

    # Restore omics model
    with open(os.path.join(omics_model_path, 'model_params.json')) as f:
        cell_params = json.load(f)

    # Define network
    cell_encoder = ENCODER_FACTORY['dense'](cell_params)
    cell_encoder.load(os.path.join(
        omics_model_path,
        f"weights/best_{params.get('omics_metric','both')}_encoder.pt"),
                      map_location=get_device())
    cell_encoder.eval()

    # Restore PaccMann
    with open(os.path.join(ic50_model_path, 'model_params.json')) as f:
        paccmann_params = json.load(f)
    paccmann_predictor = MODEL_FACTORY['mca'](paccmann_params)
    paccmann_predictor.load(os.path.join(
        ic50_model_path,
        f"weights/best_{params.get('ic50_metric', 'rmse')}_mca.pt"),
                            map_location=get_device())
    paccmann_predictor.eval()
    paccmann_smiles_language = SMILESLanguage.load(
        os.path.join(ic50_model_path, 'smiles_language.pkl'))
    paccmann_predictor._associate_language(paccmann_smiles_language)

    # Specifies the baseline model used for comparison
    baseline = ReinforceOmic(generator, cell_encoder, paccmann_predictor,
                             omics_df, params, 'baseline', logger)

    # Create a fresh model that will be optimized
    gru_encoder_rl = StackGRUEncoder(mol_params)
    gru_decoder_rl = StackGRUDecoder(mol_params)
    generator_rl = TeacherVAE(gru_encoder_rl, gru_decoder_rl)
    generator_rl.load(os.path.join(
        mol_model_path, f"weights/best_{params.get('metric', 'rec')}.pt"),
                      map_location=get_device())
    generator_rl.eval()
    generator_rl._associate_language(generator_smiles_language)

    cell_encoder_rl = ENCODER_FACTORY['dense'](cell_params)
    cell_encoder_rl.load(os.path.join(
        omics_model_path,
        f"weights/best_{params.get('metric', 'both')}_encoder.pt"),
                         map_location=get_device())
    cell_encoder_rl.eval()
    model_folder_name = site + '_' + model_name
    learner = ReinforceOmic(generator_rl, cell_encoder_rl, paccmann_predictor,
                            omics_df, params, model_folder_name, logger)

    # Split the samples for conditional generation and initialize training
    train_omics, test_omics = omics_data_splitter(
        omics_df, site, params.get('test_fraction', 0.2))
    rewards, rl_losses = [], []
    gen_mols, gen_cell, gen_ic50, modes = [], [], [], []
    logger.info('Models restored, start training.')

    for epoch in range(1, params['epochs'] + 1):

        for step in range(1, params['steps']):

            # Randomly sample a cell line:
            cell_line = np.random.choice(train_omics)

            rew, loss = learner.policy_gradient(cell_line, epoch,
                                                params['batch_size'])
            print(f"Epoch {epoch:d}/{params['epochs']:d}, step {step:d}/"
                  f"{params['steps']:d}\t loss={loss:.2f}, rew={rew:.2f}")

            rewards.append(rew.item())
            rl_losses.append(loss)

        # Save model
        learner.save(f'gen_{epoch}.pt', f'enc_{epoch}.pt')

        # Compare baseline and trained model on cell line
        base_smiles, base_preds = baseline.generate_compounds_and_evaluate(
            epoch, params['eval_batch_size'], cell_line)
        smiles, preds = learner.generate_compounds_and_evaluate(
            epoch, params['eval_batch_size'], cell_line)
        gs = [
            s for i, s in enumerate(smiles)
            if preds[i] < learner.ic50_threshold
        ]
        gp = preds[preds < learner.ic50_threshold]
        for p, s in zip(gp, gs):
            gen_mols.append(s)
            gen_cell.append(cell_line)
            gen_ic50.append(p)
            modes.append('train')

        plot_and_compare(base_preds, preds, site, cell_line, epoch,
                         learner.model_path, 'train',
                         params['eval_batch_size'])

        # Evaluate on a validation cell line.
        eval_cell_line = np.random.choice(test_omics)
        base_smiles, base_preds = baseline.generate_compounds_and_evaluate(
            epoch, params['eval_batch_size'], eval_cell_line)
        smiles, preds = learner.generate_compounds_and_evaluate(
            epoch, params['eval_batch_size'], eval_cell_line)
        plot_and_compare(base_preds, preds, site, eval_cell_line, epoch,
                         learner.model_path, 'test', params['eval_batch_size'])
        gs = [
            s for i, s in enumerate(smiles)
            if preds[i] < learner.ic50_threshold
        ]
        gp = preds[preds < learner.ic50_threshold]
        for p, s in zip(gp, gs):
            gen_mols.append(s)
            gen_cell.append(eval_cell_line)
            gen_ic50.append(p)
            modes.append('test')

        inds = np.argsort(preds)
        for i in inds[:5]:
            logger.info(f'Epoch {epoch:d}, generated {smiles[i]} against '
                        f'{eval_cell_line}.\n Predicted IC50 = {preds[i]}. ')

        # Save results (good molecules!) in DF
        df = pd.DataFrame({
            'cell_line': gen_cell,
            'SMILES': gen_mols,
            'IC50': gen_ic50,
            'mode': modes,
            'tox21': [learner.tox21(s) for s in gen_mols]
        })
        df.to_csv(os.path.join(learner.model_path, 'results', 'generated.csv'))
        # Plot loss development
        loss_df = pd.DataFrame({'loss': rl_losses, 'rewards': rewards})
        loss_df.to_csv(learner.model_path +
                       '/results/loss_reward_evolution.csv')
        plot_loss(rl_losses,
                  rewards,
                  params['epochs'],
                  cell_line,
                  learner.model_path,
                  rolling=5,
                  site=site)
Example #4
0
"""SMILES decoding from latent space module."""
import sys
from itertools import count
from typing import Any, List

import torch
from rdkit import Chem
from paccmann_chemistry.models.vae import TeacherVAE
from paccmann_chemistry.utils import get_device
from paccmann_chemistry.utils.search import SamplingSearch

device = get_device()


def get_stack_size(size_hint: int = 2) -> int:
    """Stack size from caller's frame from: https://stackoverflow.com/a/47956089.

    Args:
        size_hint: hint for the stack size. Defaults to 2.

    Returns:
        size of the stack.
    """
    get_frame = sys._getframe
    frame = None
    try:
        while True:
            frame = get_frame(size_hint)
            size_hint *= 2
    except ValueError:
        if frame:
Example #5
0
def train(
    epoch,
    model,
    train_loader,
    optimizer,
    scheduler,
    graph_gamma,
    growth_rate=0.0015,
    writer=None,
    verbose=False
):
    start_time = time()
    device = get_device()
    # selfies = train_loader.dataset._dataset.selfies
    data_preparation = packed_sequential_data_preparation
    model.gru_vae.to(device)
    model.gru_vae.train()

    input_keep = 1.
    start_index = 2
    end_index = 3

    train_loss = 0

    for _iter, data in tqdm.tqdm(
        enumerate(train_loader),
        total=len(train_loader),
        disable=(not verbose)
    ):

        seqs = data.x
        batch_size = len(seqs)

        # FIXME? variable batch size in data
        model.gru_decoder._update_batch_size(batch_size)
        model.gru_encoder._update_batch_size(batch_size)

        encoder_seq, decoder_seq, target_seq = data_preparation(
            seqs,
            input_keep=input_keep,
            start_index=start_index,
            end_index=end_index,
            device=device
        )

        optimizer.zero_grad()
        decoder_loss, mu, logvar, z = model.train_step(
            encoder_seq, decoder_seq, target_seq
        )

        # Compute distances
        graph = nx.from_edgelist(data.edge_index.T.tolist())
        graph.add_nodes_from(list(range(len(seqs))))
        dists = nx.floyd_warshall_numpy(graph)
        dists[np.isinf(dists)] = 0
        dists = torch.tensor(dists).to(device)

        kl_scale = 1 / (
            1 +
            np.exp((6 - growth_rate * epoch + (_iter / len(train_loader))))
        )

        z = z.squeeze()
        gr_loss = graph_loss(z, dists)

        loss = decoder_loss + graph_gamma * gr_loss

        loss.backward()
        optimizer.step()

        writer.add_scalar(
            'recon_loss', decoder_loss, _iter + epoch * len(train_loader)
        )
        writer.add_scalar(
            'graph_loss', gr_loss, _iter + epoch * len(train_loader)
        )
        writer.add_scalar(
            'loss', loss.item(), _iter + epoch * len(train_loader)
        )

        train_loss += loss.item()
        if _iter % (len(train_loader) // 10) == 0:
            tqdm.tqdm.write(f'{decoder_loss}\t{gr_loss}')
    scheduler.step()
    logger.info(
        f"Learning rate {optimizer.param_groups[0]['lr']}"
        f"\tkl_scale {kl_scale}"
    )
    logger.info(f'{epoch}\t{train_loss/_iter}\t{time()-start_time}')
Example #6
0
def main(parser_namespace):
    # model loading
    disable_rdkit_logging()
    affinity_path = parser_namespace.affinity_path
    svae_path = parser_namespace.svae_path
    svae_weights_path = os.path.join(svae_path, "weights", "best_rec.pt")
    results_file_name = parser_namespace.optimisation_name

    logger.add(results_file_name + ".log", rotation="10 MB")

    svae_params = dict()
    with open(os.path.join(svae_path, "model_params.json"), "r") as f:
        svae_params.update(json.load(f))

    smiles_language = SMILESLanguage.load(
        os.path.join(svae_path, "selfies_language.pkl"))

    # initialize encoder, decoder, testVAE, and GP_generator_MW
    gru_encoder = StackGRUEncoder(svae_params)
    gru_decoder = StackGRUDecoder(svae_params)
    gru_vae = TeacherVAE(gru_encoder, gru_decoder)
    gru_vae.load_state_dict(
        torch.load(svae_weights_path, map_location=get_device()))

    gru_vae._associate_language(smiles_language)
    gru_vae.eval()

    smiles_generator = SmilesGenerator(gru_vae)

    with open(os.path.join(affinity_path, "model_params.json")) as f:
        predictor_params = json.load(f)
    affinity_predictor = MODEL_FACTORY["bimodal_mca"](predictor_params)
    affinity_predictor.load(
        os.path.join(
            affinity_path,
            f"weights/best_{predictor_params.get('p_metric', 'ROC-AUC')}_bimodal_mca.pt",
        ),
        map_location=get_device(),
    )
    affinity_protein_language = ProteinLanguage.load(
        os.path.join(affinity_path, "protein_language.pkl"))
    affinity_smiles_language = SMILESLanguage.load(
        os.path.join(affinity_path, "smiles_language.pkl"))
    affinity_predictor._associate_language(affinity_smiles_language)
    affinity_predictor._associate_language(affinity_protein_language)
    affinity_predictor.eval()

    erg_protein = "MASTIKEALSVVSEDQSLFECAYGTPHLAKTEMTASSSSDYGQTSKMSPRVPQQDWLSQPPARVTIKMECNPSQVNGSRNSPDECSVAKGGKMVGSPDTVGMNYGSYMEEKHMPPPNMTTNERRVIVPADPTLWSTDHVRQWLEWAVKEYGLPDVNILLFQNIDGKELCKMTKDDFQRLTPSYNADILLSHLHYLRETPLPHLTSDDVDKALQNSPRLMHARNTGGAAFIFPNTSVYPEATQRITTRPDLPYEPPRRSAWTGHGHPTPQSKAAQPSPSTVPKTEDQRPQLDPYQILGPTSSRLANPGSGQIQLWQFLLELLSDSSNSSCITWEGTNGEFKMTDPDEVARRWGERKSKPNMNYDKLSRALRYYYDKNIMTKVHGKRYAYKFDFHGIAQALQPHPPESSLYKYPSDLPYMGSYHAHPQKMNFVAPHPPALPVTSSSFFAAPNPYWNSPTGGIYPNTRLPTSHMPSHLGTYY"

    target_minimization_function = AffinityMinimization(
        smiles_generator, 30, affinity_predictor, erg_protein)
    qed_function = QEDMinimization(smiles_generator, 30)
    sa_function = SAMinimization(smiles_generator, 30)
    combined_minimization = CombinedMinimization(
        [target_minimization_function, qed_function, sa_function], 1,
        [0.75, 1, 0.5])
    target_optimizer = GPOptimizer(combined_minimization.evaluate)

    params = dict(
        dimensions=[(-5.0, 5.0)] * 256,
        acq_func="EI",
        n_calls=20,
        n_initial_points=19,
        initial_point_generator="random",
        random_state=1234,
    )
    logger.info("Optimisation parameters: {params}", params=params)

    # optimisation
    for j in range(5):
        res = target_optimizer.optimize(params)
        latent_point = torch.tensor([[res.x]])

        with open(results_file_name + "_LP" + str(j + 1) + ".pkl", "wb") as f:
            pickle.dump(latent_point, f, protocol=2)

        smile_set = set()

        while len(smile_set) < 20:
            smiles = smiles_generator.generate_smiles(
                latent_point.repeat(1, 30, 1))
            smile_set.update(set(smiles))
        smile_set = list(smile_set)

        pad_smiles_predictor = LeftPadding(
            affinity_predictor.smiles_padding_length,
            affinity_predictor.smiles_language.padding_index,
        )
        to_tensor = ToTensor(get_device())
        smiles_num = [
            torch.unsqueeze(
                to_tensor(
                    pad_smiles_predictor(
                        affinity_predictor.smiles_language.
                        smiles_to_token_indexes(smile))),
                0,
            ) for smile in smile_set
        ]

        smiles_tensor = torch.cat(smiles_num, dim=0)

        pad_protein_predictor = LeftPadding(
            affinity_predictor.protein_padding_length,
            affinity_predictor.protein_language.padding_index,
        )

        protein_num = torch.unsqueeze(
            to_tensor(
                pad_protein_predictor(
                    affinity_predictor.protein_language.
                    sequence_to_token_indexes([erg_protein]))),
            0,
        )
        protein_num = protein_num.repeat(len(smile_set), 1)

        with torch.no_grad():
            pred, _ = affinity_predictor(smiles_tensor, protein_num)
        affinities = torch.squeeze(pred, 1).numpy()

        sas = SAS()
        sa_scores = [sas(smile) for smile in smile_set]
        qed_scores = [qed(Chem.MolFromSmiles(smile)) for smile in smile_set]

        # save to file
        file = results_file_name + str(j + 1) + ".txt"
        logger.info("creating {file}", file=file)

        with open(file, "w") as f:
            f.write(
                f'{"point":<10}{"Affinity":<10}{"QED":<10}{"SA":<10}{"smiles":<15}\n'
            )
            for i in range(20):
                dat = [
                    i + 1, affinities[i], qed_scores[i], sa_scores[i],
                    smile_set[i]
                ]
                f.write(
                    f'{dat[0]:<10}{"%.3f"%dat[1]:<10}{"%.3f"%dat[2]:<10}{"%.3f"%dat[3]:<10}{dat[4]:<15}\n'
                )
def main(parser_namespace):

    disable_rdkit_logging()

    model_path = parser_namespace.model_path
    data_path = parser_namespace.data_path

    weights_path = os.path.join(model_path, 'weights', 'best_loss.pt')

    device = get_device()
    # read the params json
    params = dict()
    with open(os.path.join(model_path, 'model_params.json'), 'r') as f:
        params.update(json.load(f))

    params['batch_size'] = 1

    # Load SMILES language
    smiles_language = SMILESLanguage.load(
        os.path.join(model_path, 'selfies_language.pkl'))

    data_preparation = get_data_preparation(params.get('batch_mode'))
    device = get_device()

    dataset = SMILESDataset(
        data_path,
        smiles_language=smiles_language,
        padding=False,
        selfies=params.get('selfies', False),
        add_start_and_stop=params.get('add_start_stop_token', True),
        augment=False,  #params.get('augment_smiles', False),
        canonical=params.get('canonical', False),
        kekulize=params.get('kekulize', False),
        all_bonds_explicit=params.get('all_bonds_explicit', False),
        all_hs_explicit=params.get('all_hs_explicit', False),
        remove_bonddir=params.get('remove_bonddir', False),
        remove_chirality=params.get('remove_chirality', False),
        backend='lazy',
        device=device)

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=params.get('batch_size', 64),
        collate_fn=collate_fn,
        drop_last=True,
        shuffle=True,
        pin_memory=params.get('pin_memory', True),
        num_workers=params.get('num_workers', 8))
    # initialize encoder and decoder
    gru_encoder = StackGRUEncoder(params).to(device)
    gru_decoder = StackGRUDecoder(params).to(device)
    gru_vae = TeacherVAE(gru_encoder, gru_decoder).to(device)
    logger.info('\n****MODEL SUMMARY***\n')
    for name, parameter in gru_vae.named_parameters():
        logger.info(f'Param {name}, shape:\t{parameter.shape}')
    total_params = sum(p.numel() for p in gru_vae.parameters())
    logger.info(f'Total # params: {total_params}')

    gru_vae.load_state_dict(torch.load(weights_path, map_location=device))

    # Updating the vocab size will break the model
    params.update({
        # 'vocab_size': smiles_language.number_of_tokens,
        'pad_index': smiles_language.padding_index
    })  # yapf:disable

    # if params.get('embedding', 'learned') == 'one_hot':
    #     params.update({'embedding_size': params['vocab_size']})

    # train for n_epoch epochs
    logger.info(
        'Model creation, loading and data processing done. Evaluation starts.')

    gru_vae.eval()
    gru_vae.to(device)

    counter = 0
    with torch.no_grad():
        latent_code = []
        from tqdm import tqdm
        for batch in tqdm(dataloader, total=len(dataloader)):
            (encoder_seq, _, _) = data_preparation(batch,
                                                   input_keep=0.,
                                                   start_index=2,
                                                   end_index=3,
                                                   device=device)
            try:
                mu, logvar = gru_vae.encode(encoder_seq)
            except RuntimeError:
                # Substitute any new tokens by "<UNK>" tokens
                new_seq = []
                padd_encoder_seq, lenghts = (
                    torch.nn.utils.rnn.pad_packed_sequence(encoder_seq,
                                                           batch_first=True))
                for seq, _len in zip(padd_encoder_seq, lenghts):
                    seq = seq[:_len]
                    if any([x >= params['vocab_size'] for x in seq]):
                        seq = torch.tensor([
                            x if x < params['vocab_size'] else
                            smiles_language.unknown_index
                            for x in seq.tolist()
                        ]).short()

                        failed_smiles = smiles_language.selfies_to_smiles(
                            smiles_language.token_indexes_to_smiles(
                                seq.tolist()))
                        logger.warning(
                            f'Out of bounds sample: ~{counter}\t{failed_smiles}'
                        )
                    new_seq.append(seq)

                if new_seq:
                    for _ in range(params['batch_size'] - len(new_seq)):
                        new_seq.append(torch.ones_like(new_seq[-1]))
                    (encoder_seq, _, _) = data_preparation(new_seq,
                                                           input_keep=0.,
                                                           start_index=2,
                                                           end_index=3,
                                                           device=device)
                    mu, logvar = gru_vae.encode(encoder_seq)
            for _mu in mu.tolist():
                latent_code.append([counter, _mu])
                counter += 1

    LATENT_CODE_PATH = os.path.join(os.path.dirname(data_path),
                                    'samples_latent_code.tsv')

    with open(LATENT_CODE_PATH, 'w') as f:
        for i, mu in latent_code:
            f.write(f'{i}\t{",".join([str(x) for x in mu[0]])}\n')
Example #8
0
def main(parser_namespace):
    try:
        device = get_device()
        disable_rdkit_logging()
        # read the params json
        params = dict()
        with open(parser_namespace.params_filepath) as f:
            params.update(json.load(f))

        # get params
        train_smiles_filepath = parser_namespace.train_smiles_filepath
        test_smiles_filepath = parser_namespace.test_smiles_filepath
        smiles_language_filepath = (
            parser_namespace.smiles_language_filepath
            if parser_namespace.smiles_language_filepath.lower() != 'none' else
            None)

        model_path = parser_namespace.model_path
        training_name = parser_namespace.training_name

        writer = SummaryWriter(f'logs/{training_name}')

        logger.info(f'Model with name {training_name} starts.')

        model_dir = os.path.join(model_path, training_name)
        log_path = os.path.join(model_dir, 'logs')
        val_dir = os.path.join(log_path, 'val_logs')
        os.makedirs(os.path.join(model_dir, 'weights'), exist_ok=True)
        os.makedirs(os.path.join(model_dir, 'results'), exist_ok=True)
        os.makedirs(log_path, exist_ok=True)
        os.makedirs(val_dir, exist_ok=True)

        # Load SMILES language
        smiles_language = None
        if smiles_language_filepath is not None:
            smiles_language = SMILESLanguage.load(smiles_language_filepath)

        logger.info(f'Smiles filepath: {train_smiles_filepath}')

        # create SMILES eager dataset
        smiles_train_data = SMILESDataset(
            train_smiles_filepath,
            smiles_language=smiles_language,
            padding=False,
            selfies=params.get('selfies', False),
            add_start_and_stop=params.get('add_start_stop_token', True),
            augment=params.get('augment_smiles', False),
            canonical=params.get('canonical', False),
            kekulize=params.get('kekulize', False),
            all_bonds_explicit=params.get('all_bonds_explicit', False),
            all_hs_explicit=params.get('all_hs_explicit', False),
            remove_bonddir=params.get('remove_bonddir', False),
            remove_chirality=params.get('remove_chirality', False),
            backend='lazy',
            device=device,
        )
        smiles_test_data = SMILESDataset(
            test_smiles_filepath,
            smiles_language=smiles_language,
            padding=False,
            selfies=params.get('selfies', False),
            add_start_and_stop=params.get('add_start_stop_token', True),
            augment=params.get('augment_smiles', False),
            canonical=params.get('canonical', False),
            kekulize=params.get('kekulize', False),
            all_bonds_explicit=params.get('all_bonds_explicit', False),
            all_hs_explicit=params.get('all_hs_explicit', False),
            remove_bonddir=params.get('remove_bonddir', False),
            remove_chirality=params.get('remove_chirality', False),
            backend='lazy',
            device=device,
        )

        if smiles_language_filepath is None:
            smiles_language = smiles_train_data.smiles_language
            smiles_language.save(
                os.path.join(model_path, f'{training_name}.lang'))
        else:
            smiles_language_filename = os.path.basename(
                smiles_language_filepath)
            smiles_language.save(
                os.path.join(model_dir, smiles_language_filename))

        params.update({
            'vocab_size': smiles_language.number_of_tokens,
            'pad_index': smiles_language.padding_index
        })

        vocab_dict = smiles_language.index_to_token
        params.update({
            'start_index':
            list(vocab_dict.keys())[list(
                vocab_dict.values()).index('<START>')],
            'end_index':
            list(vocab_dict.keys())[list(vocab_dict.values()).index('<STOP>')]
        })

        if params.get('embedding', 'learned') == 'one_hot':
            params.update({'embedding_size': params['vocab_size']})

        with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp:
            json.dump(params, fp)

        # create DataLoaders
        train_data_loader = torch.utils.data.DataLoader(
            smiles_train_data,
            batch_size=params.get('batch_size', 64),
            collate_fn=collate_fn,
            drop_last=True,
            shuffle=True,
            pin_memory=params.get('pin_memory', True),
            num_workers=params.get('num_workers', 8))

        test_data_loader = torch.utils.data.DataLoader(
            smiles_test_data,
            batch_size=params.get('batch_size', 64),
            collate_fn=collate_fn,
            drop_last=True,
            shuffle=True,
            pin_memory=params.get('pin_memory', True),
            num_workers=params.get('num_workers', 8))
        # initialize encoder and decoder
        gru_encoder = StackGRUEncoder(params).to(device)
        gru_decoder = StackGRUDecoder(params).to(device)
        gru_vae = TeacherVAE(gru_encoder, gru_decoder).to(device)
        # TODO I haven't managed to get this to work. I will leave it here
        # if somewant (or future me) wants to give it a look and get the
        # tensorboard graph to work
        # if writer and False:
        #     gru_vae.set_batch_mode('padded')
        #     dummy_input = torch.ones(smiles_train_data[0].shape)
        #     dummy_input = dummy_input.unsqueeze(0).to(device)
        #     writer.add_graph(gru_vae, (dummy_input, dummy_input, dummy_input))
        #     gru_vae.set_batch_mode(params.get('batch_mode'))
        logger.info('\n****MODEL SUMMARY***\n')
        for name, parameter in gru_vae.named_parameters():
            logger.info(f'Param {name}, shape:\t{parameter.shape}')
        total_params = sum(p.numel() for p in gru_vae.parameters())
        logger.info(f'Total # params: {total_params}')

        loss_tracker = {
            'test_loss_a': 10e4,
            'test_rec_a': 10e4,
            'test_kld_a': 10e4,
            'ep_loss': 0,
            'ep_rec': 0,
            'ep_kld': 0
        }

        # train for n_epoch epochs
        logger.info(
            'Model creation and data processing done, Training starts.')
        decoder_search = SEARCH_FACTORY[
            params.get('decoder_search', 'sampling')
        ](
            temperature=params.get('temperature', 1.),
            beam_width=params.get('beam_width', 3),
            top_tokens=params.get('top_tokens', 5)
        )  # yapf: disable

        if writer:
            pparams = params.copy()
            pparams['training_file'] = train_smiles_filepath
            pparams['test_file'] = test_smiles_filepath
            pparams['language_file'] = smiles_language_filepath
            pparams['model_path'] = model_path
            pparams = {
                k: v if v is not None else 'N.A.'
                for k, v in params.items()
            }
            pparams['training_name'] = training_name
            from pprint import pprint
            pprint(pparams)
            writer.add_hparams(hparam_dict=pparams, metric_dict={})

        for epoch in range(params['epochs'] + 1):
            t = time()
            loss_tracker = train_vae(
                epoch,
                gru_vae,
                train_data_loader,
                test_data_loader,
                smiles_language,
                model_dir,
                search=decoder_search,
                optimizer=params.get('optimizer', 'adadelta'),
                lr=params['learning_rate'],
                kl_growth=params['kl_growth'],
                input_keep=params['input_keep'],
                test_input_keep=params['test_input_keep'],
                generate_len=params['generate_len'],
                log_interval=params['log_interval'],
                save_interval=params['save_interval'],
                eval_interval=params['eval_interval'],
                loss_tracker=loss_tracker,
                logger=logger,
                # writer=writer,
                batch_mode=params.get('batch_mode'))
            logger.info(f'Epoch {epoch}, took {time() - t:.1f}.')

        logger.info('OVERALL: \t Best loss = {0:.4f} in Ep {1}, '
                    'best Rec = {2:.4f} in Ep {3}, '
                    'best KLD = {4:.4f} in Ep {5}'.format(
                        loss_tracker['test_loss_a'], loss_tracker['ep_loss'],
                        loss_tracker['test_rec_a'], loss_tracker['ep_rec'],
                        loss_tracker['test_kld_a'], loss_tracker['ep_kld']))
        logger.info('Training done, shutting down.')
    except Exception:
        logger.exception('Exception occurred while running train_vae.py.')
Example #9
0
def train(
    epoch,
    model,
    train_loader,
    optimizer,
    scheduler,
    writer=None,
    verbose=False
):
    start_time = time()
    device = get_device()
    # selfies = train_loader.dataset._dataset.selfies
    data_preparation = packed_sequential_data_preparation
    model.to(device)
    model.train()

    input_keep = 1.
    start_index = 2
    end_index = 3

    train_loss = 0

    for _iter, data in tqdm.tqdm(
        enumerate(train_loader),
        total=len(train_loader),
        disable=(not verbose)
    ):

        # Seqs are list of strings, so they must be first preprocessed
        # and the data has then to be moved .to(device)
        seqs = data.x
        batch_size = len(seqs)

        # FIXME? variable batch size in data
        model.encoder.update_batch_size(batch_size)

        encoder_seq, _, _ = data_preparation(
            seqs,
            input_keep=input_keep,
            start_index=start_index,
            end_index=end_index,
            device=device
        )

        optimizer.zero_grad()
        pos_z, neg_z, summary = model(encoder_seq, data.edge_index.to(device))
        loss = model.loss(pos_z, neg_z, summary)
        loss.backward()
        optimizer.step()

        writer.add_scalar(
            'loss', loss.item(), _iter + epoch * len(train_loader)
        )

        train_loss += loss.item()
        if _iter % (len(train_loader) // 10) == 0:
            tqdm.tqdm.write(f'{loss}')
    if scheduler is not None:
        scheduler.step()
    logger.info(f"Learning rate {optimizer.param_groups[0]['lr']}")
    logger.info(f'Epoch: {epoch}\t{train_loss/_iter}\t{time()-start_time}')
def main(*, parser_namespace):

    disable_rdkit_logging()

    # read the params json
    params = dict()
    with open(parser_namespace.params_path) as f:
        params.update(json.load(f))

    # get params, json args take precedence
    mol_model_path = params.get('mol_model_path',
                                parser_namespace.mol_model_path)
    protein_model_path = params.get('protein_model_path',
                                    parser_namespace.protein_model_path)
    affinity_model_path = params.get('affinity_model_path',
                                     parser_namespace.affinity_model_path)
    protein_data_path = params.get('protein_data_path',
                                   parser_namespace.protein_data_path)
    model_name = params.get(
        'model_name', parser_namespace.model_name
    )   # yapf: disable
    test_id = int(params.get(
        'test_protein_id', parser_namespace.test_protein_id
    ))   # yapf: disable
    unbiased_preds_path = params.get(
        'unbiased_predictions_path', parser_namespace.unbiased_predictions_path
    )   # yapf: disable
    model_name += '_' + str(test_id)
    logger.info(f'Model with name {model_name} starts.')

    # passing optional paths to params to possibly update_reward_fn
    optional_reward_args = [
        'tox21_path', 'organdb_path', 'site', 'clintox_path', 'sider_path'
    ]
    for arg in optional_reward_args:
        if parser_namespace.__dict__[arg]:
            # json still has presedence
            params[arg] = params.get(arg, parser_namespace.__dict__[arg])

    # Load protein sequence data
    if protein_data_path.endswith('.smi'):
        protein_df = read_smi(protein_data_path, names=['Sequence'])
    elif protein_data_path.endswith('.csv'):
        protein_df = pd.read_csv(protein_data_path, index_col='entry_name')
    else:
        raise TypeError(
            f"{protein_data_path.split('.')[-1]} files are not supported.")

    protein_test_name = protein_df.iloc[test_id].name
    logger.info(f'Test protein is {protein_test_name}')

    # Restore SMILES Model
    with open(os.path.join(mol_model_path, 'model_params.json')) as f:
        mol_params = json.load(f)

    gru_encoder = StackGRUEncoder(mol_params)
    gru_decoder = StackGRUDecoder(mol_params)
    generator = TeacherVAE(gru_encoder, gru_decoder)
    generator.load(os.path.join(
        mol_model_path,
        f"weights/best_{params.get('smiles_metric', 'rec')}.pt"),
                   map_location=get_device())
    # Load languages
    generator_smiles_language = SMILESLanguage.load(
        os.path.join(mol_model_path, 'selfies_language.pkl'))
    generator._associate_language(generator_smiles_language)

    # Restore protein model
    with open(os.path.join(protein_model_path, 'model_params.json')) as f:
        protein_params = json.load(f)

    # Define network
    protein_encoder = ENCODER_FACTORY['dense'](protein_params)
    protein_encoder.load(os.path.join(
        protein_model_path,
        f"weights/best_{params.get('omics_metric','both')}_encoder.pt"),
                         map_location=get_device())
    protein_encoder.eval()

    # Restore affinity predictor
    with open(os.path.join(affinity_model_path, 'model_params.json')) as f:
        predictor_params = json.load(f)
    predictor = MODEL_FACTORY['bimodal_mca'](predictor_params)
    predictor.load(os.path.join(
        affinity_model_path,
        f"weights/best_{params.get('p_metric', 'ROC-AUC')}_bimodal_mca.pt"),
                   map_location=get_device())
    predictor.eval()

    # Load languages
    affinity_smiles_language = SMILESLanguage.load(
        os.path.join(affinity_model_path, 'smiles_language.pkl'))
    affinity_protein_language = ProteinLanguage.load(
        os.path.join(affinity_model_path, 'protein_language.pkl'))
    predictor._associate_language(affinity_smiles_language)
    predictor._associate_language(affinity_protein_language)

    # Specifies the baseline model used for comparison
    unbiased_preds = np.array(
        pd.read_csv(
            os.path.join(unbiased_preds_path, protein_test_name + '.csv')
        )['affinity'].values
    )  # yapf: disable

    # Create a fresh model that will be optimized
    gru_encoder_rl = StackGRUEncoder(mol_params)
    gru_decoder_rl = StackGRUDecoder(mol_params)
    generator_rl = TeacherVAE(gru_encoder_rl, gru_decoder_rl)
    generator_rl.load(os.path.join(
        mol_model_path, f"weights/best_{params.get('metric', 'rec')}.pt"),
                      map_location=get_device())
    generator_rl.eval()
    # Load languages
    generator_rl._associate_language(generator_smiles_language)

    protein_encoder_rl = ENCODER_FACTORY['dense'](protein_params)
    protein_encoder_rl.load(os.path.join(
        protein_model_path,
        f"weights/best_{params.get('metric', 'both')}_encoder.pt"),
                            map_location=get_device())
    protein_encoder_rl.eval()
    model_folder_name = model_name
    learner = ReinforceProtein(generator_rl, protein_encoder_rl, predictor,
                               protein_df, params, model_folder_name, logger)

    biased_ratios, tox_ratios = [], []
    rewards, rl_losses = [], []
    gen_mols, gen_prot, gen_affinity, mode = [], [], [], []

    logger.info(f'Model stored at {learner.model_path}')

    for epoch in range(1, params['epochs'] + 1):

        for step in range(1, params['steps']):

            # Randomly sample a protein
            protein_name = np.random.choice(protein_df.index)
            while protein_name == protein_test_name:
                protein_name = np.random.choice(protein_df.index)

            logger.info(f'Current train protein: {protein_name}')

            rew, loss = learner.policy_gradient(protein_name, epoch,
                                                params['batch_size'])
            logger.info(
                f"Epoch {epoch:d}/{params['epochs']:d}, step {step:d}/"
                f"{params['steps']:d}\t loss={loss:.2f}, mean rew={rew:.2f}")

            rewards.append(rew.item())
            rl_losses.append(loss)

        # Save model
        if epoch % 10 == 0:
            learner.save(f'gen_{epoch}.pt', f'enc_{epoch}.pt')
        logger.info(f'EVAL protein: {protein_test_name}')

        smiles, preds = (learner.generate_compounds_and_evaluate(
            epoch, params['eval_batch_size'], protein_test_name))
        gs = [s for i, s in enumerate(smiles) if preds[i] > 0.5]
        gp = preds[preds > 0.5]
        for p, s in zip(gp, gs):
            gen_mols.append(s)
            gen_prot.append(protein_test_name)
            gen_affinity.append(p)
            mode.append('eval')

        inds = np.argsort(gp)[::-1]
        for i in inds[:5]:
            logger.info(f'Epoch {epoch:d}, generated {gs[i]} against '
                        f'{protein_test_name}.\n Predicted IC50 = {gp[i]}. ')

        plot_and_compare_proteins(unbiased_preds, preds, protein_test_name,
                                  epoch, learner.model_path, 'train',
                                  params['eval_batch_size'])
        biased_ratios.append(
            np.round(100 * (np.sum(preds > 0.5) / len(preds)), 1))
        all_toxes = np.array([learner.tox21(s) for s in smiles])
        tox_ratios.append(
            np.round(100 * (np.sum(all_toxes == 1.) / len(all_toxes)), 1))
        logger.info(f'Percentage of non-toxic compounds {tox_ratios[-1]}')

        toxes = [learner.tox21(s) for s in gen_mols]
        # Save results (good molecules!) in DF
        df = pd.DataFrame({
            'protein': gen_prot,
            'SMILES': gen_mols,
            'Binding probability': gen_affinity,
            'mode': mode,
            'Tox21': toxes
        })
        df.to_csv(os.path.join(learner.model_path, 'results', 'generated.csv'))
        # Plot loss development
        loss_df = pd.DataFrame({'loss': rl_losses, 'rewards': rewards})
        loss_df.to_csv(learner.model_path +
                       '/results/loss_reward_evolution.csv')
        plot_loss(rl_losses,
                  rewards,
                  params['epochs'],
                  protein_name,
                  learner.model_path,
                  rolling=5)
    pd.DataFrame({
        'efficacy_ratio': biased_ratios,
        'tox_ratio': tox_ratios
    }).to_csv(learner.model_path + '/results/ratios.csv')