Exemplo n.º 1
0
def main(config):
    set_seed(config.seed)

    train = read_smiles_csv(config.train_load)
    vocab = CharVocab.from_data(train)
    torch.save(vocab, config.vocab_save)
    torch.save(config, config.config_save)
    device = torch.device(config.device)

    # condition mode
    if config.conditional:
        fps = read_fps_csv(config.train_load)
        fps = fps_to_list(fps)
        fps = [torch.tensor(f, dtype=torch.float, device=device) for f in fps]
        # fingerprints length
        fps_len = len(fps[0])
    else:
        fps = None
        fps_len = 0

    with Pool(config.n_jobs) as pool:
        reward_func = MetricsReward(train, config.n_ref_subsample,
                                    config.rollouts, pool,
                                    config.addition_rewards)
        model = ORGAN(vocab, config, fps_len, reward_func)
        model = model.to(device)

        trainer = ORGANTrainer(config)
        trainer.fit(model, train, fps)

    torch.save(model.state_dict(), config.model_save)
Exemplo n.º 2
0
def main(config):
    set_seed(config.seed)

    train = read_smiles_csv(config.train_load)
    if config.conditional_model:
        labels = read_label_csv(config.train_load)
        config.labels_size = len(labels[0])
        labels = [[int(x) for x in list(l)] for l in labels]
        train_data = [(x, y) for (x, y) in zip(train, labels)]
    else:
        train_data = [(x) for x in train]
    shuffle(train_data)
    train_data = train_data[:500000]
    vocab = CharVocab.from_data(train)
    torch.save(config, config.config_save)
    torch.save(vocab, config.vocab_save)

    device = torch.device(config.device)

    model = AAE(vocab, config)
    model = model.to(device)

    trainer = AAETrainer(config)
    trainer.fit(model, train_data)

    model.to('cpu')
    torch.save(model.state_dict(), config.model_save)
Exemplo n.º 3
0
def main(config):
    set_seed(config.seed)

    train = read_smiles_csv(config.train_load)

    vocab = CharVocab.from_data(train)
    torch.save(config, config.config_save)
    torch.save(vocab, config.vocab_save)

    device = torch.device(config.device)

    model = AAE(vocab, config)
    model = model.to(device)

    trainer = AAETrainer(config)
    trainer.fit(model, train)

    model.to('cpu')
    torch.save(model.state_dict(), config.model_save)
Exemplo n.º 4
0
def merge_vocab(*args):
    """
        helper function to merge multiple vocab objects...helpful for cases that may require the processing of more data than 
        is able to held in memory or for getting a common vocab to use to merge multiple disjoint datasets, etc..

        *args: a list of an arbitrary number of vocab objects
    """

    # use this list to filter out 'characters' that we don't need to make the new dataset
    ignore_char_list = ["<bos>", "<eos>", "<pad>", "<unk>"]
    merged_char_set = set()

    for vocab_path in args:
        vocab = torch.load(vocab_path)
        vocab_chars_set = set(
            [x for x in vocab.c2i.keys() if x not in ignore_char_list]
        )
        merged_char_set.update(vocab_chars_set)

    return CharVocab(merged_char_set)
Exemplo n.º 5
0
def main(config):
    set_seed(config.seed)

    train = read_smiles_csv(config.train_load)
    vocab = CharVocab.from_data(train)
    device = torch.device(config.device)

    with Pool(config.n_jobs) as pool:
        reward_func = MetricsReward(train, config.n_ref_subsample,
                                    config.rollouts, pool,
                                    config.addition_rewards)
        model = ORGAN(vocab, config, reward_func)
        model = model.to(device)

        trainer = ORGANTrainer(config)
        trainer.fit(model, train)

    torch.save(model.state_dict(), config.model_save)
    torch.save(config, config.config_save)
    torch.save(vocab, config.vocab_save)
Exemplo n.º 6
0
def compute_vocab(smiles_list, n_jobs=mp.cpu_count()):
    """
        simple function that can be used to create a vocabulary for an arbitrary set of smiles strings

        smiles_list: list of smiles strings
    """
    # extract all unique characters in smiles_list
    # char_set = set.union(*[set(x) for x in smiles_list])

    with mp.Pool(n_jobs) as pool:
        result = list(
            tqdm(
                pool.imap_unordered(compute_vocab_job, smiles_list),
                total=len(smiles_list),
            )
        )
        char_set = set.union(*result)

    # create the vocab
    vocab = CharVocab(char_set)

    return vocab
Exemplo n.º 7
0
 def get_vocabulary(self, data):
     return CharVocab.from_data(data)
Exemplo n.º 8
0
    def fit(self, dataset):
        self.vocab = CharVocab.from_data(dataset)

        return self
Exemplo n.º 9
0
import gentrl
import torch
from moses.metrics.utils import get_mol
import pandas as pd
import pickle
import moses
from moses.utils import CharVocab
from rdkit import RDLogger 
RDLogger.DisableLog('rdApp.*')

# Load vocab
dataset_path = "../data/moses_qed_props.csv.gz"
df = pd.read_csv(dataset_path, compression="gzip")
vocab = CharVocab.from_data(df['SMILES'])

enc = gentrl.RNNEncoder(vocab, latent_size=50)
dec = gentrl.DilConvDecoder(vocab, latent_input_size=50, split_len=100)
model = gentrl.GENTRL(enc, dec, 50 * [('c', 20)], [('c', 20)], beta=0.001)
model.cuda()

torch.cuda.set_device(0)

moses_qed_props_model_path = "../models/moses/"
model.load(moses_qed_props_model_path)
model.cuda()

import random
from rdkit import RDLogger 
RDLogger.DisableLog('rdApp.*')

generated = []