コード例 #1
0
ファイル: pretrainer.py プロジェクト: DayDayUpDeng/DrugEx
def training(is_lstm=True):
    voc = utils.Voc(init_from_file="data/voc.txt")
    if is_lstm:
        netP_path = 'output/lstm_chembl'
        netE_path = 'output/lstm_ligand'
    else:
        netP_path = 'output/gru_chembl'
        netE_path = 'output/gru_ligand'

    prior = models.Generator(voc, is_lstm=is_lstm)
    if not os.path.exists(netP_path + '.pkg'):
        df = pd.read_table()
        chembl = df.read_table("data/chembl_corpus.txt").Token
        chembl = torch.LongTensor(voc.encode([seq.split(' ') for seq in chembl]))
        chembl = DataLoader(chembl, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
        prior.fit(chembl, out=netP_path, epochs=50)
    prior.load_state_dict(torch.load(netP_path + '.pkg'))

    # explore = model.Generator(voc)
    df = pd.read_table('data/ligand_corpus.txt').drop_duplicates('Smiles')
    valid = df.sample(len(df) // 10).Token
    train = df.drop(valid.index).Token
    # explore.load_state_dict(torch.load(netP_path + '.pkg'))

    train = torch.LongTensor(voc.encode([seq.split(' ') for seq in train]))
    train = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)

    valid = torch.LongTensor(voc.encode([seq.split(' ') for seq in valid]))
    valid = DataLoader(TensorDataset(valid), batch_size=BATCH_SIZE, shuffle=True)
    print('Fine tuning progress begins to be trained...')

    prior.fit(train, loader_valid=valid, out=netE_path, epochs=1000, lr=lr)
    print('Fine tuning progress training is finished...')
コード例 #2
0
ファイル: sampler.py プロジェクト: DayDayUpDeng/DrugEx
def sampling(netG_path, out, size=10000):
    """
    sampling a series of tokens squentially for molecule generation
    Args:
        netG_path (str): The file path of generator.
        out (str): The file path of genrated molecules including SMILES, and its scores for each sample
        size (int): The number of molecules required to be generated.
        env (utils.Environment): The environment to provide the scores of all objectives for each sample

    Returns:
        smiles (List): A list of generated SMILES-based molecules
    """
    batch_size = 250
    samples = []
    voc = utils.Voc(init_from_file="data/voc.txt")
    netG = models.Generator(voc)
    netG.load_state_dict(torch.load(netG_path))
    batch = size // batch_size
    mod = size % batch_size
    for i in tqdm(range(batch + 1)):
        if i == 0:
            if mod == 0: continue
            tokens = netG.sample(batch)
        else:
            tokens = netG.sample(batch_size)
        smiles = [voc.decode(s) for s in tokens]
        samples.extend(smiles)
    return samples
コード例 #3
0
ファイル: dataset.py プロジェクト: DayDayUpDeng/DrugEx
def corpus(input, output, is_sdf=False, requires_clean=True, is_isomerice=False):
    """ Constructing dataset with SMILES-based molecules, each molecules will be decomposed
        into a series of tokens. In the end, all the tokens will be put into one set as vocaulary.

        Arguments:
            input (string): The file path of input, either .sdf file or tab-delimited file

            output (string): The file path of output

            is_sdf (bool): Designate if the input file is sdf file or not

            requires_clean (bool): If the molecule is required to be clean, the charge metal will be
                    removed and only the largest fragment will be kept.

            is_isomerice (bool): If the molecules in the dataset keep conformational information. If not,
                    the conformational tokens (e.g. @@, @, \, /) will be removed.

    """
    if is_sdf:
        # deal with sdf file with RDkit
        inf = gzip.open(input)
        fsuppl = Chem.ForwardSDMolSupplier(inf)
        df = []
        for mol in fsuppl:
            try:
                df.append(Chem.MolToSmiles(mol, is_isomerice))
            except:
                print(mol)
    else:
        # deal with table file
        df = pd.read_table(input).Smiles.dropna()
    voc = utils.Voc()
    words = set()
    canons = []
    tokens = []
    if requires_clean:
        smiles = set()
        for smile in tqdm(df):
            try:
                smile = utils.clean_mol(smile, is_isomeric=is_isomerice)
                smiles.add(Chem.CanonSmiles(smile))
            except:
                print('Parsing Error:', smile)
    else:
        smiles = df.values
    for smile in tqdm(smiles):
        token = voc.tokenize(smile)
        # Only collect the organic molecules
        if {'C', 'c'}.isdisjoint(token):
            print('Warning:', smile)
            continue
        # Remove the metal tokens
        if not {'[Na]', '[Zn]'}.isdisjoint(token):
            print('Redudent', smile)
            continue
        # control the minimum and maximum of sequence length.
        if 10 < len(token) <= 100:
            words.update(token)
            canons.append(smile)
            tokens.append(' '.join(token))

    # output the vocabulary file
    log = open(output + '_voc.txt', 'w')
    log.write('\n'.join(sorted(words)))
    log.close()

    # output the dataset file as tab-delimited file
    log = pd.DataFrame()
    log['Smiles'] = canons
    log['Token'] = tokens
    log.drop_duplicates(subset='Smiles')
    log.to_csv(output + '_corpus.txt', sep='\t', index=False)
コード例 #4
0
ファイル: train_smiles.py プロジェクト: XuhanLiu/DrugEx
        'ft_path': 'output/ligand_mf_brics_gpt_256'
    }
    opts, args = getopt.getopt(sys.argv[1:], "m:g:b:d:")
    OPT = dict(opts)
    torch.cuda.set_device(0)
    os.environ["CUDA_VISIBLE_DEVICES"] = OPT.get('-g', "0,1,2,3")
    method = OPT.get('-m', 'gpt')
    step = OPT['-s']
    BATCH_SIZE = int(OPT.get('-b', '256'))
    dataset = OPT.get('-d', 'ligand_mf_brics')

    data = pd.read_table('data/%s_train_smi.txt' % dataset)
    test = pd.read_table('data/%s_test_smi.txt' % dataset)
    test = test.Input.drop_duplicates().sample(BATCH_SIZE * 10).values
    if method in ['gpt']:
        voc = utils.Voc('data/voc_smiles.txt', src_len=100, trg_len=100)
    else:
        voc = utils.VocSmiles('data/voc_smiles.txt', max_len=100)
    data_in = voc.encode([seq.split(' ') for seq in data.Input.values])
    data_out = voc.encode([seq.split(' ') for seq in data.Output.values])
    data_set = TensorDataset(data_in, data_out)
    data_loader = DataLoader(data_set, batch_size=BATCH_SIZE, shuffle=True)

    test_set = voc.encode([seq.split(' ') for seq in test])
    test_set = utils.TgtData(
        test_set, ix=[voc.decode(seq, is_tk=False) for seq in test_set])
    test_loader = DataLoader(test_set,
                             batch_size=BATCH_SIZE,
                             collate_fn=test_set.collate_fn)

    pretrain(method=method)
コード例 #5
0
torch.set_num_threads(1)
BATCH_SIZE = 1024

if __name__ == "__main__":
    opts, args = getopt.getopt(sys.argv[1:], "m:d:g:p:")
    OPT = dict(opts)
    # torch.cuda.set_device(0)
    os.environ[
        "CUDA_VISIBLE_DEVICES"] = OPT['-g'] if '-g' in OPT else "0, 1, 2, 3"
    method = OPT['-m'] if '-m' in OPT else 'atom'
    dataset = OPT['-d'] if '-d' in OPT else 'ligand_mf_brics'
    path = OPT['-p'] if '-p' in OPT else dataset
    utils.devices = [0]

    if method in ['gpt']:
        voc = utils.Voc('data/chembl_voc.txt', src_len=100, trg_len=100)
    else:
        voc = utils.VocSmiles('data/chembl_voc.txt', max_len=100)
    if method == 'ved':
        agent = generator.EncDec(voc, voc).to(utils.dev)
    elif method == 'attn':
        agent = generator.Seq2Seq(voc, voc).to(utils.dev)
    elif method == 'gpt':
        agent = GPT2Model(voc, n_layer=12).to(utils.dev)
    else:
        voc = utils.VocGraph('data/voc_atom.txt')
        agent = GraphModel(voc_trg=voc)

    for agent_path in [
            'benchmark/graph_PR_REG_OBJ1_0e+00.pkg',
            'benchmark/graph_PR_REG_OBJ1_1e-01.pkg',
コード例 #6
0
    else:
        mod1 = utils.ClippedScore(lower_x=3, upper_x=6.5)
        mod2 = utils.ClippedScore(lower_x=10, upper_x=6.5)
        ths = [0.99] * 3
    mods = [mod1, mod1, mod2] if case == 'OBJ3' else [mod2, mod1, mod2]
    env = utils.Env(objs=objs, mods=mods, keys=keys, ths=ths)

    root = 'output/%s_%s_%s_%s/'% (alg, case, scheme, time.strftime('%y%m%d_%H%M%S', time.localtime()))
    os.mkdir(root)
    copy2('models/rlearner.py', root)
    copy2('trainer.py', root)

    pr_path = 'output/lstm_chembl'
    ft_path = 'output/lstm_ligand'

    voc = utils.Voc(init_from_file="data/voc.txt")
    agent = generator.Generator(voc)
    agent.load_state_dict(torch.load(ft_path + '.pkg'))

    prior = generator.Generator(voc)
    prior.load_state_dict(torch.load(pr_path + '.pkg'))

    if alg == 'drugex':
        learner = rlearner.DrugEx(prior, env, agent)
    elif alg == 'organic':
        embed_dim = 128
        filter_size = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20]
        num_filters = [100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160]
        prior = classifier.Discriminator(agent.voc.size, embed_dim, filter_size, num_filters)
        df = pd.read_table('data/LIGAND_%s_%s.tsv' % (z, case))
        df = df[df.DESIRE == 1]