def get_vae(molecules=True,
            grammar=True,
            weights_file=None,
            epsilon_std=1,
            decoder_type='step',
            **kwargs):
    model_args = get_model_args(molecules=molecules, grammar=grammar)
    for key, value in kwargs.items():
        if key in model_args:
            model_args[key] = value
    sample_z = model_args.pop('sample_z')

    encoder_args = [
        'feature_len', 'max_seq_length', 'cnn_encoder_params', 'drop_rate',
        'encoder_type', 'rnn_encoder_hidden_n'
    ]
    encoder = get_encoder(**{
        key: value
        for key, value in model_args.items() if key in encoder_args
    })

    decoder_args = [
        'z_size', 'decoder_hidden_n', 'feature_len', 'max_seq_length',
        'drop_rate', 'batch_size'
    ]
    decoder, _ = get_decoder(molecules,
                             grammar,
                             decoder_type=decoder_type,
                             **{
                                 key: value
                                 for key, value in model_args.items()
                                 if key in decoder_args
                             })

    model = generative_playground.models.heads.vae.VariationalAutoEncoderHead(
        encoder=encoder,
        decoder=decoder,
        sample_z=sample_z,
        epsilon_std=epsilon_std,
        z_size=model_args['z_size'])

    if weights_file is not None:
        model.load(weights_file)

    settings = get_settings(molecules=molecules, grammar=grammar)
    codec = get_codec(molecules,
                      grammar,
                      max_seq_length=settings['max_seq_length'])
    # codec.set_model(model)  # todo do we ever use this?
    return model, codec
Beispiel #2
0
 def __init__(self,
              invalid_value=-3 * 3.5,
              sa_mult=0.0,
              sa_thresh=0.5,
              normalize_scores=False):
     settings = get_settings(True, True)
     h5f = h5py.File(settings['data_path'], 'r')
     self.means = np.array(h5f['score_mean'])[:3]
     self.stds = np.array(h5f['score_std'])[:3]
     self.invalid_value = invalid_value
     self.sa_mult = sa_mult
     self.sa_thresh = sa_thresh
     self.normalize_scores = normalize_scores
     h5f.close()
def get_model_args(molecules,
                   grammar,
                   drop_rate=0.5,
                   sample_z=False,
                   encoder_type='rnn'):
    settings = get_settings(molecules, grammar)
    codec = get_codec(molecules, grammar, settings['max_seq_length'])
    model_args = {
        'z_size': settings['z_size'],
        'decoder_hidden_n': settings['decoder_hidden_n'],
        'feature_len': codec.feature_len(),
        'max_seq_length': settings['max_seq_length'],
        'cnn_encoder_params': settings['cnn_encoder_params'],
        'drop_rate': drop_rate,
        'sample_z': sample_z,
        'encoder_type': encoder_type,
        'rnn_encoder_hidden_n': settings['rnn_encoder_hidden_n']
    }

    return model_args
Beispiel #4
0
 def __init__(self,
              molecules=True,
              grammar=True,
              reward_fun=None,
              batch_size=1,
              max_steps=None,
              save_dataset=None):
     settings = get_settings(molecules, grammar)
     self.codec = get_codec(molecules, grammar, settings['max_seq_length'])
     self.action_dim = self.codec.feature_len()
     self.state_dim = self.action_dim
     if max_steps is None:
         self._max_episode_steps = settings['max_seq_length']
     else:
         self._max_episode_steps = max_steps
     self.reward_fun = reward_fun
     self.batch_size = batch_size
     self.save_dataset = save_dataset
     self.smiles = None
     self.seq_len = None
     self.valid = None
     self.actions = None
     self.done_rewards = None
     self.reset()
Beispiel #5
0
 def check_codec(self, input, molecules, grammar):
     settings = get_settings(molecules, grammar)
     codec = get_codec(molecules, grammar, settings['max_seq_length'])
     actions = codec.strings_to_actions([input])
     re_input = codec.actions_to_strings(actions)
     self.assertEqual(input, re_input[0])
def train_policy_gradient(molecules=True,
                          grammar=True,
                          smiles_source='ZINC',
                          EPOCHS=None,
                          BATCH_SIZE=None,
                          reward_fun_on=None,
                          reward_fun_off=None,
                          max_steps=277,
                          lr_on=2e-4,
                          lr_discrim=1e-4,
                          lr_schedule=None,
                          discrim_wt=2,
                          p_thresh=0.5,
                          drop_rate=0.0,
                          plot_ignore_initial=0,
                          randomize_reward=False,
                          save_file_root_name=None,
                          reward_sm=0.0,
                          preload_file_root_name=None,
                          anchor_file=None,
                          anchor_weight=0.0,
                          decoder_type='action',
                          plot_prefix='',
                          dashboard='policy gradient',
                          smiles_save_file=None,
                          on_policy_loss_type='best',
                          priors=True,
                          node_temperature_schedule=lambda x: 1.0,
                          rule_temperature_schedule=lambda x: 1.0,
                          eps=0.0,
                          half_float=False,
                          extra_repetition_penalty=0.0,
                          entropy_wgt=1.0):
    root_location = os.path.dirname(
        os.path.abspath(inspect.getfile(inspect.currentframe())))
    root_location = root_location + '/../../'

    def full_path(x):
        return os.path.realpath(root_location + 'pretrained/' + x)

    zinc_data = get_smiles_from_database(source=smiles_source)
    zinc_set = set(zinc_data)
    lookbacks = [BATCH_SIZE, 10 * BATCH_SIZE, 100 * BATCH_SIZE]
    history_data = [deque(['O'], maxlen=lb) for lb in lookbacks]

    if save_file_root_name is not None:
        gen_save_file = save_file_root_name + '_gen.h5'
        disc_save_file = save_file_root_name + '_disc.h5'
    if preload_file_root_name is not None:
        gen_preload_file = preload_file_root_name + '_gen.h5'
        disc_preload_file = preload_file_root_name + '_disc.h5'

    settings = get_settings(molecules=molecules, grammar=grammar)
    codec = get_codec(molecules, grammar, settings['max_seq_length'])
    discrim_model = GraphDiscriminator(codec.grammar, drop_rate=drop_rate)
    if False and preload_file_root_name is not None:
        try:
            preload_path = full_path(disc_preload_file)
            discrim_model.load_state_dict(torch.load(preload_path),
                                          strict=False)
            print('Discriminator weights loaded successfully!')
        except Exception as e:
            print('failed to load discriminator weights ' + str(e))

    if EPOCHS is not None:
        settings['EPOCHS'] = EPOCHS
    if BATCH_SIZE is not None:
        settings['BATCH_SIZE'] = BATCH_SIZE

    alt_reward_calc = AdjustedRewardCalculator(reward_fun_on,
                                               zinc_set,
                                               lookbacks,
                                               extra_repetition_penalty,
                                               discrim_wt,
                                               discrim_model=None)

    reward_fun = lambda x: adj_reward(discrim_wt,
                                      discrim_model,
                                      reward_fun_on,
                                      zinc_set,
                                      history_data,
                                      extra_repetition_penalty,
                                      x,
                                      alt_calc=alt_reward_calc)

    task = SequenceGenerationTask(molecules=molecules,
                                  grammar=grammar,
                                  reward_fun=reward_fun,
                                  batch_size=BATCH_SIZE,
                                  max_steps=max_steps,
                                  save_dataset=None)

    node_policy = SoftmaxRandomSamplePolicy(temperature=torch.tensor(1.0),
                                            eps=eps)
    rule_policy = SoftmaxRandomSamplePolicy(temperature=torch.tensor(2.0),
                                            eps=eps)

    model = get_decoder(molecules,
                        grammar,
                        z_size=settings['z_size'],
                        decoder_hidden_n=200,
                        feature_len=codec.feature_len(),
                        max_seq_length=max_steps,
                        drop_rate=drop_rate,
                        batch_size=BATCH_SIZE,
                        decoder_type=decoder_type,
                        reward_fun=reward_fun,
                        task=task,
                        node_policy=node_policy,
                        rule_policy=rule_policy,
                        priors=priors)[0]

    if preload_file_root_name is not None:
        try:
            preload_path = full_path(gen_preload_file)
            model.load_state_dict(torch.load(preload_path, map_location='cpu'),
                                  strict=False)
            print('Generator weights loaded successfully!')
        except Exception as e:
            print('failed to load generator weights ' + str(e))

    anchor_model = None

    # construct the loader to feed the discriminator
    def make_callback(data):
        def hc(inputs, model, outputs, loss_fn, loss):
            graphs = outputs['graphs']
            smiles = [g.to_smiles() for g in graphs]
            for s in smiles:  # only store unique instances of molecules so discriminator can't guess on frequency
                if s not in data:
                    data.append(s)

        return hc

    class TemperatureCallback:
        def __init__(self, policy, temperature_function):
            self.policy = policy
            self.counter = 0
            self.temp_fun = temperature_function

        def __call__(self, inputs, model, outputs, loss_fn, loss):
            self.counter += 1
            target_temp = self.temp_fun(self.counter)
            self.policy.set_temperature(target_temp)

    # need to have something there to begin with, else the DataLoader constructor barfs

    def get_rl_fitter(model,
                      loss_obj,
                      train_gen,
                      save_path,
                      fit_plot_prefix='',
                      model_process_fun=None,
                      lr=None,
                      lr_schedule=lr_schedule,
                      extra_callbacks=[],
                      loss_display_cap=float('inf'),
                      anchor_model=None,
                      anchor_weight=0):
        nice_params = filter(lambda p: p.requires_grad, model.parameters())
        optimizer = optim.Adam(nice_params, lr=lr, eps=1e-4)
        if lr_schedule is None:
            lr_schedule = lambda x: 1.0
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_schedule)

        if dashboard is not None:
            metric_monitor = MetricPlotter(
                plot_prefix=fit_plot_prefix,
                loss_display_cap=loss_display_cap,
                dashboard_name=dashboard,
                plot_ignore_initial=plot_ignore_initial,
                process_model_fun=model_process_fun,
                smooth_weight=reward_sm,
                save_location=os.path.dirname(save_path))
        else:
            metric_monitor = None

        checkpointer = Checkpointer(valid_batches_to_checkpoint=10,
                                    save_path=save_path,
                                    save_always=True)

        fitter = fit_rl(train_gen=train_gen,
                        model=model,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        epochs=EPOCHS,
                        loss_fn=loss_obj,
                        grad_clip=5,
                        half_float=half_float,
                        anchor_model=anchor_model,
                        anchor_weight=anchor_weight,
                        callbacks=[metric_monitor, checkpointer] +
                        extra_callbacks)

        return fitter

    class GeneratorToIterable:
        def __init__(self, gen):
            self.gen = gen
            # we assume the generator is finite
            self.len = 0
            for _ in gen():
                self.len += 1

        def __len__(self):
            return self.len

        def __iter__(self):
            return self.gen()

    def my_gen():
        for _ in range(1000):
            yield to_gpu(torch.zeros(BATCH_SIZE, settings['z_size']))

    # the on-policy fitter

    gen_extra_callbacks = [make_callback(d) for d in history_data]

    if smiles_save_file is not None:
        smiles_save_path = os.path.realpath(root_location + 'pretrained/' +
                                            smiles_save_file)
        gen_extra_callbacks.append(MoleculeSaver(smiles_save_path, gzip=True))
        print('Saved SMILES to {}'.format(smiles_save_file))

    if node_temperature_schedule is not None:
        gen_extra_callbacks.append(
            TemperatureCallback(node_policy, node_temperature_schedule))

    if rule_temperature_schedule is not None:
        gen_extra_callbacks.append(
            TemperatureCallback(rule_policy, rule_temperature_schedule))

    fitter1 = get_rl_fitter(
        model,
        PolicyGradientLoss(
            on_policy_loss_type,
            entropy_wgt=entropy_wgt),  # last_reward_wgt=reward_sm),
        GeneratorToIterable(my_gen),
        full_path(gen_save_file),
        plot_prefix + 'on-policy',
        model_process_fun=model_process_fun,
        lr=lr_on,
        extra_callbacks=gen_extra_callbacks,
        anchor_model=anchor_model,
        anchor_weight=anchor_weight)
    #
    # # get existing molecule data to add training
    pre_dataset = EvenlyBlendedDataset(
        2 * [history_data[0]] + history_data[1:],
        labels=False)  # a blend of 3 time horizons
    dataset = EvenlyBlendedDataset([pre_dataset, zinc_data], labels=True)
    discrim_loader = DataLoader(dataset, shuffle=True, batch_size=50)

    class MyLoss(nn.Module):
        def __init__(self):
            super().__init__()
            self.celoss = nn.CrossEntropyLoss()

        def forward(self, x):
            # tmp = discriminator_reward_mult(x['smiles'])
            # tmp2 = F.softmax(x['p_zinc'], dim=1)[:,1].detach().cpu().numpy()
            # import numpy as np
            # assert np.max(np.abs(tmp-tmp2)) < 1e-6
            return self.celoss(x['p_zinc'].to(device),
                               x['dataset_index'].to(device))

    fitter2 = get_rl_fitter(discrim_model,
                            MyLoss(),
                            IterableTransform(
                                discrim_loader, lambda x: {
                                    'smiles': x['X'],
                                    'dataset_index': x['dataset_index']
                                }),
                            full_path(disc_save_file),
                            plot_prefix + ' discriminator',
                            lr=lr_discrim,
                            model_process_fun=None)

    def on_policy_gen(fitter, model):
        while True:
            # model.policy = SoftmaxRandomSamplePolicy()#bias=codec.grammar.get_log_frequencies())
            yield next(fitter)

    return model, fitter1, fitter2  #,on_policy_gen(fitter1, model)
Beispiel #7
0
from generative_playground.codec.hypergraph_grammar import HypergraphMaskGenerator

from generative_playground.codec.hypergraph_parser import hypergraph_parser, \
    check_validity
from generative_playground.molecules.model_settings import get_settings
from collections import OrderedDict
from rdkit.Chem import MolFromSmiles, MolToSmiles
import copy
import random, os
import numpy as np

if __name__ == '__main__':
    from generative_playground.codec.hypergraph_grammar import HypergraphGrammar, evaluate_rules, hypergraphs_are_equivalent

    settings = get_settings(molecules=True, grammar='new')
    thresh = 100000
    # Read in the strings
    f = open(settings['source_data'], 'r')
    L = []
    for line in f:
        line = line.strip()
        L.append(line)
        if len(L) > thresh:
            break
    f.close()

    fn = "rule_hypergraphs.pickle"
    max_rules = 50
    if os.path.isfile(fn):
        rm = HypergraphGrammar.load(fn)
    else:
Beispiel #8
0
def train_mol_descriptor(grammar=True,
                         EPOCHS=None,
                         BATCH_SIZE=None,
                         lr=2e-4,
                         gradient_clip=5,
                         drop_rate=0.0,
                         plot_ignore_initial=0,
                         save_file=None,
                         preload_file=None,
                         encoder_type='rnn',
                         plot_prefix='',
                         dashboard='properties',
                         aux_dataset=None,
                         preload_weights=False):

    root_location = os.path.dirname(
        os.path.abspath(inspect.getfile(inspect.currentframe())))
    root_location = root_location + '/../'
    save_path = root_location + 'pretrained/' + save_file

    if preload_file is None:
        preload_path = save_path
    else:
        preload_path = root_location + 'pretrained/' + preload_file

    batch_mult = 1 if aux_dataset is None else 2

    settings = get_settings(molecules=True, grammar=grammar)
    max_steps = settings['max_seq_length']

    if EPOCHS is not None:
        settings['EPOCHS'] = EPOCHS
    if BATCH_SIZE is not None:
        settings['BATCH_SIZE'] = BATCH_SIZE
    if False:
        pre_model, _ = get_decoder(True,
                                   grammar,
                                   z_size=settings['z_size'],
                                   decoder_hidden_n=200,
                                   feature_len=settings['feature_len'],
                                   max_seq_length=max_steps,
                                   drop_rate=drop_rate,
                                   decoder_type=encoder_type,
                                   batch_size=BATCH_SIZE * batch_mult)

        class AttentionSimulator(nn.Module):
            def __init__(self, pre_model, drop_rate):
                super().__init__()
                self.pre_model = pre_model
                pre_model_2 = AttentionAggregatingHead(pre_model,
                                                       drop_rate=drop_rate)
                pre_model_2.model_out_transform = lambda x: x[1]
                self.model = MeanVarianceSkewHead(pre_model_2,
                                                  4,
                                                  drop_rate=drop_rate)

            def forward(self, x):
                self.pre_model.policy = PolicyFromTarget(x)
                return self.model(None)

        model = to_gpu(AttentionSimulator(pre_model, drop_rate=drop_rate))
    else:
        pre_model = get_encoder(feature_len=settings['feature_len'],
                                max_seq_length=settings['max_seq_length'],
                                cnn_encoder_params={
                                    'kernel_sizes': (2, 3, 4),
                                    'filters': (2, 3, 4),
                                    'dense_size': 100
                                },
                                drop_rate=drop_rate,
                                encoder_type=encoder_type)

        model = MeanVarianceSkewHead(pre_model, 4, drop_rate=drop_rate)

    nice_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = optim.Adam(nice_params, lr=lr)

    main_dataset = MultiDatasetFromHDF5(settings['data_path'],
                                        ['actions', 'smiles'])
    train_loader, valid_loader = train_valid_loaders(main_dataset,
                                                     valid_fraction=0.1,
                                                     batch_size=BATCH_SIZE,
                                                     pin_memory=use_gpu)

    def scoring_fun(x):
        if isinstance(x, tuple) or isinstance(x, list):
            x = {'actions': x[0], 'smiles': x[1]}
        out_x = to_gpu(x['actions'])
        end_of_slice = randint(3, out_x.size()[1])
        #TODO inject random slicing back
        out_x = out_x[:, 0:end_of_slice]
        smiles = x['smiles']
        scores = to_gpu(
            torch.from_numpy(property_scorer(smiles).astype(np.float32)))
        return out_x, scores

    train_gen_main = IterableTransform(train_loader, scoring_fun)
    valid_gen_main = IterableTransform(valid_loader, scoring_fun)

    if aux_dataset is not None:
        train_aux, valid_aux = SamplingWrapper(aux_dataset) \
            .get_train_valid_loaders(BATCH_SIZE,
                                     dataset_name=['actions',
                                                   'smiles'])
        train_gen_aux = IterableTransform(train_aux, scoring_fun)
        valid_gen_aux = IterableTransform(valid_aux, scoring_fun)
        train_gen = CombinedLoader([train_gen_main, train_gen_aux],
                                   num_batches=90)
        valid_gen = CombinedLoader([valid_gen_main, valid_gen_aux],
                                   num_batches=10)
    else:
        train_gen = train_gen_main  #CombinedLoader([train_gen_main, train_gen_aux], num_batches=90)
        valid_gen = valid_gen_main  #CombinedLoader([valid_gen_main, valid_gen_aux], num_batches=10)

    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               factor=0.2,
                                               patience=3,
                                               min_lr=min(0.0001, 0.1 * lr),
                                               eps=1e-08)
    #scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)
    loss_obj = VariationalLoss(['valid', 'logP', 'SA', 'cyc_sc'])

    metric_monitor = MetricPlotter(plot_prefix=plot_prefix,
                                   loss_display_cap=4.0,
                                   dashboard_name=dashboard,
                                   plot_ignore_initial=plot_ignore_initial)

    checkpointer = Checkpointer(valid_batches_to_checkpoint=10,
                                save_path=save_path)

    fitter = fit(train_gen=train_gen,
                 valid_gen=valid_gen,
                 model=model,
                 optimizer=optimizer,
                 scheduler=scheduler,
                 grad_clip=gradient_clip,
                 epochs=settings['EPOCHS'],
                 loss_fn=loss_obj,
                 metric_monitor=metric_monitor,
                 checkpointer=checkpointer)

    return model, fitter, main_dataset
Beispiel #9
0
def main():
  # change this to False to produce the equation dataset
  molecules = True
  # change this to False to get character-based encodings instead of grammar-based
  grammar = 'new' #use True  for the grammar used by Kusner et al

  # can't define model class inside settings as it itself uses settings a lot
  _, my_model = get_vae(molecules, grammar)
  def pre_parser(x):
      try:
          return next(my_model._parser.parse(x))
      except Exception as e:
          return None

  settings = get_settings(molecules,grammar)
  MAX_LEN = settings['max_seq_length']
  #feature_len = settings['feature_len']
  dest_file = settings['data_path']
  source_file = settings['source_data']

  # Read in the strings
  f = open(source_file,'r')
  L = []
  for line in f:
      line = line.strip()
      L.append(line)
  f.close()

  # convert to one-hot and save, in small increments to save RAM
  #dest_file = dest_file.replace('.h5','_new.h5')
  ds = IncrementingHDF5Dataset(dest_file)

  step = 100
  dt = h5py.special_dtype(vlen=str)     # PY3 hdf5 datatype for variable-length Unicode strings
  size = min(10000, len(L))
  for i in tqdm(range(0, size, step)):#for i in range(0, 1000, 2000):
      #print('Processing: i=[' + str(i) + ':' + str(i + step) + ']')
      these_indices = list(range(i, min(i + step,len(L))))
      these_smiles = L[i:min(i + step,len(L))]
      if grammar=='new': # have to weed out non-parseable strings
          tokens = [my_model._tokenize(s.replace('-c','c')) for s in these_smiles]
          these_smiles, these_indices = list(zip(*[(s,ind) for s,t,ind in zip(these_smiles, tokens, these_indices) if pre_parser(t) is not None]))
          #print(len(these_smiles))
      these_actions = torch.tensor(my_model.strings_to_actions(these_smiles))
      action_seq_length = my_model.action_seq_length(these_actions)
      onehot = my_model.actions_to_one_hot(these_actions)
      append_data = {'smiles': np.array(these_smiles, dtype=dt),
                    'indices': np.array(these_indices),
                    'actions': these_actions,
                    'valid': np.ones((len(these_smiles))),
                    'seq_len': action_seq_length,
                    'data': onehot}
      if molecules:
          from rdkit.Chem.rdmolfiles import MolFromSmiles
          mols = [MolFromSmiles(s) for s in these_smiles]
          raw_scores = np.array([get_score_components_from_mol(m) for m in mols])
          append_data['raw_scores'] = raw_scores
          num_atoms = np.array([len(m.GetAtoms()) for m in mols])
          append_data['num_atoms'] = num_atoms

      ds.append(append_data)

  if molecules:
      # also calculate mean and std of the scores, to use in the ultimate objective
      raw_scores = np.array(ds.h5f['raw_scores'])
      score_std = raw_scores.std(0)
      score_mean = raw_scores.mean(0)
      ds.append_to_dataset('score_std',score_std)
      ds.append_to_dataset('score_mean', score_mean)

  print('success!')
def train_reinforcement(grammar=True,
                        model=None,
                        EPOCHS=None,
                        BATCH_SIZE=None,
                        lr=2e-4,
                        main_dataset=None,
                        new_datasets=None,
                        plot_ignore_initial=0,
                        save_file=None,
                        plot_prefix='',
                        dashboard='main',
                        preload_weights=False):

    root_location = os.path.dirname(
        os.path.abspath(inspect.getfile(inspect.currentframe())))
    root_location = root_location + '/../'
    if save_file is not None:
        save_path = root_location + 'pretrained/' + save_file
    else:
        save_path = None
    molecules = True  # checking for validity only makes sense for molecules
    settings = get_settings(molecules=molecules, grammar=grammar)

    # TODO: separate settings for this?
    if EPOCHS is not None:
        settings['EPOCHS'] = EPOCHS
    if BATCH_SIZE is not None:
        settings['BATCH_SIZE'] = BATCH_SIZE

    if preload_weights:
        try:
            model.load(save_path)
        except:
            pass
    nice_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = optim.Adam(nice_params, lr=lr)

    # # create the composite loaders
    # train_loader, valid_loader = train_valid_loaders(main_dataset,
    #                                                  valid_fraction=0.1,
    #                                                  batch_size=BATCH_SIZE,
    #                                                  pin_memory=use_gpu)
    train_l = []
    valid_l = []
    for ds in new_datasets:
        train_loader, valid_loader = SamplingWrapper(ds)\
                        .get_train_valid_loaders(BATCH_SIZE,
                                                 valid_batch_size = 1+int(BATCH_SIZE/5),
                            dataset_name=['actions','seq_len','valid','sample_seq_ind'],
                                                 window=1000)
        train_l.append(train_loader)
        valid_l.append(valid_loader)
    train_gen = CombinedLoader(train_l, num_batches=90)
    valid_gen = CombinedLoader(valid_l, num_batches=10)

    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               factor=0.2,
                                               patience=3,
                                               min_lr=0.0001,
                                               eps=1e-08)
    #scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)
    loss_obj = ReinforcementLoss()

    fitter = fit(train_gen=train_gen,
                 valid_gen=valid_gen,
                 model=model,
                 optimizer=optimizer,
                 scheduler=scheduler,
                 epochs=settings['EPOCHS'],
                 loss_fn=loss_obj,
                 save_path=save_path,
                 save_always=True,
                 dashboard_name=dashboard,
                 plot_ignore_initial=plot_ignore_initial,
                 plot_prefix=plot_prefix,
                 loss_display_cap=200)

    return fitter
Beispiel #11
0
def train_vae(molecules = True,
              grammar = True,
              EPOCHS = None,
              BATCH_SIZE = None,
              lr = 2e-4,
              drop_rate = 0.0,
              plot_ignore_initial = 0,
              reg_weight = 1,
              epsilon_std = 0.01,
              sample_z = True,
              save_file = None,
              preload_file = None,
              encoder_type='cnn',
              decoder_type='step',
              plot_prefix = '',
              dashboard = 'main',
              preload_weights=False):

    root_location = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
    root_location = root_location + '/../'
    save_path = root_location + 'pretrained/' + save_file
    if preload_file is None:
        preload_path = save_path
    else:
        preload_path = root_location + 'pretrained/' + preload_file



    settings = get_settings(molecules=molecules,grammar=grammar)

    if EPOCHS is not None:
        settings['EPOCHS'] = EPOCHS
    if BATCH_SIZE is not None:
        settings['BATCH_SIZE'] = BATCH_SIZE


    model,_ = get_vae(molecules=molecules,
                      grammar=grammar,
                      drop_rate=drop_rate,
                      sample_z = sample_z,
                      rnn_encoder=encoder_type,
                      decoder_type = decoder_type,
                      weights_file=preload_path if preload_weights else None,
                      epsilon_std=epsilon_std
                      )

    nice_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = optim.Adam(nice_params, lr=lr)

    main_dataset = DatasetFromHDF5(settings['data_path'],'data')
    train_loader, valid_loader = train_valid_loaders(main_dataset,
                                                     valid_fraction=0.1,
                                                     batch_size=BATCH_SIZE,
                                                     pin_memory=use_gpu)

    train_gen = TwinGenerator(train_loader)
    valid_gen = TwinGenerator(valid_loader)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               factor=0.2,
                                               patience=3,
                                               min_lr=min(0.0001,0.1*lr),
                                               eps=1e-08)
    #scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)
    loss_obj = VAELoss(settings['grammar'], sample_z, reg_weight)

    metric_monitor = MetricPlotter(plot_prefix=plot_prefix,
                                   loss_display_cap=4.0,
                                   dashboard_name=dashboard,
                                   plot_ignore_initial=plot_ignore_initial)

    checkpointer = Checkpointer(valid_batches_to_checkpoint=1,
                             save_path=save_path)

    fitter = fit(train_gen=train_gen,
                 valid_gen=valid_gen,
                 model=model,
                 optimizer=optimizer,
                 scheduler=scheduler,
                 epochs=settings['EPOCHS'],
                 loss_fn=loss_obj,
                 metric_monitor = metric_monitor,
                 checkpointer = checkpointer)

    return model, fitter, main_dataset
        if isinstance(x, str):
            molecule = MolFromSmiles(x)
            tree = hypergraph_parser(molecule)
            norm_tree = self.grammar.normalize_tree(tree)
        else:
            norm_tree = x

        for rule_pair in self.rule_pairs:
            tree = apply_hypergraph_substitution(self.grammar, norm_tree,
                                                 rule_pair)
        return tree


if __name__ == '__main__':
    from generative_playground.molecules.model_settings import get_settings
    settings = get_settings(True, 'new')  # True)  #

    with open(settings['source_data']) as f:
        smiles = f.readlines()

    for i in range(len(smiles)):
        smiles[i] = smiles[i].strip()

    codec = settings['codec']
    trees = get_parse_trees(codec, smiles[:100])
    rule_pairs = extract_popular_pairs(trees, 10)

    # and a test parse
    rpe_parser = RPEParser(codec._parser, rule_pairs)
    for smile in smiles:
        tokens = codec._tokenize(smile)
    def __init__(self,
                 grammar,
                 smiles_source='ZINC',
                 BATCH_SIZE=None,
                 reward_fun=None,
                 max_steps=277,
                 num_batches=100,
                 lr=2e-4,
                 entropy_wgt=1.0,
                 lr_schedule=None,
                 root_name=None,
                 preload_file_root_name=None,
                 save_location=None,
                 plot_metrics=True,
                 metric_smooth=0.0,
                 decoder_type='graph_conditional',
                 on_policy_loss_type='advantage_record',
                 priors='conditional',
                 rule_temperature_schedule=None,
                 eps=0.0,
                 half_float=False,
                 extra_repetition_penalty=0.0):

        self.num_batches = num_batches
        self.save_location = save_location
        self.molecule_saver = MoleculeSaver(None, gzip=True)
        self.metric_monitor = None  # to be populated by self.set_root_name(...)

        zinc_data = get_smiles_from_database(source=smiles_source)
        zinc_set = set(zinc_data)
        lookbacks = [BATCH_SIZE, 10 * BATCH_SIZE, 100 * BATCH_SIZE]
        history_data = [deque(['O'], maxlen=lb) for lb in lookbacks]

        if root_name is not None:
            pass
            # gen_save_file = root_name + '_gen.h5'
        if preload_file_root_name is not None:
            gen_preload_file = preload_file_root_name + '_gen.h5'

        settings = get_settings(molecules=True, grammar=grammar)
        codec = get_codec(True, grammar, settings['max_seq_length'])

        if BATCH_SIZE is not None:
            settings['BATCH_SIZE'] = BATCH_SIZE

        self.alt_reward_calc = AdjustedRewardCalculator(
            reward_fun,
            zinc_set,
            lookbacks,
            extra_repetition_penalty,
            0,
            discrim_model=None)
        self.reward_fun = lambda x: adj_reward(0,
                                               None,
                                               reward_fun,
                                               zinc_set,
                                               history_data,
                                               extra_repetition_penalty,
                                               x,
                                               alt_calc=self.alt_reward_calc)

        task = SequenceGenerationTask(molecules=True,
                                      grammar=grammar,
                                      reward_fun=self.alt_reward_calc,
                                      batch_size=BATCH_SIZE,
                                      max_steps=max_steps,
                                      save_dataset=None)

        if 'sparse' in decoder_type:
            rule_policy = SoftmaxRandomSamplePolicySparse()
        else:
            rule_policy = SoftmaxRandomSamplePolicy(
                temperature=torch.tensor(1.0), eps=eps)

        # TODO: strip this down to the normal call
        self.model = get_decoder(True,
                                 grammar,
                                 z_size=settings['z_size'],
                                 decoder_hidden_n=200,
                                 feature_len=codec.feature_len(),
                                 max_seq_length=max_steps,
                                 batch_size=BATCH_SIZE,
                                 decoder_type=decoder_type,
                                 reward_fun=self.alt_reward_calc,
                                 task=task,
                                 rule_policy=rule_policy,
                                 priors=priors)[0]

        if preload_file_root_name is not None:
            try:
                preload_path = os.path.realpath(save_location +
                                                gen_preload_file)
                self.model.load_state_dict(torch.load(preload_path,
                                                      map_location='cpu'),
                                           strict=False)
                print('Generator weights loaded successfully!')
            except Exception as e:
                print('failed to load generator weights ' + str(e))

        # construct the loader to feed the discriminator
        def make_callback(data):
            def hc(inputs, model, outputs, loss_fn, loss):
                graphs = outputs['graphs']
                smiles = [g.to_smiles() for g in graphs]
                for s in smiles:  # only store unique instances of molecules so discriminator can't guess on frequency
                    if s not in data:
                        data.append(s)

            return hc

        if plot_metrics:
            # TODO: save_file for rewards data goes here?
            self.metric_monitor_factory = lambda name: MetricPlotter(
                plot_prefix='',
                loss_display_cap=float('inf'),
                dashboard_name=name,
                save_location=save_location,
                process_model_fun=model_process_fun,
                smooth_weight=metric_smooth)
        else:
            self.metric_monitor_factory = lambda x: None

        # the on-policy fitter

        gen_extra_callbacks = [make_callback(d) for d in history_data]
        gen_extra_callbacks.append(self.molecule_saver)
        if rule_temperature_schedule is not None:
            gen_extra_callbacks.append(
                TemperatureCallback(rule_policy, rule_temperature_schedule))

        nice_params = filter(lambda p: p.requires_grad,
                             self.model.parameters())
        self.optimizer = optim.Adam(nice_params, lr=lr, eps=1e-4)

        if lr_schedule is None:
            lr_schedule = lambda x: 1.0
        self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_schedule)
        self.loss = PolicyGradientLoss(on_policy_loss_type,
                                       entropy_wgt=entropy_wgt)
        self.fitter_factory = lambda: make_fitter(BATCH_SIZE, settings[
            'z_size'], [self.metric_monitor] + gen_extra_callbacks, self)

        self.fitter = self.fitter_factory()
        self.set_root_name(root_name)
        print('Runner initialized!')
def train_policy_gradient(molecules=True,
                          grammar=True,
                          EPOCHS=None,
                          BATCH_SIZE=None,
                          reward_fun_on=None,
                          reward_fun_off=None,
                          max_steps=277,
                          lr_on=2e-4,
                          lr_off=1e-4,
                          drop_rate=0.0,
                          plot_ignore_initial=0,
                          save_file=None,
                          preload_file=None,
                          anchor_file=None,
                          anchor_weight=0.0,
                          decoder_type='action',
                          plot_prefix='',
                          dashboard='policy gradient',
                          smiles_save_file=None,
                          on_policy_loss_type='best',
                          off_policy_loss_type='mean',
                          sanity_checks=True):

    root_location = os.path.dirname(
        os.path.abspath(inspect.getfile(inspect.currentframe())))
    root_location = root_location + '/../'
    save_path = root_location + 'pretrained/' + save_file
    smiles_save_path = root_location + 'pretrained/' + smiles_save_file

    settings = get_settings(molecules=molecules, grammar=grammar)

    if EPOCHS is not None:
        settings['EPOCHS'] = EPOCHS
    if BATCH_SIZE is not None:
        settings['BATCH_SIZE'] = BATCH_SIZE

    save_dataset = IncrementingHDF5Dataset(smiles_save_path)

    task = SequenceGenerationTask(molecules=molecules,
                                  grammar=grammar,
                                  reward_fun=reward_fun_on,
                                  batch_size=BATCH_SIZE,
                                  max_steps=max_steps,
                                  save_dataset=save_dataset)

    def get_model(sanity_checks=sanity_checks):
        return get_decoder(
            molecules,
            grammar,
            z_size=settings['z_size'],
            decoder_hidden_n=200,
            feature_len=None,  # made redundant, need to factor it out
            max_seq_length=max_steps,
            drop_rate=drop_rate,
            decoder_type=decoder_type,
            task=task)[0]

    model = get_model()

    if preload_file is not None:
        try:
            preload_path = root_location + 'pretrained/' + preload_file
            model.load_state_dict(torch.load(preload_path))
        except:
            pass

    anchor_model = None
    if anchor_file is not None:
        anchor_model = get_model()
        try:
            anchor_path = root_location + 'pretrained/' + anchor_file
            anchor_model.load_state_dict(torch.load(anchor_path))
        except:
            anchor_model = None

    from generative_playground.molecules.rdkit_utils.rdkit_utils import NormalizedScorer
    import rdkit.Chem.rdMolDescriptors as desc
    import numpy as np
    scorer = NormalizedScorer()

    def model_process_fun(model_out, visdom, n):
        from rdkit import Chem
        from rdkit.Chem.Draw import MolToFile
        actions, logits, rewards, terminals, info = model_out
        smiles, valid = info
        total_rewards = rewards.sum(1)
        best_ind = torch.argmax(total_rewards).data.item()
        this_smile = smiles[best_ind]
        mol = Chem.MolFromSmiles(this_smile)
        pic_save_path = root_location + 'images/' + 'test.svg'
        if mol is not None:
            try:
                MolToFile(mol, pic_save_path, imageType='svg')
                with open(pic_save_path, 'r') as myfile:
                    data = myfile.read()
                data = data.replace('svg:', '')
                visdom.append('best molecule of batch', 'svg', svgstr=data)
            except:
                pass
            scores, norm_scores = scorer.get_scores([this_smile])
            visdom.append(
                'score component',
                'line',
                X=np.array([n]),
                Y=np.array(
                    [[x for x in norm_scores[0]] + [norm_scores[0].sum()] +
                     [scores[0].sum()] + [desc.CalcNumAromaticRings(mol)]]),
                opts={
                    'legend': [
                        'logP', 'SA', 'cycle', 'norm_reward', 'reward',
                        'Aromatic rings'
                    ]
                })
            visdom.append('fraction valid',
                          'line',
                          X=np.array([n]),
                          Y=np.array([valid.mean().data.item()]))

    if reward_fun_off is None:
        reward_fun_off = reward_fun_on

    def get_fitter(model,
                   loss_obj,
                   fit_plot_prefix='',
                   model_process_fun=None,
                   lr=None,
                   loss_display_cap=float('inf'),
                   anchor_model=None,
                   anchor_weight=0):
        nice_params = filter(lambda p: p.requires_grad, model.parameters())
        optimizer = optim.Adam(nice_params, lr=lr)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.99)

        if dashboard is not None:
            metric_monitor = MetricPlotter(
                plot_prefix=fit_plot_prefix,
                loss_display_cap=loss_display_cap,
                dashboard_name=dashboard,
                plot_ignore_initial=plot_ignore_initial,
                process_model_fun=model_process_fun)
        else:
            metric_monitor = None

        checkpointer = Checkpointer(valid_batches_to_checkpoint=1,
                                    save_path=save_path,
                                    save_always=True)

        def my_gen():
            for _ in range(1000):
                yield to_gpu(torch.zeros(BATCH_SIZE, settings['z_size']))

        fitter = fit_rl(train_gen=my_gen,
                        model=model,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        epochs=EPOCHS,
                        loss_fn=loss_obj,
                        grad_clip=5,
                        anchor_model=anchor_model,
                        anchor_weight=anchor_weight,
                        metric_monitor=metric_monitor,
                        checkpointer=checkpointer)

        return fitter

    # the on-policy fitter
    fitter1 = get_fitter(model,
                         PolicyGradientLoss(on_policy_loss_type),
                         plot_prefix + 'on-policy',
                         model_process_fun=model_process_fun,
                         lr=lr_on,
                         anchor_model=anchor_model,
                         anchor_weight=anchor_weight)
    # get existing molecule data to add training
    main_dataset = DatasetFromHDF5(settings['data_path'], 'actions')

    # TODO change call to a simple DataLoader, no validation
    train_loader, valid_loader = train_valid_loaders(main_dataset,
                                                     valid_fraction=0.1,
                                                     batch_size=BATCH_SIZE,
                                                     pin_memory=use_gpu)

    fitter2 = get_fitter(model,
                         PolicyGradientLoss(off_policy_loss_type),
                         plot_prefix + ' off-policy',
                         lr=lr_off,
                         model_process_fun=model_process_fun,
                         loss_display_cap=125)

    def on_policy_gen(fitter, model):
        while True:
            model.policy = SoftmaxRandomSamplePolicy()
            yield next(fitter)

    def off_policy_gen(fitter, data_gen, model):
        while True:
            data_iter = data_gen.__iter__()
            try:
                x_actions = next(data_iter).to(torch.int64)
                model.policy = PolicyFromTarget(x_actions)
                yield next(fitter)
            except StopIteration:
                data_iter = data_gen.__iter__()

    return model, on_policy_gen(fitter1,
                                model), off_policy_gen(fitter2, train_loader,
                                                       model)
Beispiel #15
0
# from deep_rl import *
# from generative_playground.models.problem.rl.network_heads import CategoricalActorCriticNet
# from generative_playground.train.rl.run_iterations import run_iterations
from generative_playground.molecules.rdkit_utils.rdkit_utils import num_atoms, num_aromatic_rings, NormalizedScorer
# from generative_playground.models.problem.rl.DeepRL_wrappers import BodyAdapter, MyA2CAgent
from generative_playground.molecules.model_settings import get_settings
from generative_playground.molecules.train.pg.hypergraph.main_train_policy_gradient_minimal import train_policy_gradient
from generative_playground.codec.hypergraph_grammar import GrammarInitializer
from generative_playground.molecules.guacamol_utils import guacamol_goal_scoring_functions, version_name_list

batch_size = 15  # 20
drop_rate = 0.5
molecules = True
grammar_cache = 'hyper_grammar_guac_10k_with_clique_collapse.pickle'  #'hyper_grammar.pickle'
grammar = 'hypergraph:' + grammar_cache
settings = get_settings(molecules, grammar)
ver = 'trivial'
obj_num = 0
reward_funs = guacamol_goal_scoring_functions(ver)
reward_fun = reward_funs[obj_num]
# # later will run this ahead of time
# gi = GrammarInitializer(grammar_cache)

root_name = 'canned_' + ver + '_' + str(obj_num) + 'do 0.5 lr4e-5'
max_steps = 45
model, gen_fitter, disc_fitter = train_policy_gradient(
    molecules,
    grammar,
    EPOCHS=100,
    BATCH_SIZE=batch_size,
    reward_fun_on=reward_fun,
Beispiel #16
0
def train_policy_gradient_ppo(molecules=True,
                              grammar=True,
                              smiles_source='ZINC',
                              EPOCHS=None,
                              BATCH_SIZE=None,
                              reward_fun_on=None,
                              reward_fun_off=None,
                              max_steps=277,
                              lr_on=2e-4,
                              lr_discrim=1e-4,
                              discrim_wt=2,
                              p_thresh=0.5,
                              drop_rate=0.0,
                              plot_ignore_initial=0,
                              randomize_reward=False,
                              save_file_root_name=None,
                              reward_sm=0.0,
                              preload_file_root_name=None,
                              anchor_file=None,
                              anchor_weight=0.0,
                              decoder_type='action',
                              plot_prefix='',
                              dashboard='policy gradient',
                              smiles_save_file=None,
                              on_policy_loss_type='best',
                              priors=True,
                              node_temperature_schedule=lambda x: 1.0,
                              rule_temperature_schedule=lambda x: 1.0,
                              eps=0.0,
                              half_float=False,
                              extra_repetition_penalty=0.0):
    root_location = os.path.dirname(
        os.path.abspath(inspect.getfile(inspect.currentframe())))
    root_location = root_location + '/../../'

    def full_path(x):
        return os.path.realpath(root_location + 'pretrained/' + x)

    if save_file_root_name is not None:
        gen_save_file = save_file_root_name + '_gen.h5'
        disc_save_file = save_file_root_name + '_disc.h5'
    if preload_file_root_name is not None:
        gen_preload_file = preload_file_root_name + '_gen.h5'
        disc_preload_file = preload_file_root_name + '_disc.h5'

    settings = get_settings(molecules=molecules, grammar=grammar)
    codec = get_codec(molecules, grammar, settings['max_seq_length'])
    discrim_model = GraphDiscriminator(codec.grammar, drop_rate=drop_rate)
    if False and preload_file_root_name is not None:
        try:
            preload_path = full_path(disc_preload_file)
            discrim_model.load_state_dict(torch.load(preload_path),
                                          strict=False)
            print('Discriminator weights loaded successfully!')
        except Exception as e:
            print('failed to load discriminator weights ' + str(e))

    zinc_data = get_smiles_from_database(source=smiles_source)
    zinc_set = set(zinc_data)
    lookbacks = [BATCH_SIZE, 10 * BATCH_SIZE, 100 * BATCH_SIZE]
    history_data = [deque(['O'], maxlen=lb) for lb in lookbacks]

    def originality_mult(smiles_list):
        out = []
        for s in smiles_list:
            if s in zinc_set:
                out.append(0.5)
            elif s in history_data[0]:
                out.append(0.5)
            elif s in history_data[1]:
                out.append(0.70)
            elif s in history_data[2]:
                out.append(0.85)
            else:
                out.append(1.0)
        return np.array(out)

    def sigmoid(x):
        return 1 / (1 + np.exp(-x))

    def discriminator_reward_mult(smiles_list):
        orig_state = discrim_model.training
        discrim_model.eval()
        discrim_out_logits = discrim_model(smiles_list)['p_zinc']
        discrim_probs = F.softmax(discrim_out_logits, dim=1)
        prob_zinc = discrim_probs[:, 1].detach().cpu().numpy()
        if orig_state:
            discrim_model.train()
        return prob_zinc

    def apply_originality_penalty(x, orig_mult):
        assert x <= 1, "Reward must be no greater than 0"
        if x > 0.5:  # want to punish nearly-perfect scores less and less
            out = math.pow(x, 1 / orig_mult)
        else:  # continuous join at 0.5
            penalty = math.pow(0.5, 1 / orig_mult) - 0.5
            out = x + penalty

        out -= extra_repetition_penalty * (1 - 1 / orig_mult)
        return out

    def adj_reward(x):
        if discrim_wt > 1e-5:
            p = discriminator_reward_mult(x)
        else:
            p = 0
        rwd = np.array(reward_fun_on(x))
        orig_mult = originality_mult(x)
        # we assume the reward is <=1, first term will dominate for reward <0, second for 0 < reward < 1
        # reward = np.minimum(rwd/orig_mult, np.power(np.abs(rwd),1/orig_mult))
        reward = np.array([
            apply_originality_penalty(x, om) for x, om in zip(rwd, orig_mult)
        ])
        out = reward + discrim_wt * p * orig_mult
        return out

    def adj_reward_old(x):
        p = discriminator_reward_mult(x)
        w = sigmoid(-(p - p_thresh) / 0.01)
        if randomize_reward:
            rand = np.random.uniform(size=p.shape)
            w *= rand
        reward = np.maximum(reward_fun_on(x), p_thresh)
        weighted_reward = w * p + (1 - w) * reward
        out = weighted_reward * originality_mult(x)  #
        return out

    if EPOCHS is not None:
        settings['EPOCHS'] = EPOCHS
    if BATCH_SIZE is not None:
        settings['BATCH_SIZE'] = BATCH_SIZE

    task = SequenceGenerationTask(molecules=molecules,
                                  grammar=grammar,
                                  reward_fun=adj_reward,
                                  batch_size=BATCH_SIZE,
                                  max_steps=max_steps,
                                  save_dataset=None)

    node_policy = SoftmaxRandomSamplePolicy(temperature=torch.tensor(1.0),
                                            eps=eps)
    rule_policy = SoftmaxRandomSamplePolicy(temperature=torch.tensor(1.0),
                                            eps=eps)

    model = get_decoder(molecules,
                        grammar,
                        z_size=settings['z_size'],
                        decoder_hidden_n=200,
                        feature_len=codec.feature_len(),
                        max_seq_length=max_steps,
                        drop_rate=drop_rate,
                        batch_size=BATCH_SIZE,
                        decoder_type=decoder_type,
                        reward_fun=adj_reward,
                        task=task,
                        node_policy=node_policy,
                        rule_policy=rule_policy,
                        priors=priors)[0]

    if preload_file_root_name is not None:
        try:
            preload_path = full_path(gen_preload_file)
            model.load_state_dict(torch.load(preload_path, map_location='cpu'),
                                  strict=False)
            print('Generator weights loaded successfully!')
        except Exception as e:
            print('failed to load generator weights ' + str(e))

    anchor_model = None

    from generative_playground.molecules.rdkit_utils.rdkit_utils import NormalizedScorer
    import numpy as np
    scorer = NormalizedScorer()

    if reward_fun_off is None:
        reward_fun_off = reward_fun_on

    # construct the loader to feed the discriminator
    def make_callback(data):
        def hc(inputs, model, outputs, loss_fn, loss):
            graphs = outputs['graphs']
            smiles = [g.to_smiles() for g in graphs]
            for s in smiles:  # only store unique instances of molecules so discriminator can't guess on frequency
                if s not in data:
                    data.append(s)

        return hc

    class TemperatureCallback:
        def __init__(self, policy, temperature_function):
            self.policy = policy
            self.counter = 0
            self.temp_fun = temperature_function

        def __call__(self, inputs, model, outputs, loss_fn, loss):
            self.counter += 1
            target_temp = self.temp_fun(self.counter)
            self.policy.set_temperature(target_temp)

    # need to have something there to begin with, else the DataLoader constructor barfs

    def get_rl_fitter(model,
                      loss_obj,
                      train_gen,
                      save_path,
                      fit_plot_prefix='',
                      model_process_fun=None,
                      lr=None,
                      extra_callbacks=[],
                      loss_display_cap=float('inf'),
                      anchor_model=None,
                      anchor_weight=0):
        nice_params = filter(lambda p: p.requires_grad, model.parameters())
        optimizer = optim.Adam(nice_params, lr=lr, eps=1e-4)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.99)

        if dashboard is not None:
            metric_monitor = MetricPlotter(
                plot_prefix=fit_plot_prefix,
                loss_display_cap=loss_display_cap,
                dashboard_name=dashboard,
                plot_ignore_initial=plot_ignore_initial,
                process_model_fun=model_process_fun,
                smooth_weight=reward_sm)
        else:
            metric_monitor = None

        checkpointer = Checkpointer(valid_batches_to_checkpoint=1,
                                    save_path=save_path,
                                    save_always=True)

        fitter = fit_rl(train_gen=train_gen,
                        model=model,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        epochs=EPOCHS,
                        loss_fn=loss_obj,
                        grad_clip=5,
                        half_float=half_float,
                        anchor_model=anchor_model,
                        anchor_weight=anchor_weight,
                        callbacks=[metric_monitor, checkpointer] +
                        extra_callbacks)

        return fitter

    def my_gen():
        for _ in range(1000):
            yield to_gpu(torch.zeros(BATCH_SIZE, settings['z_size']))

    # the on-policy fitter

    gen_extra_callbacks = [make_callback(d) for d in history_data]

    if smiles_save_file is not None:
        smiles_save_path = os.path.realpath(root_location + 'pretrained/' +
                                            smiles_save_file)
        gen_extra_callbacks.append(MoleculeSaver(smiles_save_path, gzip=True))
        print('Saved SMILES to {}'.format(smiles_save_file))

    if node_temperature_schedule is not None:
        gen_extra_callbacks.append(
            TemperatureCallback(node_policy, node_temperature_schedule))

    if rule_temperature_schedule is not None:
        gen_extra_callbacks.append(
            TemperatureCallback(rule_policy, rule_temperature_schedule))

    fitter1 = get_rl_fitter(
        model,
        PolicyGradientLoss(on_policy_loss_type),  # last_reward_wgt=reward_sm),
        GeneratorToIterable(my_gen),
        full_path(gen_save_file),
        plot_prefix + 'on-policy',
        model_process_fun=model_process_fun,
        lr=lr_on,
        extra_callbacks=gen_extra_callbacks,
        anchor_model=anchor_model,
        anchor_weight=anchor_weight)
    #
    # # get existing molecule data to add training
    pre_dataset = EvenlyBlendedDataset(
        2 * [history_data[0]] + history_data[1:],
        labels=False)  # a blend of 3 time horizons
    dataset = EvenlyBlendedDataset([pre_dataset, zinc_data], labels=True)
    discrim_loader = DataLoader(dataset, shuffle=True, batch_size=50)

    class MyLoss(nn.Module):
        def __init__(self):
            super().__init__()
            self.celoss = nn.CrossEntropyLoss()

        def forward(self, x):
            # tmp = discriminator_reward_mult(x['smiles'])
            # tmp2 = F.softmax(x['p_zinc'], dim=1)[:,1].detach().cpu().numpy()
            # import numpy as np
            # assert np.max(np.abs(tmp-tmp2)) < 1e-6
            return self.celoss(x['p_zinc'].to(device),
                               x['dataset_index'].to(device))

    fitter2 = get_rl_fitter(discrim_model,
                            MyLoss(),
                            IterableTransform(
                                discrim_loader, lambda x: {
                                    'smiles': x['X'],
                                    'dataset_index': x['dataset_index']
                                }),
                            full_path(disc_save_file),
                            plot_prefix + ' discriminator',
                            lr=lr_discrim,
                            model_process_fun=None)

    def on_policy_gen(fitter, model):
        while True:
            # model.policy = SoftmaxRandomSamplePolicy()#bias=codec.grammar.get_log_frequencies())
            yield next(fitter)

    return model, fitter1, fitter2  #,on_policy_gen(fitter1, model)
Beispiel #17
0
def train_validity(grammar=True,
                   model=None,
                   EPOCHS=None,
                   BATCH_SIZE=None,
                   lr=2e-4,
                   main_dataset=None,
                   new_datasets=None,
                   plot_ignore_initial=0,
                   save_file=None,
                   plot_prefix='',
                   dashboard='main',
                   preload_weights=False):

    root_location = os.path.dirname(
        os.path.abspath(inspect.getfile(inspect.currentframe())))
    root_location = root_location + '/../'
    if save_file is not None:
        save_path = root_location + 'pretrained/' + save_file
    else:
        save_path = None
    molecules = True  # checking for validity only makes sense for molecules
    settings = get_settings(molecules=molecules, grammar=grammar)

    # TODO: separate settings for this?
    if EPOCHS is not None:
        settings['EPOCHS'] = EPOCHS
    if BATCH_SIZE is not None:
        settings['BATCH_SIZE'] = BATCH_SIZE

    if preload_weights:
        try:
            model.load(save_path)
        except:
            pass

    nice_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = optim.Adam(nice_params, lr=lr)

    # create the composite loaders
    train_loader, valid_loader = train_valid_loaders(main_dataset,
                                                     valid_fraction=0.1,
                                                     batch_size=BATCH_SIZE,
                                                     pin_memory=use_gpu)
    valid_smile_ds, invalid_smile_ds = new_datasets
    valid_train, valid_val = valid_smile_ds.get_train_valid_loaders(BATCH_SIZE)
    invalid_train, invalid_val = valid_smile_ds.get_train_valid_loaders(
        BATCH_SIZE)
    train_gen = MixedLoader(train_loader, valid_train, invalid_train)
    valid_gen = MixedLoader(valid_loader, valid_val, invalid_val)

    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               factor=0.2,
                                               patience=3,
                                               min_lr=0.0001,
                                               eps=1e-08)
    #scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)
    loss_obj = nn.BCELoss(size_average=True)

    fitter = fit(train_gen=train_gen,
                 valid_gen=valid_gen,
                 model=model,
                 optimizer=optimizer,
                 scheduler=scheduler,
                 epochs=settings['EPOCHS'],
                 loss_fn=loss_obj,
                 save_path=save_path,
                 dashboard_name=dashboard,
                 plot_ignore_initial=plot_ignore_initial,
                 plot_prefix=plot_prefix)

    return fitter
Beispiel #18
0
def train_policy_gradient(molecules=True,
                          grammar=True,
                          EPOCHS=None,
                          BATCH_SIZE=None,
                          reward_fun_on=None,
                          reward_fun_off=None,
                          max_steps=277,
                          lr_on=2e-4,
                          lr_discrim=1e-4,
                          p_thresh=0.5,
                          drop_rate=0.0,
                          plot_ignore_initial=0,
                          randomize_reward=False,
                          save_file=None,
                          reward_sm=0.0,
                          preload_file=None,
                          anchor_file=None,
                          anchor_weight=0.0,
                          decoder_type='action',
                          plot_prefix='',
                          dashboard='policy gradient',
                          smiles_save_file=None,
                          on_policy_loss_type='best',
                          off_policy_loss_type='mean',
                          sanity_checks=True):
    root_location = os.path.dirname(
        os.path.abspath(inspect.getfile(inspect.currentframe())))
    root_location = root_location + '/../../'
    gen_save_path = root_location + 'pretrained/gen_' + save_file
    disc_save_path = root_location + 'pretrained/disc_' + save_file

    if smiles_save_file is not None:
        smiles_save_path = root_location + 'pretrained/' + smiles_save_file
        save_dataset = IncrementingHDF5Dataset(smiles_save_path)
    else:
        save_dataset = None

    settings = get_settings(molecules=molecules, grammar=grammar)
    codec = get_codec(molecules, grammar, settings['max_seq_length'])
    discrim_model = GraphDiscriminator(codec.grammar, drop_rate=drop_rate)

    zinc_data = get_zinc_smiles()
    zinc_set = set(zinc_data)
    lookbacks = [BATCH_SIZE, 10 * BATCH_SIZE, 100 * BATCH_SIZE]
    history_data = [deque(['O'], maxlen=lb) for lb in lookbacks]

    def originality_mult(smiles_list):
        out = []
        for s in smiles_list:
            if s in zinc_set:
                out.append(0.5)
            elif s in history_data[0]:
                out.append(0.5)
            elif s in history_data[1]:
                out.append(0.70)
            elif s in history_data[2]:
                out.append(0.85)
            else:
                out.append(1.0)
        return np.array(out)

    def sigmoid(x):
        tmp = -x  #(
        return 1 / (1 + np.exp(-x))

    def discriminator_reward_mult(smiles_list):
        orig_state = discrim_model.training
        discrim_model.eval()
        discrim_out_logits = discrim_model(smiles_list)['p_zinc']
        discrim_probs = F.softmax(discrim_out_logits, dim=1)
        prob_zinc = discrim_probs[:, 1].detach().cpu().numpy()
        if orig_state:
            discrim_model.train()
        return prob_zinc

    def adj_reward(x):
        p = discriminator_reward_mult(x)
        reward = np.maximum(reward_fun_on(x), 0)
        out = reward * originality_mult(x) + 2 * p
        return out

    def adj_reward_old(x):
        p = discriminator_reward_mult(x)
        w = sigmoid(-(p - p_thresh) / 0.01)
        if randomize_reward:
            rand = np.random.uniform(size=p.shape)
            w *= rand
        reward = np.maximum(reward_fun_on(x), p_thresh)
        weighted_reward = w * p + (1 - w) * reward
        out = weighted_reward * originality_mult(x)  #
        return out

    if EPOCHS is not None:
        settings['EPOCHS'] = EPOCHS
    if BATCH_SIZE is not None:
        settings['BATCH_SIZE'] = BATCH_SIZE

    task = SequenceGenerationTask(molecules=molecules,
                                  grammar=grammar,
                                  reward_fun=adj_reward,
                                  batch_size=BATCH_SIZE,
                                  max_steps=max_steps,
                                  save_dataset=save_dataset)

    model = get_decoder(molecules,
                        grammar,
                        z_size=settings['z_size'],
                        decoder_hidden_n=200,
                        feature_len=codec.feature_len(),
                        max_seq_length=max_steps,
                        drop_rate=drop_rate,
                        decoder_type=decoder_type,
                        task=task)[0]

    # TODO: really ugly, refactor! In fact this model doesn't need a MaskingHead at all!
    model.stepper.model.mask_gen.priors = True  #'conditional' # use empirical priors for the mask gen
    # if preload_file is not None:
    #     try:
    #         preload_path = root_location + 'pretrained/' + preload_file
    #         model.load_state_dict(torch.load(preload_path))
    #     except:
    #         pass

    anchor_model = None

    from generative_playground.molecules.rdkit_utils.rdkit_utils import NormalizedScorer
    import rdkit.Chem.rdMolDescriptors as desc
    import numpy as np
    scorer = NormalizedScorer()

    def model_process_fun(model_out, visdom, n):
        # TODO: rephrase this to return a dict, instead of calling visdom directly
        from rdkit import Chem
        from rdkit.Chem.Draw import MolToFile
        # actions, logits, rewards, terminals, info = model_out
        smiles, valid = model_out['info']
        total_rewards = model_out['rewards'].sum(1)
        best_ind = torch.argmax(total_rewards).data.item()
        this_smile = smiles[best_ind]
        mol = Chem.MolFromSmiles(this_smile)
        pic_save_path = root_location + 'images/' + 'tmp.svg'
        if mol is not None:
            try:
                MolToFile(mol, pic_save_path, imageType='svg')
                with open(pic_save_path, 'r') as myfile:
                    data = myfile.read()
                data = data.replace('svg:', '')
                visdom.append('best molecule of batch', 'svg', svgstr=data)
            except Exception as e:
                print(e)
            scores, norm_scores = scorer.get_scores([this_smile])
            visdom.append(
                'score component',
                'line',
                X=np.array([n]),
                Y=np.array(
                    [[x for x in norm_scores[0]] + [norm_scores[0].sum()] +
                     [scores[0].sum()] + [desc.CalcNumAromaticRings(mol)] +
                     [total_rewards[best_ind].item()]]),
                opts={
                    'legend': [
                        'logP', 'SA', 'cycle', 'norm_reward', 'reward',
                        'Aromatic rings', 'eff_reward'
                    ]
                })
            visdom.append('fraction valid',
                          'line',
                          X=np.array([n]),
                          Y=np.array([valid.mean().data.item()]))

    if reward_fun_off is None:
        reward_fun_off = reward_fun_on

    # construct the loader to feed the discriminator
    def make_callback(data):
        def hc(inputs, model, outputs, loss_fn, loss):
            graphs = outputs['graphs']
            smiles = [g.to_smiles() for g in graphs]
            for s in smiles:  # only store unique instances of molecules so discriminator can't guess on frequency
                if s not in data:
                    data.append(s)

        return hc

    # need to have something there to begin with, else the DataLoader constructor barfs

    def get_rl_fitter(model,
                      loss_obj,
                      train_gen,
                      save_path,
                      fit_plot_prefix='',
                      model_process_fun=None,
                      lr=None,
                      extra_callbacks=[],
                      loss_display_cap=float('inf'),
                      anchor_model=None,
                      anchor_weight=0):
        nice_params = filter(lambda p: p.requires_grad, model.parameters())
        optimizer = optim.Adam(nice_params, lr=lr)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.99)

        if dashboard is not None:
            metric_monitor = MetricPlotter(
                plot_prefix=fit_plot_prefix,
                loss_display_cap=loss_display_cap,
                dashboard_name=dashboard,
                plot_ignore_initial=plot_ignore_initial,
                process_model_fun=model_process_fun,
                smooth_weight=0.9)
        else:
            metric_monitor = None

        checkpointer = Checkpointer(valid_batches_to_checkpoint=1,
                                    save_path=save_path,
                                    save_always=True)

        fitter = fit_rl(train_gen=train_gen,
                        model=model,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        epochs=EPOCHS,
                        loss_fn=loss_obj,
                        grad_clip=5,
                        anchor_model=anchor_model,
                        anchor_weight=anchor_weight,
                        callbacks=[metric_monitor, checkpointer] +
                        extra_callbacks)

        return fitter

    class GeneratorToIterable:
        def __init__(self, gen):
            self.gen = gen
            # we assume the generator is finite
            self.len = 0
            for _ in gen():
                self.len += 1

        def __len__(self):
            return self.len

        def __iter__(self):
            return self.gen()

    def my_gen():
        for _ in range(1000):
            yield to_gpu(torch.zeros(BATCH_SIZE, settings['z_size']))

    # the on-policy fitter

    history_callbacks = [make_callback(d) for d in history_data]
    fitter1 = get_rl_fitter(model,
                            PolicyGradientLoss(on_policy_loss_type,
                                               last_reward_wgt=reward_sm),
                            GeneratorToIterable(my_gen),
                            gen_save_path,
                            plot_prefix + 'on-policy',
                            model_process_fun=model_process_fun,
                            lr=lr_on,
                            extra_callbacks=history_callbacks,
                            anchor_model=anchor_model,
                            anchor_weight=anchor_weight)
    #
    # # get existing molecule data to add training
    pre_dataset = EvenlyBlendedDataset(
        2 * [history_data[0]] + history_data[1:],
        labels=False)  # a blend of 3 time horizons
    dataset = EvenlyBlendedDataset([pre_dataset, zinc_data], labels=True)
    discrim_loader = DataLoader(dataset, shuffle=True, batch_size=50)
    celoss = nn.CrossEntropyLoss()

    def my_loss(x):
        # tmp = discriminator_reward_mult(x['smiles'])
        # tmp2 = F.softmax(x['p_zinc'], dim=1)[:,1].detach().cpu().numpy()
        # import numpy as np
        # assert np.max(np.abs(tmp-tmp2)) < 1e-6
        return celoss(x['p_zinc'].to(device), x['dataset_index'].to(device))

    fitter2 = get_rl_fitter(discrim_model,
                            my_loss,
                            IterableTransform(
                                discrim_loader, lambda x: {
                                    'smiles': x['X'],
                                    'dataset_index': x['dataset_index']
                                }),
                            disc_save_path,
                            plot_prefix + ' discriminator',
                            lr=lr_discrim,
                            model_process_fun=None)

    def on_policy_gen(fitter, model):
        while True:
            model.policy = SoftmaxRandomSamplePolicy(
            )  #bias=codec.grammar.get_log_frequencies())
            yield next(fitter)

    return model, on_policy_gen(fitter1, model), fitter2