コード例 #1
0
ファイル: train.py プロジェクト: tammykan/moses
def main(config):
    set_seed(config.seed)

    train = read_smiles_csv(config.train_load)
    vocab = CharVocab.from_data(train)
    torch.save(vocab, config.vocab_save)
    torch.save(config, config.config_save)
    device = torch.device(config.device)

    # condition mode
    if config.conditional:
        fps = read_fps_csv(config.train_load)
        fps = fps_to_list(fps)
        fps = [torch.tensor(f, dtype=torch.float, device=device) for f in fps]
        # fingerprints length
        fps_len = len(fps[0])
    else:
        fps = None
        fps_len = 0

    with Pool(config.n_jobs) as pool:
        reward_func = MetricsReward(train, config.n_ref_subsample,
                                    config.rollouts, pool,
                                    config.addition_rewards)
        model = ORGAN(vocab, config, fps_len, reward_func)
        model = model.to(device)

        trainer = ORGANTrainer(config)
        trainer.fit(model, train, fps)

    torch.save(model.state_dict(), config.model_save)
コード例 #2
0
ファイル: train.py プロジェクト: ZixinYang/moses
def main(config):
    set_seed(config.seed)

    train = read_smiles_csv(config.train_load)
    if config.conditional_model:
        labels = read_label_csv(config.train_load)
        config.labels_size = len(labels[0])
        labels = [[int(x) for x in list(l)] for l in labels]
        train_data = [(x, y) for (x, y) in zip(train, labels)]
    else:
        train_data = [(x) for x in train]
    shuffle(train_data)
    train_data = train_data[:500000]
    vocab = CharVocab.from_data(train)
    torch.save(config, config.config_save)
    torch.save(vocab, config.vocab_save)

    device = torch.device(config.device)

    model = AAE(vocab, config)
    model = model.to(device)

    trainer = AAETrainer(config)
    trainer.fit(model, train_data)

    model.to('cpu')
    torch.save(model.state_dict(), config.model_save)
コード例 #3
0
ファイル: train.py プロジェクト: tammykan/moses
def main(config):
    set_seed(config.seed)

    train = read_smiles_csv(config.train_load)

    vocab = CharVocab.from_data(train)
    torch.save(config, config.config_save)
    torch.save(vocab, config.vocab_save)

    device = torch.device(config.device)

    model = AAE(vocab, config)
    model = model.to(device)

    trainer = AAETrainer(config)
    trainer.fit(model, train)

    model.to('cpu')
    torch.save(model.state_dict(), config.model_save)
コード例 #4
0
def main(config):
    set_seed(config.seed)

    train = read_smiles_csv(config.train_load)
    vocab = CharVocab.from_data(train)
    device = torch.device(config.device)

    with Pool(config.n_jobs) as pool:
        reward_func = MetricsReward(train, config.n_ref_subsample,
                                    config.rollouts, pool,
                                    config.addition_rewards)
        model = ORGAN(vocab, config, reward_func)
        model = model.to(device)

        trainer = ORGANTrainer(config)
        trainer.fit(model, train)

    torch.save(model.state_dict(), config.model_save)
    torch.save(config, config.config_save)
    torch.save(vocab, config.vocab_save)
コード例 #5
0
 def get_vocabulary(self, data):
     return CharVocab.from_data(data)
コード例 #6
0
ファイル: datautils.py プロジェクト: tammykan/moses
    def fit(self, dataset):
        self.vocab = CharVocab.from_data(dataset)

        return self
コード例 #7
0
import gentrl
import torch
from moses.metrics.utils import get_mol
import pandas as pd
import pickle
import moses
from moses.utils import CharVocab
from rdkit import RDLogger 
RDLogger.DisableLog('rdApp.*')

# Load vocab
dataset_path = "../data/moses_qed_props.csv.gz"
df = pd.read_csv(dataset_path, compression="gzip")
vocab = CharVocab.from_data(df['SMILES'])

enc = gentrl.RNNEncoder(vocab, latent_size=50)
dec = gentrl.DilConvDecoder(vocab, latent_input_size=50, split_len=100)
model = gentrl.GENTRL(enc, dec, 50 * [('c', 20)], [('c', 20)], beta=0.001)
model.cuda()

torch.cuda.set_device(0)

moses_qed_props_model_path = "../models/moses/"
model.load(moses_qed_props_model_path)
model.cuda()

import random
from rdkit import RDLogger 
RDLogger.DisableLog('rdApp.*')

generated = []