def test_hypergraph_mask_gen(self):
        molecules = True
        grammar_cache = 'tmp.pickle'
        grammar = 'hypergraph:' + grammar_cache
        # create a grammar cache inferred from our sample molecules
        g = HypergraphGrammar(cache_file=grammar_cache)
        if os.path.isfile(g.cache_file):
            os.remove(g.cache_file)
        g.strings_to_actions(get_smiles_from_database(5))
        mask_gen1 = get_codec(molecules, grammar, 30).mask_gen
        mask_gen2 = get_codec(molecules, grammar, 30).mask_gen
        mask_gen1.priors = False
        mask_gen2.priors = True
        policy1 = SoftmaxRandomSamplePolicy()
        policy2 = SoftmaxRandomSamplePolicy()
        lp = []
        for mg in [mask_gen1, mask_gen2]:
            mg.reset()
            mg.apply_action([None])
            logit_priors = mg.action_prior_logits()  # that includes any priors
            lp.append(
                torch.from_numpy(logit_priors).to(device=device,
                                                  dtype=torch.float32))

        dummy_model_output = torch.ones_like(lp[0])
        eff_logits = []
        for this_lp, policy in zip(lp, [policy1, policy2]):
            eff_logits.append(policy.effective_logits(dummy_model_output))
 def test_discriminator_class_batch_independence(self):
     d = GraphDiscriminator(gi.grammar, drop_rate=0.0)
     smiles = get_smiles_from_database(5)
     out1 = d({'smiles': smiles})['p_zinc']
     out2 = d({'smiles': smiles[:1]})['p_zinc']
     diff = torch.max((out1[0, :] - out2[0, :]).abs())
     assert diff < 1e-6, "There is cross-talk between batches"
 def test_discriminator_class_determinism(self):
     d = GraphDiscriminator(gi.grammar, drop_rate=0.0)
     smiles = get_smiles_from_database(5)
     out1 = d({'smiles': smiles})['p_zinc']
     out2 = d({'smiles': smiles})['p_zinc']
     diff = torch.max((out1 - out2).abs())
     assert diff < 1e-6, "Function is non-deterministic"
 def test_discriminator_class(self):
     d = GraphDiscriminator(gi.grammar, drop_rate=0.1)
     smiles = get_smiles_from_database(5)
     out = d(smiles)
     assert out['p_zinc'].size(0) == len(smiles)
     assert out['p_zinc'].size(1) == 2
     assert len(out['p_zinc'].size()) == 2
     assert out['smiles'] == smiles
Beispiel #5
0
    def init_grammar(self, max_num_mols):
        L = get_smiles_from_database(max_num_mols)
        for ind, smiles in enumerate(L):
            if ind >= max_num_mols:
                break
            if ind > self.last_processed:  # don't repeat
                try:
                    # this causes g to remember all the rules occurring in these molecules
                    these_actions = self.grammar.raw_strings_to_actions(
                        [smiles])
                    this_tree = self.grammar.last_tree_processed
                    these_tuples = tree_with_rule_inds_to_list_of_tuples(
                        this_tree)
                    for p, nt, c in these_tuples:
                        if (p, nt) not in self.grammar.conditional_frequencies:
                            self.grammar.conditional_frequencies[(p, nt)] = {}
                        if c not in self.grammar.conditional_frequencies[(p,
                                                                          nt)]:
                            self.grammar.conditional_frequencies[(p,
                                                                  nt)][c] = 1
                        else:
                            self.grammar.conditional_frequencies[(p,
                                                                  nt)][c] += 1
                    # count the frequency of the occurring rules
                    for aa in these_actions:
                        for a in aa:
                            if a not in self.grammar.rule_frequency_dict:
                                self.grammar.rule_frequency_dict[a] = 0
                            self.grammar.rule_frequency_dict[a] += 1

                    lengths = [len(x) for x in these_actions]
                    new_max_len = max(lengths)
                    self.total_len += sum(lengths)
                    if new_max_len > self.max_len:
                        self.max_len = new_max_len
                        print("Max len so far:", self.max_len)
                except Exception as e:  #TODO: fix this, make errors not happen ;)
                    print(e)
                self.last_processed = ind
                # if we discovered a new rule, remember that
                if not len(self.new_rules) or self.grammar.rate_tracker[-1][
                        -1] > self.new_rules[-1][-1]:
                    self.new_rules.append(
                        (ind, *self.grammar.rate_tracker[-1]))
                    print(self.new_rules[-1])
            if ind % 10 == 9:
                self.save()
            if ind % 100 == 0 and ind > 0:
                self.stats[ind] = {
                    'max_len': self.max_len,
                    'avg_len': self.total_len / ind,
                    'num_rules': len(self.grammar.rules),
                }
        self.grammar.normalize_conditional_frequencies()
        self.grammar.calc_terminal_distance()
        return self.max_len  # maximum observed molecule length
 def test_zinc_loaders(self):
     history_size = 1000
     history_data = deque(['aaa', 'aaa', 'aaa'], maxlen=history_size)
     zinc_data = get_smiles_from_database(100)
     dataset = EvenlyBlendedDataset([history_data, zinc_data], labels=True)
     loader = DataLoader(dataset, shuffle=True, batch_size=10)
     for batch in loader:
         assert type(batch) == dict
         assert 'X' in batch
         assert 'dataset_index' in batch
         assert len(batch['X']) == len(batch['dataset_index'])
         for x, label in zip(batch['X'], batch['dataset_index']):
             if label == 1:
                 assert x != 'aaa'
             elif label == 0:
                 assert x == 'aaa'
             else:
                 raise ValueError("Unknown label")
Beispiel #7
0
import logging
import random
import numpy as np
from unittest import TestCase, skip
from generative_playground.codec.hypergraph import to_mol, HyperGraph, HypergraphTree
from generative_playground.codec.hypergraph_parser import hypergraph_parser, graph_from_graph_tree
from generative_playground.molecules.data_utils.zinc_utils import get_smiles_from_database
from generative_playground.codec.hypergraph_grammar import evaluate_rules, HypergraphGrammar, apply_rule
from generative_playground.codec.hypergraph_mask_generator import HypergraphMaskGenerator
from generative_playground.codec.hypergraph_rpe_grammar import HypergraphRPEGrammar
from rdkit.Chem import MolFromSmiles, AddHs, MolToSmiles, RemoveHs, Kekulize, BondType

smiles = get_smiles_from_database(10)
smiles1 = smiles[0]
bad_smiles = [
    'C1(CCCCC1)(C)C(=O)N', 'C1(CCCCC1)(C)C', 'C12=CC=CC=C1C=C3[N]2CC(NC3)(C)C',
    'CC(=O)Nc1c2n(c3ccccc13)C[C@](C)(C(=O)NC1CCCCC1)N(C1CCCCC1)C2=O'
]


class TestStart(TestCase):
    def test_hypergraph_roundtrip(self):
        mol = MolFromSmiles(smiles1)
        hg = HyperGraph.from_mol(mol)
        re_mol = to_mol(hg)
        re_smiles = MolToSmiles(re_mol)
        assert re_smiles == smiles1

    def test_hypergraph_via_nx_graph_roundtrip(self):
        mol = MolFromSmiles(smiles1)
        hg = HyperGraph.from_mol(mol)
def train_policy_gradient(molecules=True,
                          grammar=True,
                          smiles_source='ZINC',
                          EPOCHS=None,
                          BATCH_SIZE=None,
                          reward_fun_on=None,
                          reward_fun_off=None,
                          max_steps=277,
                          lr_on=2e-4,
                          lr_discrim=1e-4,
                          lr_schedule=None,
                          discrim_wt=2,
                          p_thresh=0.5,
                          drop_rate=0.0,
                          plot_ignore_initial=0,
                          randomize_reward=False,
                          save_file_root_name=None,
                          reward_sm=0.0,
                          preload_file_root_name=None,
                          anchor_file=None,
                          anchor_weight=0.0,
                          decoder_type='action',
                          plot_prefix='',
                          dashboard='policy gradient',
                          smiles_save_file=None,
                          on_policy_loss_type='best',
                          priors=True,
                          node_temperature_schedule=lambda x: 1.0,
                          rule_temperature_schedule=lambda x: 1.0,
                          eps=0.0,
                          half_float=False,
                          extra_repetition_penalty=0.0,
                          entropy_wgt=1.0):
    root_location = os.path.dirname(
        os.path.abspath(inspect.getfile(inspect.currentframe())))
    root_location = root_location + '/../../'

    def full_path(x):
        return os.path.realpath(root_location + 'pretrained/' + x)

    zinc_data = get_smiles_from_database(source=smiles_source)
    zinc_set = set(zinc_data)
    lookbacks = [BATCH_SIZE, 10 * BATCH_SIZE, 100 * BATCH_SIZE]
    history_data = [deque(['O'], maxlen=lb) for lb in lookbacks]

    if save_file_root_name is not None:
        gen_save_file = save_file_root_name + '_gen.h5'
        disc_save_file = save_file_root_name + '_disc.h5'
    if preload_file_root_name is not None:
        gen_preload_file = preload_file_root_name + '_gen.h5'
        disc_preload_file = preload_file_root_name + '_disc.h5'

    settings = get_settings(molecules=molecules, grammar=grammar)
    codec = get_codec(molecules, grammar, settings['max_seq_length'])
    discrim_model = GraphDiscriminator(codec.grammar, drop_rate=drop_rate)
    if False and preload_file_root_name is not None:
        try:
            preload_path = full_path(disc_preload_file)
            discrim_model.load_state_dict(torch.load(preload_path),
                                          strict=False)
            print('Discriminator weights loaded successfully!')
        except Exception as e:
            print('failed to load discriminator weights ' + str(e))

    if EPOCHS is not None:
        settings['EPOCHS'] = EPOCHS
    if BATCH_SIZE is not None:
        settings['BATCH_SIZE'] = BATCH_SIZE

    alt_reward_calc = AdjustedRewardCalculator(reward_fun_on,
                                               zinc_set,
                                               lookbacks,
                                               extra_repetition_penalty,
                                               discrim_wt,
                                               discrim_model=None)

    reward_fun = lambda x: adj_reward(discrim_wt,
                                      discrim_model,
                                      reward_fun_on,
                                      zinc_set,
                                      history_data,
                                      extra_repetition_penalty,
                                      x,
                                      alt_calc=alt_reward_calc)

    task = SequenceGenerationTask(molecules=molecules,
                                  grammar=grammar,
                                  reward_fun=reward_fun,
                                  batch_size=BATCH_SIZE,
                                  max_steps=max_steps,
                                  save_dataset=None)

    node_policy = SoftmaxRandomSamplePolicy(temperature=torch.tensor(1.0),
                                            eps=eps)
    rule_policy = SoftmaxRandomSamplePolicy(temperature=torch.tensor(2.0),
                                            eps=eps)

    model = get_decoder(molecules,
                        grammar,
                        z_size=settings['z_size'],
                        decoder_hidden_n=200,
                        feature_len=codec.feature_len(),
                        max_seq_length=max_steps,
                        drop_rate=drop_rate,
                        batch_size=BATCH_SIZE,
                        decoder_type=decoder_type,
                        reward_fun=reward_fun,
                        task=task,
                        node_policy=node_policy,
                        rule_policy=rule_policy,
                        priors=priors)[0]

    if preload_file_root_name is not None:
        try:
            preload_path = full_path(gen_preload_file)
            model.load_state_dict(torch.load(preload_path, map_location='cpu'),
                                  strict=False)
            print('Generator weights loaded successfully!')
        except Exception as e:
            print('failed to load generator weights ' + str(e))

    anchor_model = None

    # construct the loader to feed the discriminator
    def make_callback(data):
        def hc(inputs, model, outputs, loss_fn, loss):
            graphs = outputs['graphs']
            smiles = [g.to_smiles() for g in graphs]
            for s in smiles:  # only store unique instances of molecules so discriminator can't guess on frequency
                if s not in data:
                    data.append(s)

        return hc

    class TemperatureCallback:
        def __init__(self, policy, temperature_function):
            self.policy = policy
            self.counter = 0
            self.temp_fun = temperature_function

        def __call__(self, inputs, model, outputs, loss_fn, loss):
            self.counter += 1
            target_temp = self.temp_fun(self.counter)
            self.policy.set_temperature(target_temp)

    # need to have something there to begin with, else the DataLoader constructor barfs

    def get_rl_fitter(model,
                      loss_obj,
                      train_gen,
                      save_path,
                      fit_plot_prefix='',
                      model_process_fun=None,
                      lr=None,
                      lr_schedule=lr_schedule,
                      extra_callbacks=[],
                      loss_display_cap=float('inf'),
                      anchor_model=None,
                      anchor_weight=0):
        nice_params = filter(lambda p: p.requires_grad, model.parameters())
        optimizer = optim.Adam(nice_params, lr=lr, eps=1e-4)
        if lr_schedule is None:
            lr_schedule = lambda x: 1.0
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_schedule)

        if dashboard is not None:
            metric_monitor = MetricPlotter(
                plot_prefix=fit_plot_prefix,
                loss_display_cap=loss_display_cap,
                dashboard_name=dashboard,
                plot_ignore_initial=plot_ignore_initial,
                process_model_fun=model_process_fun,
                smooth_weight=reward_sm,
                save_location=os.path.dirname(save_path))
        else:
            metric_monitor = None

        checkpointer = Checkpointer(valid_batches_to_checkpoint=10,
                                    save_path=save_path,
                                    save_always=True)

        fitter = fit_rl(train_gen=train_gen,
                        model=model,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        epochs=EPOCHS,
                        loss_fn=loss_obj,
                        grad_clip=5,
                        half_float=half_float,
                        anchor_model=anchor_model,
                        anchor_weight=anchor_weight,
                        callbacks=[metric_monitor, checkpointer] +
                        extra_callbacks)

        return fitter

    class GeneratorToIterable:
        def __init__(self, gen):
            self.gen = gen
            # we assume the generator is finite
            self.len = 0
            for _ in gen():
                self.len += 1

        def __len__(self):
            return self.len

        def __iter__(self):
            return self.gen()

    def my_gen():
        for _ in range(1000):
            yield to_gpu(torch.zeros(BATCH_SIZE, settings['z_size']))

    # the on-policy fitter

    gen_extra_callbacks = [make_callback(d) for d in history_data]

    if smiles_save_file is not None:
        smiles_save_path = os.path.realpath(root_location + 'pretrained/' +
                                            smiles_save_file)
        gen_extra_callbacks.append(MoleculeSaver(smiles_save_path, gzip=True))
        print('Saved SMILES to {}'.format(smiles_save_file))

    if node_temperature_schedule is not None:
        gen_extra_callbacks.append(
            TemperatureCallback(node_policy, node_temperature_schedule))

    if rule_temperature_schedule is not None:
        gen_extra_callbacks.append(
            TemperatureCallback(rule_policy, rule_temperature_schedule))

    fitter1 = get_rl_fitter(
        model,
        PolicyGradientLoss(
            on_policy_loss_type,
            entropy_wgt=entropy_wgt),  # last_reward_wgt=reward_sm),
        GeneratorToIterable(my_gen),
        full_path(gen_save_file),
        plot_prefix + 'on-policy',
        model_process_fun=model_process_fun,
        lr=lr_on,
        extra_callbacks=gen_extra_callbacks,
        anchor_model=anchor_model,
        anchor_weight=anchor_weight)
    #
    # # get existing molecule data to add training
    pre_dataset = EvenlyBlendedDataset(
        2 * [history_data[0]] + history_data[1:],
        labels=False)  # a blend of 3 time horizons
    dataset = EvenlyBlendedDataset([pre_dataset, zinc_data], labels=True)
    discrim_loader = DataLoader(dataset, shuffle=True, batch_size=50)

    class MyLoss(nn.Module):
        def __init__(self):
            super().__init__()
            self.celoss = nn.CrossEntropyLoss()

        def forward(self, x):
            # tmp = discriminator_reward_mult(x['smiles'])
            # tmp2 = F.softmax(x['p_zinc'], dim=1)[:,1].detach().cpu().numpy()
            # import numpy as np
            # assert np.max(np.abs(tmp-tmp2)) < 1e-6
            return self.celoss(x['p_zinc'].to(device),
                               x['dataset_index'].to(device))

    fitter2 = get_rl_fitter(discrim_model,
                            MyLoss(),
                            IterableTransform(
                                discrim_loader, lambda x: {
                                    'smiles': x['X'],
                                    'dataset_index': x['dataset_index']
                                }),
                            full_path(disc_save_file),
                            plot_prefix + ' discriminator',
                            lr=lr_discrim,
                            model_process_fun=None)

    def on_policy_gen(fitter, model):
        while True:
            # model.policy = SoftmaxRandomSamplePolicy()#bias=codec.grammar.get_log_frequencies())
            yield next(fitter)

    return model, fitter1, fitter2  #,on_policy_gen(fitter1, model)
Beispiel #9
0
def run_mcts(
    num_batches=10,  # respawn after that - workaround for memory leak
    batch_size=20,
    ver='trivial',  # 'v2'#'
    obj_num=4,
    grammar_cache='hyper_grammar_guac_10k_with_clique_collapse.pickle',  # 'hyper_grammar.pickle'
    max_seq_length=60,
    root_name='',
    compress_data_store=True,
    kind='thompson_local',
    reset_cache=False,
    penalize_repetition=True,
    save_every=10,
    num_bins=50,  # TODO: replace with a Value Distribution object
    updates_to_refresh=10,
    decay=0.95,
    lr=0.05,
    grad_clip=5,
    entropy_weight=3,
):

    if penalize_repetition:
        zinc_set = set(get_smiles_from_database(source='ChEMBL:train'))
        lookbacks = [batch_size, 10 * batch_size, 100 * batch_size]
        pre_reward_fun = guacamol_goal_scoring_functions(ver)[obj_num]
        reward_fun_ = AdjustedRewardCalculator(pre_reward_fun, zinc_set,
                                               lookbacks)
        # reward_fun_ = CountRewardAdjuster(pre_reward_fun)
    else:
        reward_fun_ = lambda x: guacamol_goal_scoring_functions(ver)[obj_num
                                                                     ]([x])[0]

    # load or create the global variables needed
    here = os.path.dirname(__file__)
    save_path = os.path.realpath(
        here + '../../../../molecules/train/mcts/data/') + '/'
    globals_name = os.path.realpath(save_path + root_name + '.gpkl')
    db_path = '/' + os.path.realpath(save_path + root_name + '.db').replace(
        '\\', '/')

    plotter = MetricPlotter(plot_prefix='',
                            save_file=None,
                            loss_display_cap=4,
                            dashboard_name=root_name,
                            plot_ignore_initial=0,
                            process_model_fun=model_process_fun,
                            extra_metric_fun=None,
                            smooth_weight=0.5,
                            frequent_calls=False)

    # if 'thompson' in kind:
    #     my_globals = get_thompson_globals(num_bins=num_bins,
    #                                       reward_fun_=reward_fun_,
    #                                       grammar_cache=grammar_cache,
    #                                       max_seq_length=max_seq_length,
    #                                       decay=decay,
    #                                       updates_to_refresh=updates_to_refresh,
    #                                       plotter=plotter
    #                                       )
    # elif 'model' in kind:
    my_globals = GlobalParametersModel(
        batch_size=batch_size,
        reward_fun_=reward_fun_,
        grammar_cache=grammar_cache,  # 'hyper_grammar.pickle'
        max_depth=max_seq_length,
        lr=lr,
        grad_clip=grad_clip,
        entropy_weight=entropy_weight,
        decay=decay,
        num_bins=num_bins,
        updates_to_refresh=updates_to_refresh,
        plotter=plotter,
        degenerate=True if kind == 'model_thompson' else False)
    if reset_cache:
        try:
            os.remove(globals_name)
            print('removed globals cache ' + globals_name)
        except:
            print("Could not remove globals cache" + globals_name)
        try:
            os.remove(db_path[1:])
            print('removed locals cache ' + db_path[1:])
        except:
            print("Could not remove locals cache" + db_path[1:])
    else:
        try:
            with gzip.open(globals_name) as f:
                global_state = dill.load(f)
                my_globals.set_mutable_state(global_state)
                print("Loaded global state cache!")
        except:
            pass

    node_type = class_from_kind(kind)
    from generative_playground.utils.deep_getsizeof import memory_by_type
    with Shelve(db_path, 'kv_table',
                compress=compress_data_store) as state_store:
        my_globals.state_store = state_store
        root_node = node_type(my_globals,
                              parent=None,
                              source_action=None,
                              depth=1)

        for b in range(num_batches):
            mem = memory_by_type()
            print("memory pre-explore", sum([x[2] for x in mem]), mem[:5])
            rewards, infos = explore(root_node, batch_size)
            state_store.flush()
            mem = memory_by_type()
            print("memory post-explore", sum([x[2] for x in mem]), mem[:5])
            if b % save_every == save_every - 1:
                print('**** saving global state ****')
                with gzip.open(globals_name, 'wb') as f:
                    dill.dump(my_globals.get_mutable_state(), f)

            # visualisation code goes here
            plotter_input = {
                'rewards': np.array(rewards),
                'info': [[x['smiles'] for x in infos],
                         np.ones(len(infos))]
            }
            my_globals.plotter(None, None, plotter_input, None, None)
            print(max(rewards))

    # print(root_node.result_repo.avg_reward())
    # temp_dir.cleanup()
    print("done!")
Beispiel #10
0
def train_policy_gradient_ppo(molecules=True,
                              grammar=True,
                              smiles_source='ZINC',
                              EPOCHS=None,
                              BATCH_SIZE=None,
                              reward_fun_on=None,
                              reward_fun_off=None,
                              max_steps=277,
                              lr_on=2e-4,
                              lr_discrim=1e-4,
                              discrim_wt=2,
                              p_thresh=0.5,
                              drop_rate=0.0,
                              plot_ignore_initial=0,
                              randomize_reward=False,
                              save_file_root_name=None,
                              reward_sm=0.0,
                              preload_file_root_name=None,
                              anchor_file=None,
                              anchor_weight=0.0,
                              decoder_type='action',
                              plot_prefix='',
                              dashboard='policy gradient',
                              smiles_save_file=None,
                              on_policy_loss_type='best',
                              priors=True,
                              node_temperature_schedule=lambda x: 1.0,
                              rule_temperature_schedule=lambda x: 1.0,
                              eps=0.0,
                              half_float=False,
                              extra_repetition_penalty=0.0):
    root_location = os.path.dirname(
        os.path.abspath(inspect.getfile(inspect.currentframe())))
    root_location = root_location + '/../../'

    def full_path(x):
        return os.path.realpath(root_location + 'pretrained/' + x)

    if save_file_root_name is not None:
        gen_save_file = save_file_root_name + '_gen.h5'
        disc_save_file = save_file_root_name + '_disc.h5'
    if preload_file_root_name is not None:
        gen_preload_file = preload_file_root_name + '_gen.h5'
        disc_preload_file = preload_file_root_name + '_disc.h5'

    settings = get_settings(molecules=molecules, grammar=grammar)
    codec = get_codec(molecules, grammar, settings['max_seq_length'])
    discrim_model = GraphDiscriminator(codec.grammar, drop_rate=drop_rate)
    if False and preload_file_root_name is not None:
        try:
            preload_path = full_path(disc_preload_file)
            discrim_model.load_state_dict(torch.load(preload_path),
                                          strict=False)
            print('Discriminator weights loaded successfully!')
        except Exception as e:
            print('failed to load discriminator weights ' + str(e))

    zinc_data = get_smiles_from_database(source=smiles_source)
    zinc_set = set(zinc_data)
    lookbacks = [BATCH_SIZE, 10 * BATCH_SIZE, 100 * BATCH_SIZE]
    history_data = [deque(['O'], maxlen=lb) for lb in lookbacks]

    def originality_mult(smiles_list):
        out = []
        for s in smiles_list:
            if s in zinc_set:
                out.append(0.5)
            elif s in history_data[0]:
                out.append(0.5)
            elif s in history_data[1]:
                out.append(0.70)
            elif s in history_data[2]:
                out.append(0.85)
            else:
                out.append(1.0)
        return np.array(out)

    def sigmoid(x):
        return 1 / (1 + np.exp(-x))

    def discriminator_reward_mult(smiles_list):
        orig_state = discrim_model.training
        discrim_model.eval()
        discrim_out_logits = discrim_model(smiles_list)['p_zinc']
        discrim_probs = F.softmax(discrim_out_logits, dim=1)
        prob_zinc = discrim_probs[:, 1].detach().cpu().numpy()
        if orig_state:
            discrim_model.train()
        return prob_zinc

    def apply_originality_penalty(x, orig_mult):
        assert x <= 1, "Reward must be no greater than 0"
        if x > 0.5:  # want to punish nearly-perfect scores less and less
            out = math.pow(x, 1 / orig_mult)
        else:  # continuous join at 0.5
            penalty = math.pow(0.5, 1 / orig_mult) - 0.5
            out = x + penalty

        out -= extra_repetition_penalty * (1 - 1 / orig_mult)
        return out

    def adj_reward(x):
        if discrim_wt > 1e-5:
            p = discriminator_reward_mult(x)
        else:
            p = 0
        rwd = np.array(reward_fun_on(x))
        orig_mult = originality_mult(x)
        # we assume the reward is <=1, first term will dominate for reward <0, second for 0 < reward < 1
        # reward = np.minimum(rwd/orig_mult, np.power(np.abs(rwd),1/orig_mult))
        reward = np.array([
            apply_originality_penalty(x, om) for x, om in zip(rwd, orig_mult)
        ])
        out = reward + discrim_wt * p * orig_mult
        return out

    def adj_reward_old(x):
        p = discriminator_reward_mult(x)
        w = sigmoid(-(p - p_thresh) / 0.01)
        if randomize_reward:
            rand = np.random.uniform(size=p.shape)
            w *= rand
        reward = np.maximum(reward_fun_on(x), p_thresh)
        weighted_reward = w * p + (1 - w) * reward
        out = weighted_reward * originality_mult(x)  #
        return out

    if EPOCHS is not None:
        settings['EPOCHS'] = EPOCHS
    if BATCH_SIZE is not None:
        settings['BATCH_SIZE'] = BATCH_SIZE

    task = SequenceGenerationTask(molecules=molecules,
                                  grammar=grammar,
                                  reward_fun=adj_reward,
                                  batch_size=BATCH_SIZE,
                                  max_steps=max_steps,
                                  save_dataset=None)

    node_policy = SoftmaxRandomSamplePolicy(temperature=torch.tensor(1.0),
                                            eps=eps)
    rule_policy = SoftmaxRandomSamplePolicy(temperature=torch.tensor(1.0),
                                            eps=eps)

    model = get_decoder(molecules,
                        grammar,
                        z_size=settings['z_size'],
                        decoder_hidden_n=200,
                        feature_len=codec.feature_len(),
                        max_seq_length=max_steps,
                        drop_rate=drop_rate,
                        batch_size=BATCH_SIZE,
                        decoder_type=decoder_type,
                        reward_fun=adj_reward,
                        task=task,
                        node_policy=node_policy,
                        rule_policy=rule_policy,
                        priors=priors)[0]

    if preload_file_root_name is not None:
        try:
            preload_path = full_path(gen_preload_file)
            model.load_state_dict(torch.load(preload_path, map_location='cpu'),
                                  strict=False)
            print('Generator weights loaded successfully!')
        except Exception as e:
            print('failed to load generator weights ' + str(e))

    anchor_model = None

    from generative_playground.molecules.rdkit_utils.rdkit_utils import NormalizedScorer
    import numpy as np
    scorer = NormalizedScorer()

    if reward_fun_off is None:
        reward_fun_off = reward_fun_on

    # construct the loader to feed the discriminator
    def make_callback(data):
        def hc(inputs, model, outputs, loss_fn, loss):
            graphs = outputs['graphs']
            smiles = [g.to_smiles() for g in graphs]
            for s in smiles:  # only store unique instances of molecules so discriminator can't guess on frequency
                if s not in data:
                    data.append(s)

        return hc

    class TemperatureCallback:
        def __init__(self, policy, temperature_function):
            self.policy = policy
            self.counter = 0
            self.temp_fun = temperature_function

        def __call__(self, inputs, model, outputs, loss_fn, loss):
            self.counter += 1
            target_temp = self.temp_fun(self.counter)
            self.policy.set_temperature(target_temp)

    # need to have something there to begin with, else the DataLoader constructor barfs

    def get_rl_fitter(model,
                      loss_obj,
                      train_gen,
                      save_path,
                      fit_plot_prefix='',
                      model_process_fun=None,
                      lr=None,
                      extra_callbacks=[],
                      loss_display_cap=float('inf'),
                      anchor_model=None,
                      anchor_weight=0):
        nice_params = filter(lambda p: p.requires_grad, model.parameters())
        optimizer = optim.Adam(nice_params, lr=lr, eps=1e-4)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.99)

        if dashboard is not None:
            metric_monitor = MetricPlotter(
                plot_prefix=fit_plot_prefix,
                loss_display_cap=loss_display_cap,
                dashboard_name=dashboard,
                plot_ignore_initial=plot_ignore_initial,
                process_model_fun=model_process_fun,
                smooth_weight=reward_sm)
        else:
            metric_monitor = None

        checkpointer = Checkpointer(valid_batches_to_checkpoint=1,
                                    save_path=save_path,
                                    save_always=True)

        fitter = fit_rl(train_gen=train_gen,
                        model=model,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        epochs=EPOCHS,
                        loss_fn=loss_obj,
                        grad_clip=5,
                        half_float=half_float,
                        anchor_model=anchor_model,
                        anchor_weight=anchor_weight,
                        callbacks=[metric_monitor, checkpointer] +
                        extra_callbacks)

        return fitter

    def my_gen():
        for _ in range(1000):
            yield to_gpu(torch.zeros(BATCH_SIZE, settings['z_size']))

    # the on-policy fitter

    gen_extra_callbacks = [make_callback(d) for d in history_data]

    if smiles_save_file is not None:
        smiles_save_path = os.path.realpath(root_location + 'pretrained/' +
                                            smiles_save_file)
        gen_extra_callbacks.append(MoleculeSaver(smiles_save_path, gzip=True))
        print('Saved SMILES to {}'.format(smiles_save_file))

    if node_temperature_schedule is not None:
        gen_extra_callbacks.append(
            TemperatureCallback(node_policy, node_temperature_schedule))

    if rule_temperature_schedule is not None:
        gen_extra_callbacks.append(
            TemperatureCallback(rule_policy, rule_temperature_schedule))

    fitter1 = get_rl_fitter(
        model,
        PolicyGradientLoss(on_policy_loss_type),  # last_reward_wgt=reward_sm),
        GeneratorToIterable(my_gen),
        full_path(gen_save_file),
        plot_prefix + 'on-policy',
        model_process_fun=model_process_fun,
        lr=lr_on,
        extra_callbacks=gen_extra_callbacks,
        anchor_model=anchor_model,
        anchor_weight=anchor_weight)
    #
    # # get existing molecule data to add training
    pre_dataset = EvenlyBlendedDataset(
        2 * [history_data[0]] + history_data[1:],
        labels=False)  # a blend of 3 time horizons
    dataset = EvenlyBlendedDataset([pre_dataset, zinc_data], labels=True)
    discrim_loader = DataLoader(dataset, shuffle=True, batch_size=50)

    class MyLoss(nn.Module):
        def __init__(self):
            super().__init__()
            self.celoss = nn.CrossEntropyLoss()

        def forward(self, x):
            # tmp = discriminator_reward_mult(x['smiles'])
            # tmp2 = F.softmax(x['p_zinc'], dim=1)[:,1].detach().cpu().numpy()
            # import numpy as np
            # assert np.max(np.abs(tmp-tmp2)) < 1e-6
            return self.celoss(x['p_zinc'].to(device),
                               x['dataset_index'].to(device))

    fitter2 = get_rl_fitter(discrim_model,
                            MyLoss(),
                            IterableTransform(
                                discrim_loader, lambda x: {
                                    'smiles': x['X'],
                                    'dataset_index': x['dataset_index']
                                }),
                            full_path(disc_save_file),
                            plot_prefix + ' discriminator',
                            lr=lr_discrim,
                            model_process_fun=None)

    def on_policy_gen(fitter, model):
        while True:
            # model.policy = SoftmaxRandomSamplePolicy()#bias=codec.grammar.get_log_frequencies())
            yield next(fitter)

    return model, fitter1, fitter2  #,on_policy_gen(fitter1, model)
    def __init__(self,
                 grammar,
                 smiles_source='ZINC',
                 BATCH_SIZE=None,
                 reward_fun=None,
                 max_steps=277,
                 num_batches=100,
                 lr=2e-4,
                 entropy_wgt=1.0,
                 lr_schedule=None,
                 root_name=None,
                 preload_file_root_name=None,
                 save_location=None,
                 plot_metrics=True,
                 metric_smooth=0.0,
                 decoder_type='graph_conditional',
                 on_policy_loss_type='advantage_record',
                 priors='conditional',
                 rule_temperature_schedule=None,
                 eps=0.0,
                 half_float=False,
                 extra_repetition_penalty=0.0):

        self.num_batches = num_batches
        self.save_location = save_location
        self.molecule_saver = MoleculeSaver(None, gzip=True)
        self.metric_monitor = None  # to be populated by self.set_root_name(...)

        zinc_data = get_smiles_from_database(source=smiles_source)
        zinc_set = set(zinc_data)
        lookbacks = [BATCH_SIZE, 10 * BATCH_SIZE, 100 * BATCH_SIZE]
        history_data = [deque(['O'], maxlen=lb) for lb in lookbacks]

        if root_name is not None:
            pass
            # gen_save_file = root_name + '_gen.h5'
        if preload_file_root_name is not None:
            gen_preload_file = preload_file_root_name + '_gen.h5'

        settings = get_settings(molecules=True, grammar=grammar)
        codec = get_codec(True, grammar, settings['max_seq_length'])

        if BATCH_SIZE is not None:
            settings['BATCH_SIZE'] = BATCH_SIZE

        self.alt_reward_calc = AdjustedRewardCalculator(
            reward_fun,
            zinc_set,
            lookbacks,
            extra_repetition_penalty,
            0,
            discrim_model=None)
        self.reward_fun = lambda x: adj_reward(0,
                                               None,
                                               reward_fun,
                                               zinc_set,
                                               history_data,
                                               extra_repetition_penalty,
                                               x,
                                               alt_calc=self.alt_reward_calc)

        task = SequenceGenerationTask(molecules=True,
                                      grammar=grammar,
                                      reward_fun=self.alt_reward_calc,
                                      batch_size=BATCH_SIZE,
                                      max_steps=max_steps,
                                      save_dataset=None)

        if 'sparse' in decoder_type:
            rule_policy = SoftmaxRandomSamplePolicySparse()
        else:
            rule_policy = SoftmaxRandomSamplePolicy(
                temperature=torch.tensor(1.0), eps=eps)

        # TODO: strip this down to the normal call
        self.model = get_decoder(True,
                                 grammar,
                                 z_size=settings['z_size'],
                                 decoder_hidden_n=200,
                                 feature_len=codec.feature_len(),
                                 max_seq_length=max_steps,
                                 batch_size=BATCH_SIZE,
                                 decoder_type=decoder_type,
                                 reward_fun=self.alt_reward_calc,
                                 task=task,
                                 rule_policy=rule_policy,
                                 priors=priors)[0]

        if preload_file_root_name is not None:
            try:
                preload_path = os.path.realpath(save_location +
                                                gen_preload_file)
                self.model.load_state_dict(torch.load(preload_path,
                                                      map_location='cpu'),
                                           strict=False)
                print('Generator weights loaded successfully!')
            except Exception as e:
                print('failed to load generator weights ' + str(e))

        # construct the loader to feed the discriminator
        def make_callback(data):
            def hc(inputs, model, outputs, loss_fn, loss):
                graphs = outputs['graphs']
                smiles = [g.to_smiles() for g in graphs]
                for s in smiles:  # only store unique instances of molecules so discriminator can't guess on frequency
                    if s not in data:
                        data.append(s)

            return hc

        if plot_metrics:
            # TODO: save_file for rewards data goes here?
            self.metric_monitor_factory = lambda name: MetricPlotter(
                plot_prefix='',
                loss_display_cap=float('inf'),
                dashboard_name=name,
                save_location=save_location,
                process_model_fun=model_process_fun,
                smooth_weight=metric_smooth)
        else:
            self.metric_monitor_factory = lambda x: None

        # the on-policy fitter

        gen_extra_callbacks = [make_callback(d) for d in history_data]
        gen_extra_callbacks.append(self.molecule_saver)
        if rule_temperature_schedule is not None:
            gen_extra_callbacks.append(
                TemperatureCallback(rule_policy, rule_temperature_schedule))

        nice_params = filter(lambda p: p.requires_grad,
                             self.model.parameters())
        self.optimizer = optim.Adam(nice_params, lr=lr, eps=1e-4)

        if lr_schedule is None:
            lr_schedule = lambda x: 1.0
        self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_schedule)
        self.loss = PolicyGradientLoss(on_policy_loss_type,
                                       entropy_wgt=entropy_wgt)
        self.fitter_factory = lambda: make_fitter(BATCH_SIZE, settings[
            'z_size'], [self.metric_monitor] + gen_extra_callbacks, self)

        self.fitter = self.fitter_factory()
        self.set_root_name(root_name)
        print('Runner initialized!')