예제 #1
0
    def train_as_rl(self,
                    reward_fn,
                    num_iterations=100000, verbose_step=50,
                    batch_size=200,
                    cond_lb=-2, cond_rb=0,
                    lr_lp=1e-5, lr_dec=1e-6):
        optimizer_lp = optim.Adam(self.lp.parameters(), lr=lr_lp)
        optimizer_dec = optim.Adam(self.dec.latent_fc.parameters(), lr=lr_dec)

        global_stats = TrainStats()
        local_stats = TrainStats()

        cur_iteration = 0
        while cur_iteration < num_iterations:
            print("!", end='')

            exploit_size = int(batch_size * (1 - 0.3))
            exploit_z = self.lp.sample(exploit_size, 50 * ['s'] + ['m'])

            z_means = exploit_z.mean(dim=0)
            z_stds = exploit_z.std(dim=0)

            expl_size = int(batch_size * 0.3)
            expl_z = torch.randn(expl_size, exploit_z.shape[1])
            expl_z = 2 * expl_z.to(exploit_z.device) * z_stds[None, :]
            expl_z += z_means[None, :]

            z = torch.cat([exploit_z, expl_z])
            smiles = self.dec.sample(50, z, argmax=False)
            zc = torch.zeros(z.shape[0], 1).to(z.device)
            conc_zy = torch.cat([z, zc], dim=1)
            log_probs = self.lp.log_prob(conc_zy, marg=50 * [False] + [True])
            log_probs += self.dec.weighted_forward(smiles, z)
            r_list = [reward_fn(s) for s in smiles]

            rewards = torch.tensor(r_list).float().to(exploit_z.device)
            rewards_bl = rewards - rewards.mean()

            optimizer_dec.zero_grad()
            optimizer_lp.zero_grad()
            loss = -(log_probs * rewards_bl).mean()
            loss.backward()
            optimizer_dec.step()
            optimizer_lp.step()

            valid_sm = [s for s in smiles if get_mol(s) is not None]
            cur_stats = {'mean_reward': sum(r_list) / len(smiles),
                         'valid_perc': len(valid_sm) / len(smiles)}

            local_stats.update(cur_stats)
            global_stats.update(cur_stats)

            cur_iteration += 1

            if verbose_step and (cur_iteration + 1) % verbose_step == 0:
                local_stats.print()
                local_stats.reset()

        return global_stats
예제 #2
0
def penalized_logP(mol_or_smiles, masked=False, default=-5):
    mol = get_mol(mol_or_smiles)
    if mol is None:
        return default
    reward = logP(mol) - SA(mol) - get_num_rings_6(mol)
    if masked and not mol_passes_filters(mol):
        return default
    return reward
예제 #3
0
    def training_step(self, batch, batch_idx):
        exploit_size = int(self.rl_batch_size * (1 - 0.3))
        exploit_z = self.lp.sample(exploit_size, 50 * ['s'] + ['m'])

        z_means = exploit_z.mean(dim=0)
        z_stds = exploit_z.std(dim=0)

        expl_size = int(self.rl_batch_size * 0.3)
        expl_z = torch.randn(expl_size, exploit_z.shape[1]).to(self.device)
        expl_z = 2 * expl_z * z_stds[None, :]
        expl_z += z_means[None, :]

        z = torch.cat([exploit_z, expl_z])
        smiles = self.dec.sample(50, z, argmax=False)
        zc = torch.zeros(z.shape[0], 1).to(z.device)
        conc_zy = torch.cat([z, zc], dim=1)
        log_probs = self.lp.log_prob(conc_zy, marg=50 * [False] + [True])
        log_probs += self.dec(smiles, z)
        r_list = [self.reward_fn(s) for s in smiles]

        rewards = torch.tensor(r_list).float().to(exploit_z.device)
        rewards_bl = rewards - rewards.mean()
        loss = -(log_probs * rewards_bl).mean()

        valid_sm = [s for s in smiles if get_mol(s) is not None]
        cur_stats = {
            'mean_reward': torch.tensor(sum(r_list) / len(smiles)),
            'valid_perc': torch.tensor(len(valid_sm) / len(smiles))
        }
        
        output_dict = OrderedDict({
            'loss': loss,
            'log': cur_stats,
            'progress_bar': cur_stats
        })
        
        return output_dict
예제 #4
0
def get_num_rings_6(mol):
    r = mol.GetRingInfo()
    return len([x for x in r.AtomRings() if len(x) > 6])


def penalized_logP(mol_or_smiles, masked=True, default=-5):
    mol = get_mol(mol_or_smiles)
    if mol is None:
        return default
    reward = logP(mol) - SA(mol) - get_num_rings_6(mol)
    if masked and not mol_passes_filters(mol):
        return default
    return reward


generated = []

while len(generated) < 50:
    print(len(generated))
    sampled = model.sample(100)
    sampled_valid = [s for s in sampled if get_mol(s)]

    generated += sampled_valid

os.makedirs('./images', exist_ok=True)
img = Draw.MolsToGridImage(
    [get_mol(s) for s in sampled_valid],
    legends=[str(penalized_logP(s)) for s in sampled_valid])
img.save('./images/mols.png')
예제 #5
0
    if mol is None:
        return default
    reward = logP(mol) - SA(mol) - get_num_rings_6(mol)
    if masked and not mol_passes_filters(mol):
        return default
    return reward
model.train_as_rl(penalized_logP)
! mkdir -p saved_gentrl_after_rl
model.save('./saved_gentrl_after_rl/')
#sample https://github.com/insilicomedicine/GENTRL/blob/master/examples/sampling.ipynb
from rdkit.Chem import Draw
model.load('./saved_gentrl_after_rl/')
#model.cuda();
def get_num_rings_6(mol):
    r = mol.GetRingInfo()
    return len([x for x in r.AtomRings() if len(x) > 6])
def penalized_logP(mol_or_smiles, masked=True, default=-5):
    mol = get_mol(mol_or_smiles)
    if mol is None:
        return default
    reward = logP(mol) - SA(mol) - get_num_rings_6(mol)
    if masked and not mol_passes_filters(mol):
        return default
    return reward
generated = []
while len(generated) < 1000:
    sampled = model.sample(100)
    sampled_valid = [s for s in sampled if get_mol(s)]
    generated += sampled_valid
Draw.MolsToGridImage([get_mol(s) for s in sampled_valid], legends=[str(penalized_logP(s)) for s in sampled_valid])
예제 #6
0
def train_as_rl(args):

    is_distributed = True
    multi_machine = False

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    if is_distributed:
        os.environ['WORLD_SIZE'] = str(args['size'])
        host_rank = args['rank']
        os.environ['RANK'] = str(host_rank)
        dp_device_ids = [host_rank]
        torch.cuda.set_device(host_rank)

    model = Net(args)
    model = model.to(device)
    model.load('saved_gentrl/')

    reward_fn = args['reward_fn']
    batch_size = args['batch_size']
    verbose_step = args['verbose_step']
    num_iterations = args['num_iterations']

    if is_distributed:
        if multi_machine and use_cuda:
            # multi-machine multi-gpu case
            model = torch.nn.parallel.DistributedDataParallel(model).to(device)
        else:
            # single-machine multi-gpu case or single-machine or multi-machine cpu case
            model = torch.nn.DataParallel(model).to(device)

    optimizer_lp = optim.Adam(model.module.lp.parameters(), lr=args['lr_lp'])
    optimizer_dec = optim.Adam(model.module.dec.latent_fc.parameters(),
                               lr=args['lr_dec'])

    global_stats = TrainStats()
    local_stats = TrainStats()

    cur_iteration = 0

    while cur_iteration < num_iterations:
        print("!", end='')

        exploit_size = int(batch_size * (1 - 0.3))
        exploit_z = model.module.lp.sample(exploit_size, 50 * ['s'] + ['m'])

        z_means = exploit_z.mean(dim=0)
        z_stds = exploit_z.std(dim=0)

        expl_size = int(batch_size * 0.3)
        expl_z = torch.randn(expl_size, exploit_z.shape[1])
        expl_z = 2 * expl_z.to(exploit_z.device) * z_stds[None, :]
        expl_z += z_means[None, :]

        z = torch.cat([exploit_z, expl_z])
        smiles = model.module.dec.sample(50, z, argmax=False)
        zc = torch.zeros(z.shape[0], 1).to(z.device)
        conc_zy = torch.cat([z, zc], dim=1)
        log_probs = model.module.lp.log_prob(conc_zy,
                                             marg=50 * [False] + [True])
        log_probs += model.module.dec.weighted_forward(smiles, z)
        r_list = [reward_fn(s) for s in smiles]

        rewards = torch.tensor(r_list).float().to(exploit_z.device)
        rewards_bl = rewards - rewards.mean()

        optimizer_dec.zero_grad()
        optimizer_lp.zero_grad()
        loss = -(log_probs * rewards_bl).mean()
        loss.backward()

        if is_distributed and not use_cuda:
            #     # average gradients manually for multi-machine cpu case only
            average_gradients(model)

        optimizer_dec.step()
        optimizer_lp.step()

        valid_sm = [s for s in smiles if get_mol(s) is not None]
        cur_stats = {
            'mean_reward': sum(r_list) / len(smiles),
            'valid_perc': len(valid_sm) / len(smiles)
        }

        local_stats.update(cur_stats)
        global_stats.update(cur_stats)

        cur_iteration += 1

        if verbose_step and (cur_iteration + 1) % verbose_step == 0:
            local_stats.print()
            local_stats.reset()

    model.module.save('./saved_gentrl_after_rl/')
    return global_stats