コード例 #1
0
def get_all_metrics(smiles):
    mols = [MolFromSmiles(s) for s in smiles]
    scorer = NormalizedScorer()
    scores, norm_scores = scorer.get_scores_from_mols(mols)
    arom_rings = np.array([Descriptors.NumAromaticRings(m) for m in mols])
    metrics = np.concatenate([scores.sum(axis=1)[:, None],
                              norm_scores.sum(axis=1)[:, None],
                              scores[:, 1][:, None],
                              norm_scores[:, 1][:, None],
                              arom_rings[:, None]],
                             axis=1)
    return (smiles, metrics)
コード例 #2
0
    A simple reward to encourage larger molecule length
    :param smiles: list of strings
    :return: reward, list of float
    '''
    return np.array(
        [-1 if num is None else num for num in num_aliphatic_rings(smiles)])


batch_size = 40
drop_rate = 0.3
molecules = True
grammar = 'new'  #True#
settings = get_settings(molecules, grammar)
invalid_value = -3 * 3.5
scorer = NormalizedScorer(invalid_value=invalid_value,
                          sa_mult=0.0,
                          sa_thresh=-1.5,
                          normalize_scores=False)
max_steps = 277  #settings['max_seq_length']


def second_score(smiles):
    pre_scores = 3 * 2.5 + scorer.get_scores(smiles)[0]
    score = np.power(pre_scores.prod(1), 0.333)
    for i in range(len(score)):
        if np.isnan(score[i]):
            score[i] = -1
    return score


reward_fun = lambda x: 3 * 2.5 + scorer(
    x
コード例 #3
0
    '''
    if not len(smiles):
        return -1  # an empty string is invalid for our purposes
    atoms = num_aromatic_rings(smiles)
    return [-1 if num is None else num + 0.5 for num in atoms]


batch_size = 20  # 20
drop_rate = 0.5
molecules = True
grammar_cache = 'hyper_grammar.pickle'
grammar = 'hypergraph:' + grammar_cache
settings = get_settings(molecules, grammar)
# max_steps = 277  # settings['max_seq_length']
invalid_value = -3.5
scorer = NormalizedScorer(invalid_value=invalid_value)
reward_fun = scorer  #lambda x: np.ones(len(x)) # lambda x: reward_aromatic_rings(x)#
# later will run this ahead of time
gi = GrammarInitializer(grammar_cache)
# if True:
#     gi.delete_cache()
#     gi = GrammarInitializer(grammar_cache)
#     max_steps_smiles = gi.init_grammar(1000)

max_steps = 30
model, gen_fitter, disc_fitter = train_policy_gradient(
    molecules,
    grammar,
    EPOCHS=100,
    BATCH_SIZE=batch_size,
    reward_fun_on=reward_fun,
    A simple reward to encourage larger molecule length
    :param smiles: list of strings
    :return: reward, list of float
    '''
    return np.array(
        [-1 if num is None else num for num in num_aliphatic_rings(smiles)])


batch_size = 40
drop_rate = 0.3
molecules = True
grammar = 'new'  #True#
settings = get_settings(molecules, grammar)
invalid_value = -3 * 3.5
scorer = NormalizedScorer(invalid_value=invalid_value,
                          sa_mult=20,
                          sa_thresh=0,
                          normalize_scores=True)
max_steps = 277  #settings['max_seq_length']


def second_score(smiles):
    pre_scores = 3 * 2.5 + scorer.get_scores(smiles)[0]
    score = np.power(pre_scores.prod(1), 0.333)
    for i in range(len(score)):
        if np.isnan(score[i]):
            score[i] = -1
    return score


reward_fun = lambda x: scorer(x) + np.array(
    [-5 * max(x - 5, 0) for x in num_aromatic_rings(x)]
コード例 #5
0
    atoms = num_atoms(smiles)
    return [-1 if num is None else num for num in atoms]



batch_size = 100
drop_rate = 0.3
molecules = True
grammar = True
settings = get_settings(molecules, grammar)
max_steps = 50 #settings['max_seq_length']
invalid_value = -7.0

task = SequenceGenerationTask(molecules = molecules,
                              grammar = grammar,
                              reward_fun = NormalizedScorer(settings['data_path'],
                                                            invalid_value=invalid_value),
                              batch_size = batch_size,
                              max_steps=50)

if grammar:
    mask_gen = GrammarMaskGenerator(task.env._max_episode_steps, grammar=settings['grammar'])
else:
    mask_gen = None



def a2c_sequence(name = 'a2c_sequence', task=None, body=None):
    config = Config()
    config.num_workers = batch_size # same thing as batch size
    config.task_fn = lambda: task
    config.optimizer_fn = lambda params: torch.optim.RMSprop(params, lr=0.0007)
コード例 #6
0
import torch
import numpy as np
from rdkit import Chem
from rdkit.Chem import rdMolDescriptors as desc
from rdkit.Chem.Draw import MolToFile
import os, inspect
from generative_playground.molecules.rdkit_utils.rdkit_utils import NormalizedScorer
import numpy as np

scorer = NormalizedScorer()
root_location = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))


def model_process_fun(model_out, visdom, n):
    # TODO: rephrase this to return a dict, instead of calling visdom directly
    from rdkit import Chem
    from rdkit.Chem.Draw import MolToFile
    # actions, logits, rewards, terminals, info = model_out
    smiles, valid = model_out['info']
    valid = to_numpy(valid)
    total_rewards = to_numpy(model_out['rewards'])
    if len(total_rewards.shape) > 1:
        total_rewards = total_rewards.sum(1)
    best_ind = np.argmax(total_rewards)
    this_smile = smiles[best_ind]
    mol = Chem.MolFromSmiles(this_smile)
    pic_save_path = os.path.realpath(root_location + '/images/' + 'tmp.svg')
    if mol is not None:
        try:
            MolToFile(mol, pic_save_path, imageType='svg')
コード例 #7
0
def train_policy_gradient_ppo(molecules=True,
                              grammar=True,
                              smiles_source='ZINC',
                              EPOCHS=None,
                              BATCH_SIZE=None,
                              reward_fun_on=None,
                              reward_fun_off=None,
                              max_steps=277,
                              lr_on=2e-4,
                              lr_discrim=1e-4,
                              discrim_wt=2,
                              p_thresh=0.5,
                              drop_rate=0.0,
                              plot_ignore_initial=0,
                              randomize_reward=False,
                              save_file_root_name=None,
                              reward_sm=0.0,
                              preload_file_root_name=None,
                              anchor_file=None,
                              anchor_weight=0.0,
                              decoder_type='action',
                              plot_prefix='',
                              dashboard='policy gradient',
                              smiles_save_file=None,
                              on_policy_loss_type='best',
                              priors=True,
                              node_temperature_schedule=lambda x: 1.0,
                              rule_temperature_schedule=lambda x: 1.0,
                              eps=0.0,
                              half_float=False,
                              extra_repetition_penalty=0.0):
    root_location = os.path.dirname(
        os.path.abspath(inspect.getfile(inspect.currentframe())))
    root_location = root_location + '/../../'

    def full_path(x):
        return os.path.realpath(root_location + 'pretrained/' + x)

    if save_file_root_name is not None:
        gen_save_file = save_file_root_name + '_gen.h5'
        disc_save_file = save_file_root_name + '_disc.h5'
    if preload_file_root_name is not None:
        gen_preload_file = preload_file_root_name + '_gen.h5'
        disc_preload_file = preload_file_root_name + '_disc.h5'

    settings = get_settings(molecules=molecules, grammar=grammar)
    codec = get_codec(molecules, grammar, settings['max_seq_length'])
    discrim_model = GraphDiscriminator(codec.grammar, drop_rate=drop_rate)
    if False and preload_file_root_name is not None:
        try:
            preload_path = full_path(disc_preload_file)
            discrim_model.load_state_dict(torch.load(preload_path),
                                          strict=False)
            print('Discriminator weights loaded successfully!')
        except Exception as e:
            print('failed to load discriminator weights ' + str(e))

    zinc_data = get_smiles_from_database(source=smiles_source)
    zinc_set = set(zinc_data)
    lookbacks = [BATCH_SIZE, 10 * BATCH_SIZE, 100 * BATCH_SIZE]
    history_data = [deque(['O'], maxlen=lb) for lb in lookbacks]

    def originality_mult(smiles_list):
        out = []
        for s in smiles_list:
            if s in zinc_set:
                out.append(0.5)
            elif s in history_data[0]:
                out.append(0.5)
            elif s in history_data[1]:
                out.append(0.70)
            elif s in history_data[2]:
                out.append(0.85)
            else:
                out.append(1.0)
        return np.array(out)

    def sigmoid(x):
        return 1 / (1 + np.exp(-x))

    def discriminator_reward_mult(smiles_list):
        orig_state = discrim_model.training
        discrim_model.eval()
        discrim_out_logits = discrim_model(smiles_list)['p_zinc']
        discrim_probs = F.softmax(discrim_out_logits, dim=1)
        prob_zinc = discrim_probs[:, 1].detach().cpu().numpy()
        if orig_state:
            discrim_model.train()
        return prob_zinc

    def apply_originality_penalty(x, orig_mult):
        assert x <= 1, "Reward must be no greater than 0"
        if x > 0.5:  # want to punish nearly-perfect scores less and less
            out = math.pow(x, 1 / orig_mult)
        else:  # continuous join at 0.5
            penalty = math.pow(0.5, 1 / orig_mult) - 0.5
            out = x + penalty

        out -= extra_repetition_penalty * (1 - 1 / orig_mult)
        return out

    def adj_reward(x):
        if discrim_wt > 1e-5:
            p = discriminator_reward_mult(x)
        else:
            p = 0
        rwd = np.array(reward_fun_on(x))
        orig_mult = originality_mult(x)
        # we assume the reward is <=1, first term will dominate for reward <0, second for 0 < reward < 1
        # reward = np.minimum(rwd/orig_mult, np.power(np.abs(rwd),1/orig_mult))
        reward = np.array([
            apply_originality_penalty(x, om) for x, om in zip(rwd, orig_mult)
        ])
        out = reward + discrim_wt * p * orig_mult
        return out

    def adj_reward_old(x):
        p = discriminator_reward_mult(x)
        w = sigmoid(-(p - p_thresh) / 0.01)
        if randomize_reward:
            rand = np.random.uniform(size=p.shape)
            w *= rand
        reward = np.maximum(reward_fun_on(x), p_thresh)
        weighted_reward = w * p + (1 - w) * reward
        out = weighted_reward * originality_mult(x)  #
        return out

    if EPOCHS is not None:
        settings['EPOCHS'] = EPOCHS
    if BATCH_SIZE is not None:
        settings['BATCH_SIZE'] = BATCH_SIZE

    task = SequenceGenerationTask(molecules=molecules,
                                  grammar=grammar,
                                  reward_fun=adj_reward,
                                  batch_size=BATCH_SIZE,
                                  max_steps=max_steps,
                                  save_dataset=None)

    node_policy = SoftmaxRandomSamplePolicy(temperature=torch.tensor(1.0),
                                            eps=eps)
    rule_policy = SoftmaxRandomSamplePolicy(temperature=torch.tensor(1.0),
                                            eps=eps)

    model = get_decoder(molecules,
                        grammar,
                        z_size=settings['z_size'],
                        decoder_hidden_n=200,
                        feature_len=codec.feature_len(),
                        max_seq_length=max_steps,
                        drop_rate=drop_rate,
                        batch_size=BATCH_SIZE,
                        decoder_type=decoder_type,
                        reward_fun=adj_reward,
                        task=task,
                        node_policy=node_policy,
                        rule_policy=rule_policy,
                        priors=priors)[0]

    if preload_file_root_name is not None:
        try:
            preload_path = full_path(gen_preload_file)
            model.load_state_dict(torch.load(preload_path, map_location='cpu'),
                                  strict=False)
            print('Generator weights loaded successfully!')
        except Exception as e:
            print('failed to load generator weights ' + str(e))

    anchor_model = None

    from generative_playground.molecules.rdkit_utils.rdkit_utils import NormalizedScorer
    import numpy as np
    scorer = NormalizedScorer()

    if reward_fun_off is None:
        reward_fun_off = reward_fun_on

    # construct the loader to feed the discriminator
    def make_callback(data):
        def hc(inputs, model, outputs, loss_fn, loss):
            graphs = outputs['graphs']
            smiles = [g.to_smiles() for g in graphs]
            for s in smiles:  # only store unique instances of molecules so discriminator can't guess on frequency
                if s not in data:
                    data.append(s)

        return hc

    class TemperatureCallback:
        def __init__(self, policy, temperature_function):
            self.policy = policy
            self.counter = 0
            self.temp_fun = temperature_function

        def __call__(self, inputs, model, outputs, loss_fn, loss):
            self.counter += 1
            target_temp = self.temp_fun(self.counter)
            self.policy.set_temperature(target_temp)

    # need to have something there to begin with, else the DataLoader constructor barfs

    def get_rl_fitter(model,
                      loss_obj,
                      train_gen,
                      save_path,
                      fit_plot_prefix='',
                      model_process_fun=None,
                      lr=None,
                      extra_callbacks=[],
                      loss_display_cap=float('inf'),
                      anchor_model=None,
                      anchor_weight=0):
        nice_params = filter(lambda p: p.requires_grad, model.parameters())
        optimizer = optim.Adam(nice_params, lr=lr, eps=1e-4)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.99)

        if dashboard is not None:
            metric_monitor = MetricPlotter(
                plot_prefix=fit_plot_prefix,
                loss_display_cap=loss_display_cap,
                dashboard_name=dashboard,
                plot_ignore_initial=plot_ignore_initial,
                process_model_fun=model_process_fun,
                smooth_weight=reward_sm)
        else:
            metric_monitor = None

        checkpointer = Checkpointer(valid_batches_to_checkpoint=1,
                                    save_path=save_path,
                                    save_always=True)

        fitter = fit_rl(train_gen=train_gen,
                        model=model,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        epochs=EPOCHS,
                        loss_fn=loss_obj,
                        grad_clip=5,
                        half_float=half_float,
                        anchor_model=anchor_model,
                        anchor_weight=anchor_weight,
                        callbacks=[metric_monitor, checkpointer] +
                        extra_callbacks)

        return fitter

    def my_gen():
        for _ in range(1000):
            yield to_gpu(torch.zeros(BATCH_SIZE, settings['z_size']))

    # the on-policy fitter

    gen_extra_callbacks = [make_callback(d) for d in history_data]

    if smiles_save_file is not None:
        smiles_save_path = os.path.realpath(root_location + 'pretrained/' +
                                            smiles_save_file)
        gen_extra_callbacks.append(MoleculeSaver(smiles_save_path, gzip=True))
        print('Saved SMILES to {}'.format(smiles_save_file))

    if node_temperature_schedule is not None:
        gen_extra_callbacks.append(
            TemperatureCallback(node_policy, node_temperature_schedule))

    if rule_temperature_schedule is not None:
        gen_extra_callbacks.append(
            TemperatureCallback(rule_policy, rule_temperature_schedule))

    fitter1 = get_rl_fitter(
        model,
        PolicyGradientLoss(on_policy_loss_type),  # last_reward_wgt=reward_sm),
        GeneratorToIterable(my_gen),
        full_path(gen_save_file),
        plot_prefix + 'on-policy',
        model_process_fun=model_process_fun,
        lr=lr_on,
        extra_callbacks=gen_extra_callbacks,
        anchor_model=anchor_model,
        anchor_weight=anchor_weight)
    #
    # # get existing molecule data to add training
    pre_dataset = EvenlyBlendedDataset(
        2 * [history_data[0]] + history_data[1:],
        labels=False)  # a blend of 3 time horizons
    dataset = EvenlyBlendedDataset([pre_dataset, zinc_data], labels=True)
    discrim_loader = DataLoader(dataset, shuffle=True, batch_size=50)

    class MyLoss(nn.Module):
        def __init__(self):
            super().__init__()
            self.celoss = nn.CrossEntropyLoss()

        def forward(self, x):
            # tmp = discriminator_reward_mult(x['smiles'])
            # tmp2 = F.softmax(x['p_zinc'], dim=1)[:,1].detach().cpu().numpy()
            # import numpy as np
            # assert np.max(np.abs(tmp-tmp2)) < 1e-6
            return self.celoss(x['p_zinc'].to(device),
                               x['dataset_index'].to(device))

    fitter2 = get_rl_fitter(discrim_model,
                            MyLoss(),
                            IterableTransform(
                                discrim_loader, lambda x: {
                                    'smiles': x['X'],
                                    'dataset_index': x['dataset_index']
                                }),
                            full_path(disc_save_file),
                            plot_prefix + ' discriminator',
                            lr=lr_discrim,
                            model_process_fun=None)

    def on_policy_gen(fitter, model):
        while True:
            # model.policy = SoftmaxRandomSamplePolicy()#bias=codec.grammar.get_log_frequencies())
            yield next(fitter)

    return model, fitter1, fitter2  #,on_policy_gen(fitter1, model)
コード例 #8
0
from generative_playground.molecules.model_settings import get_settings
from generative_playground.molecules.train.pg.hypergraph.main_train_policy_gradient_minimal import train_policy_gradient
from generative_playground.codec.hypergraph_grammar import GrammarInitializer



batch_size = 8# 20
drop_rate = 0.5
molecules = True
grammar_cache = 'hyper_grammar_guac_10k_with_clique_collapse.pickle'#'hyper_grammar.pickle'
grammar = 'hypergraph:' + grammar_cache
# settings = get_settings(molecules, grammar)
# max_steps = 277  # settings['max_seq_length']
invalid_value = -3.5
atom_penalty = lambda x: -0.05*(np.maximum(37, np.array(num_atoms(x)))-37)
scorer = NormalizedScorer(invalid_value=invalid_value, normalize_scores=True)
reward_fun = lambda x: np.tanh(0.1*scorer(x)) + atom_penalty(x) # lambda x: reward_aromatic_rings(x)#
# later will run this ahead of time
# gi = GrammarInitializer(grammar_cache)

max_steps = 60
root_name = 'classic_logP'
model, gen_fitter, disc_fitter = train_policy_gradient(molecules,
                                                       grammar,
                                                       EPOCHS=100,
                                                       BATCH_SIZE=batch_size,
                                                       reward_fun_on=reward_fun,
                                                       max_steps=max_steps,
                                                       lr_on=3e-5,
                                                       lr_discrim=5e-4,
                                                       discrim_wt=0.3,
コード例 #9
0
def reward_aliphatic_rings(smiles):
    '''
    A simple reward to encourage larger molecule length
    :param smiles: list of strings
    :return: reward, list of float
    '''
    return np.array([-1 if num is None else num for num in num_aliphatic_rings(smiles)])


batch_size = 40
drop_rate = 0.3
molecules = True
grammar = 'new'#True#
settings = get_settings(molecules, grammar)
invalid_value = -10
scorer = NormalizedScorer(invalid_value=invalid_value, sa_mult=0)
max_steps = 277 #settings['max_seq_length']

def second_score(smiles):
    pre_scores = 2.5 + scorer.get_scores(smiles)[0]
    score = np.power(pre_scores.prod(1), 0.333)
    for i in range(len(score)):
        if np.isnan(score[i]):
            score[i] = -1
    return score

reward_fun = lambda x: np.array([1 for _ in x])

model, fitter1, fitter2 = train_policy_gradient(molecules,
                                                grammar,
                                                EPOCHS=100,
コード例 #10
0
def train_policy_gradient(molecules=True,
                          grammar=True,
                          EPOCHS=None,
                          BATCH_SIZE=None,
                          reward_fun_on=None,
                          reward_fun_off=None,
                          max_steps=277,
                          lr_on=2e-4,
                          lr_off=1e-4,
                          drop_rate=0.0,
                          plot_ignore_initial=0,
                          save_file=None,
                          preload_file=None,
                          anchor_file=None,
                          anchor_weight=0.0,
                          decoder_type='action',
                          plot_prefix='',
                          dashboard='policy gradient',
                          smiles_save_file=None,
                          on_policy_loss_type='best',
                          off_policy_loss_type='mean',
                          sanity_checks=True):

    root_location = os.path.dirname(
        os.path.abspath(inspect.getfile(inspect.currentframe())))
    root_location = root_location + '/../'
    save_path = root_location + 'pretrained/' + save_file
    smiles_save_path = root_location + 'pretrained/' + smiles_save_file

    settings = get_settings(molecules=molecules, grammar=grammar)

    if EPOCHS is not None:
        settings['EPOCHS'] = EPOCHS
    if BATCH_SIZE is not None:
        settings['BATCH_SIZE'] = BATCH_SIZE

    save_dataset = IncrementingHDF5Dataset(smiles_save_path)

    task = SequenceGenerationTask(molecules=molecules,
                                  grammar=grammar,
                                  reward_fun=reward_fun_on,
                                  batch_size=BATCH_SIZE,
                                  max_steps=max_steps,
                                  save_dataset=save_dataset)

    def get_model(sanity_checks=sanity_checks):
        return get_decoder(
            molecules,
            grammar,
            z_size=settings['z_size'],
            decoder_hidden_n=200,
            feature_len=None,  # made redundant, need to factor it out
            max_seq_length=max_steps,
            drop_rate=drop_rate,
            decoder_type=decoder_type,
            task=task)[0]

    model = get_model()

    if preload_file is not None:
        try:
            preload_path = root_location + 'pretrained/' + preload_file
            model.load_state_dict(torch.load(preload_path))
        except:
            pass

    anchor_model = None
    if anchor_file is not None:
        anchor_model = get_model()
        try:
            anchor_path = root_location + 'pretrained/' + anchor_file
            anchor_model.load_state_dict(torch.load(anchor_path))
        except:
            anchor_model = None

    from generative_playground.molecules.rdkit_utils.rdkit_utils import NormalizedScorer
    import rdkit.Chem.rdMolDescriptors as desc
    import numpy as np
    scorer = NormalizedScorer()

    def model_process_fun(model_out, visdom, n):
        from rdkit import Chem
        from rdkit.Chem.Draw import MolToFile
        actions, logits, rewards, terminals, info = model_out
        smiles, valid = info
        total_rewards = rewards.sum(1)
        best_ind = torch.argmax(total_rewards).data.item()
        this_smile = smiles[best_ind]
        mol = Chem.MolFromSmiles(this_smile)
        pic_save_path = root_location + 'images/' + 'test.svg'
        if mol is not None:
            try:
                MolToFile(mol, pic_save_path, imageType='svg')
                with open(pic_save_path, 'r') as myfile:
                    data = myfile.read()
                data = data.replace('svg:', '')
                visdom.append('best molecule of batch', 'svg', svgstr=data)
            except:
                pass
            scores, norm_scores = scorer.get_scores([this_smile])
            visdom.append(
                'score component',
                'line',
                X=np.array([n]),
                Y=np.array(
                    [[x for x in norm_scores[0]] + [norm_scores[0].sum()] +
                     [scores[0].sum()] + [desc.CalcNumAromaticRings(mol)]]),
                opts={
                    'legend': [
                        'logP', 'SA', 'cycle', 'norm_reward', 'reward',
                        'Aromatic rings'
                    ]
                })
            visdom.append('fraction valid',
                          'line',
                          X=np.array([n]),
                          Y=np.array([valid.mean().data.item()]))

    if reward_fun_off is None:
        reward_fun_off = reward_fun_on

    def get_fitter(model,
                   loss_obj,
                   fit_plot_prefix='',
                   model_process_fun=None,
                   lr=None,
                   loss_display_cap=float('inf'),
                   anchor_model=None,
                   anchor_weight=0):
        nice_params = filter(lambda p: p.requires_grad, model.parameters())
        optimizer = optim.Adam(nice_params, lr=lr)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.99)

        if dashboard is not None:
            metric_monitor = MetricPlotter(
                plot_prefix=fit_plot_prefix,
                loss_display_cap=loss_display_cap,
                dashboard_name=dashboard,
                plot_ignore_initial=plot_ignore_initial,
                process_model_fun=model_process_fun)
        else:
            metric_monitor = None

        checkpointer = Checkpointer(valid_batches_to_checkpoint=1,
                                    save_path=save_path,
                                    save_always=True)

        def my_gen():
            for _ in range(1000):
                yield to_gpu(torch.zeros(BATCH_SIZE, settings['z_size']))

        fitter = fit_rl(train_gen=my_gen,
                        model=model,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        epochs=EPOCHS,
                        loss_fn=loss_obj,
                        grad_clip=5,
                        anchor_model=anchor_model,
                        anchor_weight=anchor_weight,
                        metric_monitor=metric_monitor,
                        checkpointer=checkpointer)

        return fitter

    # the on-policy fitter
    fitter1 = get_fitter(model,
                         PolicyGradientLoss(on_policy_loss_type),
                         plot_prefix + 'on-policy',
                         model_process_fun=model_process_fun,
                         lr=lr_on,
                         anchor_model=anchor_model,
                         anchor_weight=anchor_weight)
    # get existing molecule data to add training
    main_dataset = DatasetFromHDF5(settings['data_path'], 'actions')

    # TODO change call to a simple DataLoader, no validation
    train_loader, valid_loader = train_valid_loaders(main_dataset,
                                                     valid_fraction=0.1,
                                                     batch_size=BATCH_SIZE,
                                                     pin_memory=use_gpu)

    fitter2 = get_fitter(model,
                         PolicyGradientLoss(off_policy_loss_type),
                         plot_prefix + ' off-policy',
                         lr=lr_off,
                         model_process_fun=model_process_fun,
                         loss_display_cap=125)

    def on_policy_gen(fitter, model):
        while True:
            model.policy = SoftmaxRandomSamplePolicy()
            yield next(fitter)

    def off_policy_gen(fitter, data_gen, model):
        while True:
            data_iter = data_gen.__iter__()
            try:
                x_actions = next(data_iter).to(torch.int64)
                model.policy = PolicyFromTarget(x_actions)
                yield next(fitter)
            except StopIteration:
                data_iter = data_gen.__iter__()

    return model, on_policy_gen(fitter1,
                                model), off_policy_gen(fitter2, train_loader,
                                                       model)
コード例 #11
0
    '''
    if not len(smiles):
        return -1  # an empty string is invalid for our purposes
    atoms = num_aromatic_rings(smiles)
    return [-1 if num is None else num + 0.5 for num in atoms]


batch_size = 40
drop_rate = 0.3
molecules = True
grammar = True
settings = get_settings(molecules, grammar)
max_steps = 277  #settings['max_seq_length']
invalid_value = -3.5
reward_fun = lambda x: 2.5 + NormalizedScorer(
    settings['data_path'], invalid_value=invalid_value)(
        x)  #lambda x: reward_aromatic_rings(x)#

model, fitter1, fitter2 = train_policy_gradient(
    molecules,
    grammar,
    EPOCHS=100,
    BATCH_SIZE=batch_size,
    reward_fun_on=reward_fun,
    max_steps=max_steps,
    lr=1e-4,
    drop_rate=drop_rate,
    decoder_type='attention',
    plot_prefix='rings ',
    dashboard='policy gradient',
    save_file='policy_gradient_rings.h5',
コード例 #12
0
def train_policy_gradient(molecules=True,
                          grammar=True,
                          EPOCHS=None,
                          BATCH_SIZE=None,
                          reward_fun_on=None,
                          reward_fun_off=None,
                          max_steps=277,
                          lr_on=2e-4,
                          lr_discrim=1e-4,
                          p_thresh=0.5,
                          drop_rate=0.0,
                          plot_ignore_initial=0,
                          randomize_reward=False,
                          save_file=None,
                          reward_sm=0.0,
                          preload_file=None,
                          anchor_file=None,
                          anchor_weight=0.0,
                          decoder_type='action',
                          plot_prefix='',
                          dashboard='policy gradient',
                          smiles_save_file=None,
                          on_policy_loss_type='best',
                          off_policy_loss_type='mean',
                          sanity_checks=True):
    root_location = os.path.dirname(
        os.path.abspath(inspect.getfile(inspect.currentframe())))
    root_location = root_location + '/../../'
    gen_save_path = root_location + 'pretrained/gen_' + save_file
    disc_save_path = root_location + 'pretrained/disc_' + save_file

    if smiles_save_file is not None:
        smiles_save_path = root_location + 'pretrained/' + smiles_save_file
        save_dataset = IncrementingHDF5Dataset(smiles_save_path)
    else:
        save_dataset = None

    settings = get_settings(molecules=molecules, grammar=grammar)
    codec = get_codec(molecules, grammar, settings['max_seq_length'])
    discrim_model = GraphDiscriminator(codec.grammar, drop_rate=drop_rate)

    zinc_data = get_zinc_smiles()
    zinc_set = set(zinc_data)
    lookbacks = [BATCH_SIZE, 10 * BATCH_SIZE, 100 * BATCH_SIZE]
    history_data = [deque(['O'], maxlen=lb) for lb in lookbacks]

    def originality_mult(smiles_list):
        out = []
        for s in smiles_list:
            if s in zinc_set:
                out.append(0.5)
            elif s in history_data[0]:
                out.append(0.5)
            elif s in history_data[1]:
                out.append(0.70)
            elif s in history_data[2]:
                out.append(0.85)
            else:
                out.append(1.0)
        return np.array(out)

    def sigmoid(x):
        tmp = -x  #(
        return 1 / (1 + np.exp(-x))

    def discriminator_reward_mult(smiles_list):
        orig_state = discrim_model.training
        discrim_model.eval()
        discrim_out_logits = discrim_model(smiles_list)['p_zinc']
        discrim_probs = F.softmax(discrim_out_logits, dim=1)
        prob_zinc = discrim_probs[:, 1].detach().cpu().numpy()
        if orig_state:
            discrim_model.train()
        return prob_zinc

    def adj_reward(x):
        p = discriminator_reward_mult(x)
        reward = np.maximum(reward_fun_on(x), 0)
        out = reward * originality_mult(x) + 2 * p
        return out

    def adj_reward_old(x):
        p = discriminator_reward_mult(x)
        w = sigmoid(-(p - p_thresh) / 0.01)
        if randomize_reward:
            rand = np.random.uniform(size=p.shape)
            w *= rand
        reward = np.maximum(reward_fun_on(x), p_thresh)
        weighted_reward = w * p + (1 - w) * reward
        out = weighted_reward * originality_mult(x)  #
        return out

    if EPOCHS is not None:
        settings['EPOCHS'] = EPOCHS
    if BATCH_SIZE is not None:
        settings['BATCH_SIZE'] = BATCH_SIZE

    task = SequenceGenerationTask(molecules=molecules,
                                  grammar=grammar,
                                  reward_fun=adj_reward,
                                  batch_size=BATCH_SIZE,
                                  max_steps=max_steps,
                                  save_dataset=save_dataset)

    model = get_decoder(molecules,
                        grammar,
                        z_size=settings['z_size'],
                        decoder_hidden_n=200,
                        feature_len=codec.feature_len(),
                        max_seq_length=max_steps,
                        drop_rate=drop_rate,
                        decoder_type=decoder_type,
                        task=task)[0]

    # TODO: really ugly, refactor! In fact this model doesn't need a MaskingHead at all!
    model.stepper.model.mask_gen.priors = True  #'conditional' # use empirical priors for the mask gen
    # if preload_file is not None:
    #     try:
    #         preload_path = root_location + 'pretrained/' + preload_file
    #         model.load_state_dict(torch.load(preload_path))
    #     except:
    #         pass

    anchor_model = None

    from generative_playground.molecules.rdkit_utils.rdkit_utils import NormalizedScorer
    import rdkit.Chem.rdMolDescriptors as desc
    import numpy as np
    scorer = NormalizedScorer()

    def model_process_fun(model_out, visdom, n):
        # TODO: rephrase this to return a dict, instead of calling visdom directly
        from rdkit import Chem
        from rdkit.Chem.Draw import MolToFile
        # actions, logits, rewards, terminals, info = model_out
        smiles, valid = model_out['info']
        total_rewards = model_out['rewards'].sum(1)
        best_ind = torch.argmax(total_rewards).data.item()
        this_smile = smiles[best_ind]
        mol = Chem.MolFromSmiles(this_smile)
        pic_save_path = root_location + 'images/' + 'tmp.svg'
        if mol is not None:
            try:
                MolToFile(mol, pic_save_path, imageType='svg')
                with open(pic_save_path, 'r') as myfile:
                    data = myfile.read()
                data = data.replace('svg:', '')
                visdom.append('best molecule of batch', 'svg', svgstr=data)
            except Exception as e:
                print(e)
            scores, norm_scores = scorer.get_scores([this_smile])
            visdom.append(
                'score component',
                'line',
                X=np.array([n]),
                Y=np.array(
                    [[x for x in norm_scores[0]] + [norm_scores[0].sum()] +
                     [scores[0].sum()] + [desc.CalcNumAromaticRings(mol)] +
                     [total_rewards[best_ind].item()]]),
                opts={
                    'legend': [
                        'logP', 'SA', 'cycle', 'norm_reward', 'reward',
                        'Aromatic rings', 'eff_reward'
                    ]
                })
            visdom.append('fraction valid',
                          'line',
                          X=np.array([n]),
                          Y=np.array([valid.mean().data.item()]))

    if reward_fun_off is None:
        reward_fun_off = reward_fun_on

    # construct the loader to feed the discriminator
    def make_callback(data):
        def hc(inputs, model, outputs, loss_fn, loss):
            graphs = outputs['graphs']
            smiles = [g.to_smiles() for g in graphs]
            for s in smiles:  # only store unique instances of molecules so discriminator can't guess on frequency
                if s not in data:
                    data.append(s)

        return hc

    # need to have something there to begin with, else the DataLoader constructor barfs

    def get_rl_fitter(model,
                      loss_obj,
                      train_gen,
                      save_path,
                      fit_plot_prefix='',
                      model_process_fun=None,
                      lr=None,
                      extra_callbacks=[],
                      loss_display_cap=float('inf'),
                      anchor_model=None,
                      anchor_weight=0):
        nice_params = filter(lambda p: p.requires_grad, model.parameters())
        optimizer = optim.Adam(nice_params, lr=lr)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.99)

        if dashboard is not None:
            metric_monitor = MetricPlotter(
                plot_prefix=fit_plot_prefix,
                loss_display_cap=loss_display_cap,
                dashboard_name=dashboard,
                plot_ignore_initial=plot_ignore_initial,
                process_model_fun=model_process_fun,
                smooth_weight=0.9)
        else:
            metric_monitor = None

        checkpointer = Checkpointer(valid_batches_to_checkpoint=1,
                                    save_path=save_path,
                                    save_always=True)

        fitter = fit_rl(train_gen=train_gen,
                        model=model,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        epochs=EPOCHS,
                        loss_fn=loss_obj,
                        grad_clip=5,
                        anchor_model=anchor_model,
                        anchor_weight=anchor_weight,
                        callbacks=[metric_monitor, checkpointer] +
                        extra_callbacks)

        return fitter

    class GeneratorToIterable:
        def __init__(self, gen):
            self.gen = gen
            # we assume the generator is finite
            self.len = 0
            for _ in gen():
                self.len += 1

        def __len__(self):
            return self.len

        def __iter__(self):
            return self.gen()

    def my_gen():
        for _ in range(1000):
            yield to_gpu(torch.zeros(BATCH_SIZE, settings['z_size']))

    # the on-policy fitter

    history_callbacks = [make_callback(d) for d in history_data]
    fitter1 = get_rl_fitter(model,
                            PolicyGradientLoss(on_policy_loss_type,
                                               last_reward_wgt=reward_sm),
                            GeneratorToIterable(my_gen),
                            gen_save_path,
                            plot_prefix + 'on-policy',
                            model_process_fun=model_process_fun,
                            lr=lr_on,
                            extra_callbacks=history_callbacks,
                            anchor_model=anchor_model,
                            anchor_weight=anchor_weight)
    #
    # # get existing molecule data to add training
    pre_dataset = EvenlyBlendedDataset(
        2 * [history_data[0]] + history_data[1:],
        labels=False)  # a blend of 3 time horizons
    dataset = EvenlyBlendedDataset([pre_dataset, zinc_data], labels=True)
    discrim_loader = DataLoader(dataset, shuffle=True, batch_size=50)
    celoss = nn.CrossEntropyLoss()

    def my_loss(x):
        # tmp = discriminator_reward_mult(x['smiles'])
        # tmp2 = F.softmax(x['p_zinc'], dim=1)[:,1].detach().cpu().numpy()
        # import numpy as np
        # assert np.max(np.abs(tmp-tmp2)) < 1e-6
        return celoss(x['p_zinc'].to(device), x['dataset_index'].to(device))

    fitter2 = get_rl_fitter(discrim_model,
                            my_loss,
                            IterableTransform(
                                discrim_loader, lambda x: {
                                    'smiles': x['X'],
                                    'dataset_index': x['dataset_index']
                                }),
                            disc_save_path,
                            plot_prefix + ' discriminator',
                            lr=lr_discrim,
                            model_process_fun=None)

    def on_policy_gen(fitter, model):
        while True:
            model.policy = SoftmaxRandomSamplePolicy(
            )  #bias=codec.grammar.get_log_frequencies())
            yield next(fitter)

    return model, on_policy_gen(fitter1, model), fitter2