示例#1
0
def main():
  # change this to False to produce the equation dataset
  molecules = True
  # change this to False to get character-based encodings instead of grammar-based
  grammar = 'new' #use True  for the grammar used by Kusner et al

  # can't define model class inside settings as it itself uses settings a lot
  _, my_model = get_vae(molecules, grammar)
  def pre_parser(x):
      try:
          return next(my_model._parser.parse(x))
      except Exception as e:
          return None

  settings = get_settings(molecules,grammar)
  MAX_LEN = settings['max_seq_length']
  #feature_len = settings['feature_len']
  dest_file = settings['data_path']
  source_file = settings['source_data']

  # Read in the strings
  f = open(source_file,'r')
  L = []
  for line in f:
      line = line.strip()
      L.append(line)
  f.close()

  # convert to one-hot and save, in small increments to save RAM
  #dest_file = dest_file.replace('.h5','_new.h5')
  ds = IncrementingHDF5Dataset(dest_file)

  step = 100
  dt = h5py.special_dtype(vlen=str)     # PY3 hdf5 datatype for variable-length Unicode strings
  size = min(10000, len(L))
  for i in tqdm(range(0, size, step)):#for i in range(0, 1000, 2000):
      #print('Processing: i=[' + str(i) + ':' + str(i + step) + ']')
      these_indices = list(range(i, min(i + step,len(L))))
      these_smiles = L[i:min(i + step,len(L))]
      if grammar=='new': # have to weed out non-parseable strings
          tokens = [my_model._tokenize(s.replace('-c','c')) for s in these_smiles]
          these_smiles, these_indices = list(zip(*[(s,ind) for s,t,ind in zip(these_smiles, tokens, these_indices) if pre_parser(t) is not None]))
          #print(len(these_smiles))
      these_actions = torch.tensor(my_model.strings_to_actions(these_smiles))
      action_seq_length = my_model.action_seq_length(these_actions)
      onehot = my_model.actions_to_one_hot(these_actions)
      append_data = {'smiles': np.array(these_smiles, dtype=dt),
                    'indices': np.array(these_indices),
                    'actions': these_actions,
                    'valid': np.ones((len(these_smiles))),
                    'seq_len': action_seq_length,
                    'data': onehot}
      if molecules:
          from rdkit.Chem.rdmolfiles import MolFromSmiles
          mols = [MolFromSmiles(s) for s in these_smiles]
          raw_scores = np.array([get_score_components_from_mol(m) for m in mols])
          append_data['raw_scores'] = raw_scores
          num_atoms = np.array([len(m.GetAtoms()) for m in mols])
          append_data['num_atoms'] = num_atoms

      ds.append(append_data)

  if molecules:
      # also calculate mean and std of the scores, to use in the ultimate objective
      raw_scores = np.array(ds.h5f['raw_scores'])
      score_std = raw_scores.std(0)
      score_mean = raw_scores.mean(0)
      ds.append_to_dataset('score_std',score_std)
      ds.append_to_dataset('score_mean', score_mean)

  print('success!')
                                        grammar=True,
                                        BATCH_SIZE=150,
                                        drop_rate=0.3,
                                        sample_z=True,
                                        save_file='next_gen.h5',
                                        encoder_type=False,
                                        lr=5e-4,
                                        plot_prefix='RNN enc lr 1e-4',
                                        dashboard=dash_name,
                                        preload_weights=False)
# this is a wrapper for encoding/decodng
grammar_model = ZincGrammarModel(model=model)
validity_model = to_gpu(
    DenseHead(model.encoder, body_out_dim=settings['z_size'], drop_rate=0.3))

valid_smile_ds = IncrementingHDF5Dataset('valid_smiles.h5', valid_frac=0.1)
invalid_smile_ds = IncrementingHDF5Dataset('invalid_smiles.h5', valid_frac=0.1)

valid_fitter = train_validity(grammar=grammar,
                              model=validity_model,
                              EPOCHS=100,
                              BATCH_SIZE=40,
                              lr=1e-4,
                              main_dataset=main_dataset,
                              new_datasets=(valid_smile_ds, invalid_smile_ds),
                              save_file=None,
                              plot_prefix='valid_model',
                              dashboard=dash_name,
                              preload_weights=False)

# TODO: collect the smiles strings too, just for kicks, into hdf5!
示例#3
0
from generative_playground.molecules.rdkit_utils.rdkit_utils  import fraction_valid
import numpy as np

molecules = True
grammar = True
settings = get_settings(molecules, grammar)

dash_name = 'test'
visdom = Dashboard(dash_name)
model, grammar_model = get_vae(molecules,
                               grammar,
                               drop_rate=0.5,
                               decoder_type='action') # or 'action','old','step','attention'
reinforcement_model = ReinforcementModel(model.decoder)
h5_prefix = 'new4_'
valid_smile_ds = IncrementingHDF5Dataset(h5_prefix +'valid_smiles.h5')
invalid_smile_ds = IncrementingHDF5Dataset(h5_prefix + 'invalid_smiles.h5')
original_ds = IncrementingHDF5Dataset('../data/zinc_grammar_dataset.h5', mode='r')

RL_fitter = train_reinforcement(grammar = grammar,
              model = reinforcement_model,
              EPOCHS = 10000,
              BATCH_SIZE = 25,
              lr = 1e-4,
              new_datasets = (valid_smile_ds, invalid_smile_ds, original_ds),
              save_file = 'first_reinforcement.h5',
              plot_prefix = 'valid_model',
              dashboard = dash_name,
              preload_weights=False)

count = 0
def train_policy_gradient(molecules=True,
                          grammar=True,
                          EPOCHS=None,
                          BATCH_SIZE=None,
                          reward_fun_on=None,
                          reward_fun_off=None,
                          max_steps=277,
                          lr_on=2e-4,
                          lr_off=1e-4,
                          drop_rate=0.0,
                          plot_ignore_initial=0,
                          save_file=None,
                          preload_file=None,
                          anchor_file=None,
                          anchor_weight=0.0,
                          decoder_type='action',
                          plot_prefix='',
                          dashboard='policy gradient',
                          smiles_save_file=None,
                          on_policy_loss_type='best',
                          off_policy_loss_type='mean',
                          sanity_checks=True):

    root_location = os.path.dirname(
        os.path.abspath(inspect.getfile(inspect.currentframe())))
    root_location = root_location + '/../'
    save_path = root_location + 'pretrained/' + save_file
    smiles_save_path = root_location + 'pretrained/' + smiles_save_file

    settings = get_settings(molecules=molecules, grammar=grammar)

    if EPOCHS is not None:
        settings['EPOCHS'] = EPOCHS
    if BATCH_SIZE is not None:
        settings['BATCH_SIZE'] = BATCH_SIZE

    save_dataset = IncrementingHDF5Dataset(smiles_save_path)

    task = SequenceGenerationTask(molecules=molecules,
                                  grammar=grammar,
                                  reward_fun=reward_fun_on,
                                  batch_size=BATCH_SIZE,
                                  max_steps=max_steps,
                                  save_dataset=save_dataset)

    def get_model(sanity_checks=sanity_checks):
        return get_decoder(
            molecules,
            grammar,
            z_size=settings['z_size'],
            decoder_hidden_n=200,
            feature_len=None,  # made redundant, need to factor it out
            max_seq_length=max_steps,
            drop_rate=drop_rate,
            decoder_type=decoder_type,
            task=task)[0]

    model = get_model()

    if preload_file is not None:
        try:
            preload_path = root_location + 'pretrained/' + preload_file
            model.load_state_dict(torch.load(preload_path))
        except:
            pass

    anchor_model = None
    if anchor_file is not None:
        anchor_model = get_model()
        try:
            anchor_path = root_location + 'pretrained/' + anchor_file
            anchor_model.load_state_dict(torch.load(anchor_path))
        except:
            anchor_model = None

    from generative_playground.molecules.rdkit_utils.rdkit_utils import NormalizedScorer
    import rdkit.Chem.rdMolDescriptors as desc
    import numpy as np
    scorer = NormalizedScorer()

    def model_process_fun(model_out, visdom, n):
        from rdkit import Chem
        from rdkit.Chem.Draw import MolToFile
        actions, logits, rewards, terminals, info = model_out
        smiles, valid = info
        total_rewards = rewards.sum(1)
        best_ind = torch.argmax(total_rewards).data.item()
        this_smile = smiles[best_ind]
        mol = Chem.MolFromSmiles(this_smile)
        pic_save_path = root_location + 'images/' + 'test.svg'
        if mol is not None:
            try:
                MolToFile(mol, pic_save_path, imageType='svg')
                with open(pic_save_path, 'r') as myfile:
                    data = myfile.read()
                data = data.replace('svg:', '')
                visdom.append('best molecule of batch', 'svg', svgstr=data)
            except:
                pass
            scores, norm_scores = scorer.get_scores([this_smile])
            visdom.append(
                'score component',
                'line',
                X=np.array([n]),
                Y=np.array(
                    [[x for x in norm_scores[0]] + [norm_scores[0].sum()] +
                     [scores[0].sum()] + [desc.CalcNumAromaticRings(mol)]]),
                opts={
                    'legend': [
                        'logP', 'SA', 'cycle', 'norm_reward', 'reward',
                        'Aromatic rings'
                    ]
                })
            visdom.append('fraction valid',
                          'line',
                          X=np.array([n]),
                          Y=np.array([valid.mean().data.item()]))

    if reward_fun_off is None:
        reward_fun_off = reward_fun_on

    def get_fitter(model,
                   loss_obj,
                   fit_plot_prefix='',
                   model_process_fun=None,
                   lr=None,
                   loss_display_cap=float('inf'),
                   anchor_model=None,
                   anchor_weight=0):
        nice_params = filter(lambda p: p.requires_grad, model.parameters())
        optimizer = optim.Adam(nice_params, lr=lr)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.99)

        if dashboard is not None:
            metric_monitor = MetricPlotter(
                plot_prefix=fit_plot_prefix,
                loss_display_cap=loss_display_cap,
                dashboard_name=dashboard,
                plot_ignore_initial=plot_ignore_initial,
                process_model_fun=model_process_fun)
        else:
            metric_monitor = None

        checkpointer = Checkpointer(valid_batches_to_checkpoint=1,
                                    save_path=save_path,
                                    save_always=True)

        def my_gen():
            for _ in range(1000):
                yield to_gpu(torch.zeros(BATCH_SIZE, settings['z_size']))

        fitter = fit_rl(train_gen=my_gen,
                        model=model,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        epochs=EPOCHS,
                        loss_fn=loss_obj,
                        grad_clip=5,
                        anchor_model=anchor_model,
                        anchor_weight=anchor_weight,
                        metric_monitor=metric_monitor,
                        checkpointer=checkpointer)

        return fitter

    # the on-policy fitter
    fitter1 = get_fitter(model,
                         PolicyGradientLoss(on_policy_loss_type),
                         plot_prefix + 'on-policy',
                         model_process_fun=model_process_fun,
                         lr=lr_on,
                         anchor_model=anchor_model,
                         anchor_weight=anchor_weight)
    # get existing molecule data to add training
    main_dataset = DatasetFromHDF5(settings['data_path'], 'actions')

    # TODO change call to a simple DataLoader, no validation
    train_loader, valid_loader = train_valid_loaders(main_dataset,
                                                     valid_fraction=0.1,
                                                     batch_size=BATCH_SIZE,
                                                     pin_memory=use_gpu)

    fitter2 = get_fitter(model,
                         PolicyGradientLoss(off_policy_loss_type),
                         plot_prefix + ' off-policy',
                         lr=lr_off,
                         model_process_fun=model_process_fun,
                         loss_display_cap=125)

    def on_policy_gen(fitter, model):
        while True:
            model.policy = SoftmaxRandomSamplePolicy()
            yield next(fitter)

    def off_policy_gen(fitter, data_gen, model):
        while True:
            data_iter = data_gen.__iter__()
            try:
                x_actions = next(data_iter).to(torch.int64)
                model.policy = PolicyFromTarget(x_actions)
                yield next(fitter)
            except StopIteration:
                data_iter = data_gen.__iter__()

    return model, on_policy_gen(fitter1,
                                model), off_policy_gen(fitter2, train_loader,
                                                       model)
MAX_LEN = settings['max_seq_length']
feature_len = settings['feature_len']
dest_file = settings['data_path']
source_file = settings['source_data']

# Read in the strings
f = open(source_file, 'r')
L = []
for line in f:
    line = line.strip()
    L.append(line)
f.close()

# convert to one-hot and save, in small increments to save RAM
#dest_file = dest_file.replace('.h5','_new.h5')
ds = IncrementingHDF5Dataset(dest_file)

step = 100
dt = h5py.special_dtype(
    vlen=str)  # PY3 hdf5 datatype for variable-length Unicode strings

for i in range(0, len(L), step):  #for i in range(0, 1000, 2000):
    print('Processing: i=[' + str(i) + ':' + str(i + step) + ']')
    these_indices = list(range(i, min(i + step, len(L))))
    these_smiles = L[i:min(i + step, len(L))]
    if grammar == 'new':  # have to weed out non-parseable strings
        tokens = [
            my_model._tokenize(s.replace('-c', 'c')) for s in these_smiles
        ]
        these_smiles, these_indices = list(
            zip(*[(s, ind)
示例#6
0
def train_policy_gradient(molecules=True,
                          grammar=True,
                          EPOCHS=None,
                          BATCH_SIZE=None,
                          reward_fun_on=None,
                          reward_fun_off=None,
                          max_steps=277,
                          lr_on=2e-4,
                          lr_discrim=1e-4,
                          p_thresh=0.5,
                          drop_rate=0.0,
                          plot_ignore_initial=0,
                          randomize_reward=False,
                          save_file=None,
                          reward_sm=0.0,
                          preload_file=None,
                          anchor_file=None,
                          anchor_weight=0.0,
                          decoder_type='action',
                          plot_prefix='',
                          dashboard='policy gradient',
                          smiles_save_file=None,
                          on_policy_loss_type='best',
                          off_policy_loss_type='mean',
                          sanity_checks=True):
    root_location = os.path.dirname(
        os.path.abspath(inspect.getfile(inspect.currentframe())))
    root_location = root_location + '/../../'
    gen_save_path = root_location + 'pretrained/gen_' + save_file
    disc_save_path = root_location + 'pretrained/disc_' + save_file

    if smiles_save_file is not None:
        smiles_save_path = root_location + 'pretrained/' + smiles_save_file
        save_dataset = IncrementingHDF5Dataset(smiles_save_path)
    else:
        save_dataset = None

    settings = get_settings(molecules=molecules, grammar=grammar)
    codec = get_codec(molecules, grammar, settings['max_seq_length'])
    discrim_model = GraphDiscriminator(codec.grammar, drop_rate=drop_rate)

    zinc_data = get_zinc_smiles()
    zinc_set = set(zinc_data)
    lookbacks = [BATCH_SIZE, 10 * BATCH_SIZE, 100 * BATCH_SIZE]
    history_data = [deque(['O'], maxlen=lb) for lb in lookbacks]

    def originality_mult(smiles_list):
        out = []
        for s in smiles_list:
            if s in zinc_set:
                out.append(0.5)
            elif s in history_data[0]:
                out.append(0.5)
            elif s in history_data[1]:
                out.append(0.70)
            elif s in history_data[2]:
                out.append(0.85)
            else:
                out.append(1.0)
        return np.array(out)

    def sigmoid(x):
        tmp = -x  #(
        return 1 / (1 + np.exp(-x))

    def discriminator_reward_mult(smiles_list):
        orig_state = discrim_model.training
        discrim_model.eval()
        discrim_out_logits = discrim_model(smiles_list)['p_zinc']
        discrim_probs = F.softmax(discrim_out_logits, dim=1)
        prob_zinc = discrim_probs[:, 1].detach().cpu().numpy()
        if orig_state:
            discrim_model.train()
        return prob_zinc

    def adj_reward(x):
        p = discriminator_reward_mult(x)
        reward = np.maximum(reward_fun_on(x), 0)
        out = reward * originality_mult(x) + 2 * p
        return out

    def adj_reward_old(x):
        p = discriminator_reward_mult(x)
        w = sigmoid(-(p - p_thresh) / 0.01)
        if randomize_reward:
            rand = np.random.uniform(size=p.shape)
            w *= rand
        reward = np.maximum(reward_fun_on(x), p_thresh)
        weighted_reward = w * p + (1 - w) * reward
        out = weighted_reward * originality_mult(x)  #
        return out

    if EPOCHS is not None:
        settings['EPOCHS'] = EPOCHS
    if BATCH_SIZE is not None:
        settings['BATCH_SIZE'] = BATCH_SIZE

    task = SequenceGenerationTask(molecules=molecules,
                                  grammar=grammar,
                                  reward_fun=adj_reward,
                                  batch_size=BATCH_SIZE,
                                  max_steps=max_steps,
                                  save_dataset=save_dataset)

    model = get_decoder(molecules,
                        grammar,
                        z_size=settings['z_size'],
                        decoder_hidden_n=200,
                        feature_len=codec.feature_len(),
                        max_seq_length=max_steps,
                        drop_rate=drop_rate,
                        decoder_type=decoder_type,
                        task=task)[0]

    # TODO: really ugly, refactor! In fact this model doesn't need a MaskingHead at all!
    model.stepper.model.mask_gen.priors = True  #'conditional' # use empirical priors for the mask gen
    # if preload_file is not None:
    #     try:
    #         preload_path = root_location + 'pretrained/' + preload_file
    #         model.load_state_dict(torch.load(preload_path))
    #     except:
    #         pass

    anchor_model = None

    from generative_playground.molecules.rdkit_utils.rdkit_utils import NormalizedScorer
    import rdkit.Chem.rdMolDescriptors as desc
    import numpy as np
    scorer = NormalizedScorer()

    def model_process_fun(model_out, visdom, n):
        # TODO: rephrase this to return a dict, instead of calling visdom directly
        from rdkit import Chem
        from rdkit.Chem.Draw import MolToFile
        # actions, logits, rewards, terminals, info = model_out
        smiles, valid = model_out['info']
        total_rewards = model_out['rewards'].sum(1)
        best_ind = torch.argmax(total_rewards).data.item()
        this_smile = smiles[best_ind]
        mol = Chem.MolFromSmiles(this_smile)
        pic_save_path = root_location + 'images/' + 'tmp.svg'
        if mol is not None:
            try:
                MolToFile(mol, pic_save_path, imageType='svg')
                with open(pic_save_path, 'r') as myfile:
                    data = myfile.read()
                data = data.replace('svg:', '')
                visdom.append('best molecule of batch', 'svg', svgstr=data)
            except Exception as e:
                print(e)
            scores, norm_scores = scorer.get_scores([this_smile])
            visdom.append(
                'score component',
                'line',
                X=np.array([n]),
                Y=np.array(
                    [[x for x in norm_scores[0]] + [norm_scores[0].sum()] +
                     [scores[0].sum()] + [desc.CalcNumAromaticRings(mol)] +
                     [total_rewards[best_ind].item()]]),
                opts={
                    'legend': [
                        'logP', 'SA', 'cycle', 'norm_reward', 'reward',
                        'Aromatic rings', 'eff_reward'
                    ]
                })
            visdom.append('fraction valid',
                          'line',
                          X=np.array([n]),
                          Y=np.array([valid.mean().data.item()]))

    if reward_fun_off is None:
        reward_fun_off = reward_fun_on

    # construct the loader to feed the discriminator
    def make_callback(data):
        def hc(inputs, model, outputs, loss_fn, loss):
            graphs = outputs['graphs']
            smiles = [g.to_smiles() for g in graphs]
            for s in smiles:  # only store unique instances of molecules so discriminator can't guess on frequency
                if s not in data:
                    data.append(s)

        return hc

    # need to have something there to begin with, else the DataLoader constructor barfs

    def get_rl_fitter(model,
                      loss_obj,
                      train_gen,
                      save_path,
                      fit_plot_prefix='',
                      model_process_fun=None,
                      lr=None,
                      extra_callbacks=[],
                      loss_display_cap=float('inf'),
                      anchor_model=None,
                      anchor_weight=0):
        nice_params = filter(lambda p: p.requires_grad, model.parameters())
        optimizer = optim.Adam(nice_params, lr=lr)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.99)

        if dashboard is not None:
            metric_monitor = MetricPlotter(
                plot_prefix=fit_plot_prefix,
                loss_display_cap=loss_display_cap,
                dashboard_name=dashboard,
                plot_ignore_initial=plot_ignore_initial,
                process_model_fun=model_process_fun,
                smooth_weight=0.9)
        else:
            metric_monitor = None

        checkpointer = Checkpointer(valid_batches_to_checkpoint=1,
                                    save_path=save_path,
                                    save_always=True)

        fitter = fit_rl(train_gen=train_gen,
                        model=model,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        epochs=EPOCHS,
                        loss_fn=loss_obj,
                        grad_clip=5,
                        anchor_model=anchor_model,
                        anchor_weight=anchor_weight,
                        callbacks=[metric_monitor, checkpointer] +
                        extra_callbacks)

        return fitter

    class GeneratorToIterable:
        def __init__(self, gen):
            self.gen = gen
            # we assume the generator is finite
            self.len = 0
            for _ in gen():
                self.len += 1

        def __len__(self):
            return self.len

        def __iter__(self):
            return self.gen()

    def my_gen():
        for _ in range(1000):
            yield to_gpu(torch.zeros(BATCH_SIZE, settings['z_size']))

    # the on-policy fitter

    history_callbacks = [make_callback(d) for d in history_data]
    fitter1 = get_rl_fitter(model,
                            PolicyGradientLoss(on_policy_loss_type,
                                               last_reward_wgt=reward_sm),
                            GeneratorToIterable(my_gen),
                            gen_save_path,
                            plot_prefix + 'on-policy',
                            model_process_fun=model_process_fun,
                            lr=lr_on,
                            extra_callbacks=history_callbacks,
                            anchor_model=anchor_model,
                            anchor_weight=anchor_weight)
    #
    # # get existing molecule data to add training
    pre_dataset = EvenlyBlendedDataset(
        2 * [history_data[0]] + history_data[1:],
        labels=False)  # a blend of 3 time horizons
    dataset = EvenlyBlendedDataset([pre_dataset, zinc_data], labels=True)
    discrim_loader = DataLoader(dataset, shuffle=True, batch_size=50)
    celoss = nn.CrossEntropyLoss()

    def my_loss(x):
        # tmp = discriminator_reward_mult(x['smiles'])
        # tmp2 = F.softmax(x['p_zinc'], dim=1)[:,1].detach().cpu().numpy()
        # import numpy as np
        # assert np.max(np.abs(tmp-tmp2)) < 1e-6
        return celoss(x['p_zinc'].to(device), x['dataset_index'].to(device))

    fitter2 = get_rl_fitter(discrim_model,
                            my_loss,
                            IterableTransform(
                                discrim_loader, lambda x: {
                                    'smiles': x['X'],
                                    'dataset_index': x['dataset_index']
                                }),
                            disc_save_path,
                            plot_prefix + ' discriminator',
                            lr=lr_discrim,
                            model_process_fun=None)

    def on_policy_gen(fitter, model):
        while True:
            model.policy = SoftmaxRandomSamplePolicy(
            )  #bias=codec.grammar.get_log_frequencies())
            yield next(fitter)

    return model, on_policy_gen(fitter1, model), fitter2